Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add WebSocket routing #503

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 92 additions & 52 deletions browser_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type browserContextImpl struct {
options *BrowserNewContextOptions
pages []Page
routes []*routeHandlerEntry
webSocketRoutes []*webSocketRouteHandler
ownedPage Page
browser *browserImpl
serviceWorkers []Worker
Expand All @@ -44,7 +45,7 @@ func (b *browserContextImpl) SetDefaultNavigationTimeout(timeout float64) {

func (b *browserContextImpl) setDefaultNavigationTimeoutImpl(timeout *float64) {
b.timeoutSettings.SetDefaultNavigationTimeout(timeout)
b.channel.SendNoReply("setDefaultNavigationTimeoutNoReply", true, map[string]interface{}{
b.channel.SendNoReplyInternal("setDefaultNavigationTimeoutNoReply", map[string]interface{}{
"timeout": timeout,
})
}
Expand All @@ -55,7 +56,7 @@ func (b *browserContextImpl) SetDefaultTimeout(timeout float64) {

func (b *browserContextImpl) setDefaultTimeoutImpl(timeout *float64) {
b.timeoutSettings.SetDefaultTimeout(timeout)
b.channel.SendNoReply("setDefaultTimeoutNoReply", true, map[string]interface{}{
b.channel.SendNoReplyInternal("setDefaultTimeoutNoReply", map[string]interface{}{
"timeout": timeout,
})
}
Expand Down Expand Up @@ -541,7 +542,7 @@ func (b *browserContextImpl) onBinding(binding *bindingCallImpl) {
if !ok || function == nil {
return
}
binding.Call(function)
go binding.Call(function)
}

func (b *browserContextImpl) onClose() {
Expand Down Expand Up @@ -572,58 +573,56 @@ func (b *browserContextImpl) onPage(page Page) {
}

func (b *browserContextImpl) onRoute(route *routeImpl) {
go func() {
b.Lock()
route.context = b
page := route.Request().(*requestImpl).safePage()
routes := make([]*routeHandlerEntry, len(b.routes))
copy(routes, b.routes)
b.Unlock()
b.Lock()
route.context = b
page := route.Request().(*requestImpl).safePage()
routes := make([]*routeHandlerEntry, len(b.routes))
copy(routes, b.routes)
b.Unlock()

checkInterceptionIfNeeded := func() {
b.Lock()
defer b.Unlock()
if len(b.routes) == 0 {
_, err := b.connection.WrapAPICall(func() (interface{}, error) {
err := b.updateInterceptionPatterns()
return nil, err
}, true)
if err != nil {
logger.Printf("could not update interception patterns: %v\n", err)
}
checkInterceptionIfNeeded := func() {
b.Lock()
defer b.Unlock()
if len(b.routes) == 0 {
_, err := b.connection.WrapAPICall(func() (interface{}, error) {
err := b.updateInterceptionPatterns()
return nil, err
}, true)
if err != nil {
logger.Printf("could not update interception patterns: %v\n", err)
}
}
}

url := route.Request().URL()
for _, handlerEntry := range routes {
// If the page or the context was closed we stall all requests right away.
if (page != nil && page.closeWasCalled) || b.closeWasCalled {
return
}
if !handlerEntry.Matches(url) {
continue
}
if !slices.ContainsFunc(b.routes, func(entry *routeHandlerEntry) bool {
return entry == handlerEntry
}) {
continue
}
if handlerEntry.WillExceed() {
b.routes = slices.DeleteFunc(b.routes, func(rhe *routeHandlerEntry) bool {
return rhe == handlerEntry
})
}
handled := handlerEntry.Handle(route)
checkInterceptionIfNeeded()
yes := <-handled
if yes {
return
}
url := route.Request().URL()
for _, handlerEntry := range routes {
// If the page or the context was closed we stall all requests right away.
if (page != nil && page.closeWasCalled) || b.closeWasCalled {
return
}
// If the page is closed or unrouteAll() was called without waiting and interception disabled,
// the method will throw an error - silence it.
_ = route.internalContinue(true)
}()
if !handlerEntry.Matches(url) {
continue
}
if !slices.ContainsFunc(b.routes, func(entry *routeHandlerEntry) bool {
return entry == handlerEntry
}) {
continue
}
if handlerEntry.WillExceed() {
b.routes = slices.DeleteFunc(b.routes, func(rhe *routeHandlerEntry) bool {
return rhe == handlerEntry
})
}
handled := handlerEntry.Handle(route)
checkInterceptionIfNeeded()
yes := <-handled
if yes {
return
}
}
// If the page is closed or unrouteAll() was called without waiting and interception disabled,
// the method will throw an error - silence it.
_ = route.internalContinue(true)
}

func (b *browserContextImpl) updateInterceptionPatterns() error {
Expand Down Expand Up @@ -726,6 +725,40 @@ func (b *browserContextImpl) OnWebError(fn func(WebError)) {
b.On("weberror", fn)
}

func (b *browserContextImpl) RouteWebSocket(url interface{}, handler func(WebSocketRoute)) error {
b.Lock()
defer b.Unlock()
b.webSocketRoutes = slices.Insert(b.webSocketRoutes, 0, newWebSocketRouteHandler(newURLMatcher(url, b.options.BaseURL), handler))

return b.updateWebSocketInterceptionPatterns()
}

func (b *browserContextImpl) onWebSocketRoute(wr WebSocketRoute) {
b.Lock()
index := slices.IndexFunc(b.webSocketRoutes, func(r *webSocketRouteHandler) bool {
return r.Matches(wr.URL())
})
if index == -1 {
b.Unlock()
_, err := wr.ConnectToServer()
if err != nil {
logger.Println(err)
}
return
}
handler := b.webSocketRoutes[index]
b.Unlock()
handler.Handle(wr)
}

func (b *browserContextImpl) updateWebSocketInterceptionPatterns() error {
patterns := prepareWebSocketRouteHandlerInterceptionPatterns(b.webSocketRoutes)
_, err := b.channel.Send("setWebSocketInterceptionPatterns", map[string]interface{}{
"patterns": patterns,
})
return err
}

func (b *browserContextImpl) effectiveCloseReason() *string {
b.Lock()
defer b.Unlock()
Expand Down Expand Up @@ -758,15 +791,22 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini
bt.request = fromChannel(initializer["requestContext"]).(*apiRequestContextImpl)
bt.clock = newClock(bt)
bt.channel.On("bindingCall", func(params map[string]interface{}) {
go bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
})

bt.channel.On("close", bt.onClose)
bt.channel.On("page", func(payload map[string]interface{}) {
bt.onPage(fromChannel(payload["page"]).(*pageImpl))
})
bt.channel.On("route", func(params map[string]interface{}) {
bt.onRoute(fromChannel(params["route"]).(*routeImpl))
bt.channel.CreateTask(func() {
bt.onRoute(fromChannel(params["route"]).(*routeImpl))
})
})
bt.channel.On("webSocketRoute", func(params map[string]interface{}) {
bt.channel.CreateTask(func() {
bt.onWebSocketRoute(fromChannel(params["webSocketRoute"]).(*webSocketRouteImpl))
})
})
bt.channel.On("backgroundPage", bt.onBackgroundPage)
bt.channel.On("serviceWorker", func(params map[string]interface{}) {
Expand Down
41 changes: 39 additions & 2 deletions channel.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package playwright

import "encoding/json"
import (
"encoding/json"
"fmt"
)

type channel struct {
eventEmitter
Expand All @@ -16,6 +19,23 @@ func (c *channel) MarshalJSON() ([]byte, error) {
})
}

// for catch errors of route handlers etc.
func (c *channel) CreateTask(fn func()) {
go func() {
defer func() {
if e := recover(); e != nil {
err, ok := e.(error)
if ok {
c.connection.err.Set(err)
} else {
c.connection.err.Set(fmt.Errorf("%v", e))
}
}
}()
fn()
}()
}

func (c *channel) Send(method string, options ...interface{}) (interface{}, error) {
return c.connection.WrapAPICall(func() (interface{}, error) {
return c.innerSend(method, options...).GetResultValue()
Expand All @@ -30,16 +50,33 @@ func (c *channel) SendReturnAsDict(method string, options ...interface{}) (map[s
}

func (c *channel) innerSend(method string, options ...interface{}) *protocolCallback {
if err := c.connection.err.Get(); err != nil {
c.connection.err.Set(nil)
pc := newProtocolCallback(false, c.connection.abort)
pc.SetError(err)
return pc
}
params := transformOptions(options...)
return c.connection.sendMessageToServer(c.owner, method, params, false)
}

func (c *channel) SendNoReply(method string, isInternal bool, options ...interface{}) {
// SendNoReply ignores return value and errors
// almost equivalent to `send(...).catch(() => {})`
func (c *channel) SendNoReply(method string, options ...interface{}) {
c.innerSendNoReply(method, c.owner.isInternalType, options...)
}

func (c *channel) SendNoReplyInternal(method string, options ...interface{}) {
c.innerSendNoReply(method, true, options...)
}

func (c *channel) innerSendNoReply(method string, isInternal bool, options ...interface{}) {
params := transformOptions(options...)
_, err := c.connection.WrapAPICall(func() (interface{}, error) {
return c.connection.sendMessageToServer(c.owner, method, params, true).GetResult()
}, isInternal)
if err != nil {
// ignore error actively, log only for debug
logger.Printf("SendNoReply failed: %v\n", err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion channel_owner.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (c *channelOwner) setEventSubscriptionMapping(mapping map[string]string) {
func (c *channelOwner) updateSubscription(event string, enabled bool) {
protocolEvent, ok := c.eventToSubscriptionMapping[event]
if ok {
c.channel.SendNoReply("updateSubscription", true, map[string]interface{}{
c.channel.SendNoReplyInternal("updateSubscription", map[string]interface{}{
"event": protocolEvent,
"enabled": enabled,
})
Expand Down
4 changes: 3 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type connection struct {
tracingCount atomic.Int32
abort chan struct{}
abortOnce sync.Once
err *safeValue[error] // for event listener error
closedError *safeValue[error]
}

Expand Down Expand Up @@ -301,6 +302,7 @@ func newConnection(transport transport, localUtils ...*localUtilsImpl) *connecti
objects: safe.NewSyncMap[string, *channelOwner](),
transport: transport,
isRemote: false,
err: &safeValue[error]{},
closedError: &safeValue[error]{},
}
if len(localUtils) > 0 {
Expand Down Expand Up @@ -393,7 +395,7 @@ func newProtocolCallback(noReply bool, abort <-chan struct{}) *protocolCallback
}
}
return &protocolCallback{
done: make(chan struct{}),
done: make(chan struct{}, 1),
abort: abort,
}
}
Loading