Skip to content

Commit

Permalink
fix: data races in some situations (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
canstand authored Jul 16, 2024
1 parent 17169dc commit 9a40c60
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 49 deletions.
22 changes: 13 additions & 9 deletions browser_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"
"sync"

"github.com/playwright-community/playwright-go/internal/safe"
"golang.org/x/exp/slices"
)

Expand All @@ -23,7 +24,7 @@ type browserContextImpl struct {
browser *browserImpl
serviceWorkers []Worker
backgroundPages []Page
bindings map[string]BindingCallFunction
bindings *safe.SyncMap[string, BindingCallFunction]
tracing *tracingImpl
request *apiRequestContextImpl
harRecorders map[string]harRecordingMetadata
Expand Down Expand Up @@ -240,18 +241,21 @@ func (b *browserContextImpl) ExposeBinding(name string, binding BindingCallFunct
needsHandle = handle[0]
}
for _, page := range b.Pages() {
if _, ok := page.(*pageImpl).bindings[name]; ok {
if _, ok := page.(*pageImpl).bindings.Load(name); ok {
return fmt.Errorf("Function '%s' has been already registered in one of the pages", name)
}
}
if _, ok := b.bindings[name]; ok {
if _, ok := b.bindings.Load(name); ok {
return fmt.Errorf("Function '%s' has been already registered", name)
}
b.bindings[name] = binding
_, err := b.channel.Send("exposeBinding", map[string]interface{}{
"name": name,
"needsHandle": needsHandle,
})
if err != nil {
return err
}
b.bindings.Store(name, binding)
return err
}

Expand Down Expand Up @@ -533,11 +537,11 @@ func (b *browserContextImpl) StorageState(paths ...string) (*StorageState, error
}

func (b *browserContextImpl) onBinding(binding *bindingCallImpl) {
function := b.bindings[binding.initializer["name"].(string)]
if function == nil {
function, ok := b.bindings.Load(binding.initializer["name"].(string))
if !ok || function == nil {
return
}
go binding.Call(function)
binding.Call(function)
}

func (b *browserContextImpl) onClose() {
Expand Down Expand Up @@ -740,7 +744,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini
pages: make([]Page, 0),
backgroundPages: make([]Page, 0),
routes: make([]*routeHandlerEntry, 0),
bindings: make(map[string]BindingCallFunction),
bindings: safe.NewSyncMap[string, BindingCallFunction](),
harRecorders: make(map[string]harRecordingMetadata),
closed: make(chan struct{}, 1),
harRouters: make([]*harRouter, 0),
Expand All @@ -754,7 +758,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini
bt.request = fromChannel(initializer["requestContext"]).(*apiRequestContextImpl)
bt.clock = newClock(bt)
bt.channel.On("bindingCall", func(params map[string]interface{}) {
bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
go bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
})

bt.channel.On("close", bt.onClose)
Expand Down
4 changes: 2 additions & 2 deletions channel_owner.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (c *channelOwner) dispose(reason ...string) {
if c.parent != nil {
delete(c.parent.objects, c.guid)
}
delete(c.connection.objects, c.guid)
c.connection.objects.Delete(c.guid)
if len(reason) > 0 {
c.wasCollected = reason[0] == "gc"
}
Expand Down Expand Up @@ -89,7 +89,7 @@ func (c *channelOwner) createChannelOwner(self interface{}, parent *channelOwner
c.parent.objects[guid] = c
}
if c.connection != nil {
c.connection.objects[guid] = c
c.connection.objects.Store(guid, c)
}
c.channel = newChannel(c, self)
c.eventToSubscriptionMapping = map[string]string{}
Expand Down
22 changes: 12 additions & 10 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/go-stack/stack"
"github.com/playwright-community/playwright-go/internal/safe"
)

var (
Expand All @@ -27,10 +28,10 @@ type result struct {
type connection struct {
transport transport
apiZone sync.Map
objects map[string]*channelOwner
objects *safe.SyncMap[string, *channelOwner]
lastID atomic.Uint32
rootObject *rootChannelOwner
callbacks sync.Map
callbacks *safe.SyncMap[uint32, *protocolCallback]
afterClose func()
onClose func() error
isRemote bool
Expand Down Expand Up @@ -97,21 +98,21 @@ func (c *connection) Dispatch(msg *message) {
method := msg.Method
if msg.ID != 0 {
cb, _ := c.callbacks.LoadAndDelete(uint32(msg.ID))
if cb.(*protocolCallback).noReply {
if cb.noReply {
return
}
if msg.Error != nil {
cb.(*protocolCallback).SetResult(result{
cb.SetResult(result{
Error: parseError(msg.Error.Error),
})
} else {
cb.(*protocolCallback).SetResult(result{
cb.SetResult(result{
Data: c.replaceGuidsWithChannels(msg.Result),
})
}
return
}
object := c.objects[msg.GUID]
object, _ := c.objects.Load(msg.GUID)
if method == "__create__" {
c.createRemoteObject(
object, msg.Params["type"].(string), msg.Params["guid"].(string), msg.Params["initializer"],
Expand All @@ -122,7 +123,7 @@ func (c *connection) Dispatch(msg *message) {
return
}
if method == "__adopt__" {
child, ok := c.objects[msg.Params["guid"].(string)]
child, ok := c.objects.Load(msg.Params["guid"].(string))
if !ok {
return
}
Expand Down Expand Up @@ -205,7 +206,7 @@ func (c *connection) replaceGuidsWithChannels(payload interface{}) interface{} {
if v.Kind() == reflect.Map {
mapV := payload.(map[string]interface{})
if guid, hasGUID := mapV["guid"]; hasGUID {
if channelOwner, ok := c.objects[guid.(string)]; ok {
if channelOwner, ok := c.objects.Load(guid.(string)); ok {
return channelOwner.channel
}
}
Expand Down Expand Up @@ -254,7 +255,7 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa
return nil, fmt.Errorf("could not send message: %w", err)
}

return cb.(*protocolCallback), nil
return cb, nil
}

func (c *connection) setInTracing(isTracing bool) {
Expand Down Expand Up @@ -327,7 +328,8 @@ func serializeCallLocation(caller stack.Call) map[string]interface{} {
func newConnection(transport transport, localUtils ...*localUtilsImpl) *connection {
connection := &connection{
abort: make(chan struct{}, 1),
objects: make(map[string]*channelOwner),
callbacks: safe.NewSyncMap[uint32, *protocolCallback](),
objects: safe.NewSyncMap[string, *channelOwner](),
transport: transport,
isRemote: false,
closedError: &safeValue[error]{},
Expand Down
43 changes: 25 additions & 18 deletions event_emitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type (
hasInit bool
}
eventRegister struct {
sync.Mutex
listeners []listener
}
listener struct {
Expand All @@ -33,18 +34,15 @@ type (

func (e *eventEmitter) Emit(name string, payload ...interface{}) (hasListener bool) {
e.eventsMutex.Lock()
defer e.eventsMutex.Unlock()
e.init()

evt, ok := e.events[name]
if !ok {
e.eventsMutex.Unlock()
return
}

hasListener = evt.count() > 0

evt.callHandlers(payload...)
return
e.eventsMutex.Unlock()
return evt.callHandlers(payload...) > 0
}

func (e *eventEmitter) Once(name string, handler interface{}) {
Expand All @@ -60,10 +58,11 @@ func (e *eventEmitter) RemoveListener(name string, handler interface{}) {
defer e.eventsMutex.Unlock()
e.init()

if _, ok := e.events[name]; !ok {
return
if evt, ok := e.events[name]; ok {
evt.Lock()
defer evt.Unlock()
evt.removeHandler(handler)
}
e.events[name].removeHandler(handler)
}

// ListenerCount count the listeners by name, count all if name is empty
Expand All @@ -90,6 +89,7 @@ func (e *eventEmitter) ListenerCount(name string) int {

func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) {
e.eventsMutex.Lock()
defer e.eventsMutex.Unlock()
e.init()

if _, ok := e.events[name]; !ok {
Expand All @@ -98,7 +98,6 @@ func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) {
}
}
e.events[name].addHandler(handler, once)
e.eventsMutex.Unlock()
}

func (e *eventEmitter) init() {
Expand All @@ -108,23 +107,27 @@ func (e *eventEmitter) init() {
}
}

func (e *eventRegister) addHandler(handler interface{}, once bool) {
e.listeners = append(e.listeners, listener{handler: handler, once: once})
func (er *eventRegister) addHandler(handler interface{}, once bool) {
er.Lock()
defer er.Unlock()
er.listeners = append(er.listeners, listener{handler: handler, once: once})
}

func (e *eventRegister) count() int {
return len(e.listeners)
func (er *eventRegister) count() int {
er.Lock()
defer er.Unlock()
return len(er.listeners)
}

func (e *eventRegister) removeHandler(handler interface{}) {
handlerPtr := reflect.ValueOf(handler).Pointer()

e.listeners = slices.DeleteFunc[[]listener](e.listeners, func(l listener) bool {
e.listeners = slices.DeleteFunc(e.listeners, func(l listener) bool {
return reflect.ValueOf(l.handler).Pointer() == handlerPtr
})
}

func (e *eventRegister) callHandlers(payloads ...interface{}) {
func (er *eventRegister) callHandlers(payloads ...interface{}) int {
payloadV := make([]reflect.Value, 0)

for _, p := range payloads {
Expand All @@ -136,10 +139,14 @@ func (e *eventRegister) callHandlers(payloads ...interface{}) {
handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))])
}

for _, l := range e.listeners {
er.Lock()
defer er.Unlock()
count := len(er.listeners)
for _, l := range er.listeners {
if l.once {
defer e.removeHandler(l.handler)
defer er.removeHandler(l.handler)
}
handle(l)
}
return count
}
91 changes: 91 additions & 0 deletions internal/safe/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package safe

import (
"sync"

"golang.org/x/exp/maps"
)

// SyncMap is a thread-safe map
type SyncMap[K comparable, V any] struct {
sync.RWMutex
m map[K]V
}

// NewSyncMap creates a new thread-safe map
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
return &SyncMap[K, V]{
m: make(map[K]V),
}
}

func (m *SyncMap[K, V]) Store(k K, v V) {
m.Lock()
defer m.Unlock()
m.m[k] = v
}

func (m *SyncMap[K, V]) Load(k K) (v V, ok bool) {
m.RLock()
defer m.RUnlock()
v, ok = m.m[k]
return
}

// LoadOrStore returns the existing value for the key if present. Otherwise, it stores and returns the given value.
func (m *SyncMap[K, V]) LoadOrStore(k K, v V) (actual V, loaded bool) {
m.Lock()
defer m.Unlock()
actual, loaded = m.m[k]
if loaded {
return
}
m.m[k] = v
return v, false
}

// LoadAndDelete deletes the value for a key, and returns the previous value if any.
func (m *SyncMap[K, V]) LoadAndDelete(k K) (v V, loaded bool) {
m.Lock()
defer m.Unlock()
v, loaded = m.m[k]
if loaded {
delete(m.m, k)
}
return
}

func (m *SyncMap[K, V]) Delete(k K) {
m.Lock()
defer m.Unlock()
delete(m.m, k)
}

func (m *SyncMap[K, V]) Clear() {
m.Lock()
defer m.Unlock()
maps.Clear(m.m)
}

func (m *SyncMap[K, V]) Len() int {
m.RLock()
defer m.RUnlock()
return len(m.m)
}

func (m *SyncMap[K, V]) Clone() map[K]V {
m.RLock()
defer m.RUnlock()
return maps.Clone(m.m)
}

func (m *SyncMap[K, V]) Range(f func(k K, v V) bool) {
m.RLock()
defer m.RUnlock()

for k, v := range m.m {
if !f(k, v) {
break
}
}
}
Loading

0 comments on commit 9a40c60

Please sign in to comment.