diff --git a/browser_context.go b/browser_context.go index 58c7c5b..2563ecd 100644 --- a/browser_context.go +++ b/browser_context.go @@ -20,6 +20,7 @@ type browserContextImpl struct { options *BrowserNewContextOptions pages []Page routes []*routeHandlerEntry + webSocketRoutes []*webSocketRouteHandler ownedPage Page browser *browserImpl serviceWorkers []Worker @@ -44,7 +45,7 @@ func (b *browserContextImpl) SetDefaultNavigationTimeout(timeout float64) { func (b *browserContextImpl) setDefaultNavigationTimeoutImpl(timeout *float64) { b.timeoutSettings.SetDefaultNavigationTimeout(timeout) - b.channel.SendNoReply("setDefaultNavigationTimeoutNoReply", true, map[string]interface{}{ + b.channel.SendNoReplyInternal("setDefaultNavigationTimeoutNoReply", map[string]interface{}{ "timeout": timeout, }) } @@ -55,7 +56,7 @@ func (b *browserContextImpl) SetDefaultTimeout(timeout float64) { func (b *browserContextImpl) setDefaultTimeoutImpl(timeout *float64) { b.timeoutSettings.SetDefaultTimeout(timeout) - b.channel.SendNoReply("setDefaultTimeoutNoReply", true, map[string]interface{}{ + b.channel.SendNoReplyInternal("setDefaultTimeoutNoReply", map[string]interface{}{ "timeout": timeout, }) } @@ -541,7 +542,7 @@ func (b *browserContextImpl) onBinding(binding *bindingCallImpl) { if !ok || function == nil { return } - binding.Call(function) + go binding.Call(function) } func (b *browserContextImpl) onClose() { @@ -572,58 +573,56 @@ func (b *browserContextImpl) onPage(page Page) { } func (b *browserContextImpl) onRoute(route *routeImpl) { - go func() { - b.Lock() - route.context = b - page := route.Request().(*requestImpl).safePage() - routes := make([]*routeHandlerEntry, len(b.routes)) - copy(routes, b.routes) - b.Unlock() + b.Lock() + route.context = b + page := route.Request().(*requestImpl).safePage() + routes := make([]*routeHandlerEntry, len(b.routes)) + copy(routes, b.routes) + b.Unlock() - checkInterceptionIfNeeded := func() { - b.Lock() - defer b.Unlock() - if len(b.routes) == 0 { - _, err := b.connection.WrapAPICall(func() (interface{}, error) { - err := b.updateInterceptionPatterns() - return nil, err - }, true) - if err != nil { - logger.Printf("could not update interception patterns: %v\n", err) - } + checkInterceptionIfNeeded := func() { + b.Lock() + defer b.Unlock() + if len(b.routes) == 0 { + _, err := b.connection.WrapAPICall(func() (interface{}, error) { + err := b.updateInterceptionPatterns() + return nil, err + }, true) + if err != nil { + logger.Printf("could not update interception patterns: %v\n", err) } } + } - url := route.Request().URL() - for _, handlerEntry := range routes { - // If the page or the context was closed we stall all requests right away. - if (page != nil && page.closeWasCalled) || b.closeWasCalled { - return - } - if !handlerEntry.Matches(url) { - continue - } - if !slices.ContainsFunc(b.routes, func(entry *routeHandlerEntry) bool { - return entry == handlerEntry - }) { - continue - } - if handlerEntry.WillExceed() { - b.routes = slices.DeleteFunc(b.routes, func(rhe *routeHandlerEntry) bool { - return rhe == handlerEntry - }) - } - handled := handlerEntry.Handle(route) - checkInterceptionIfNeeded() - yes := <-handled - if yes { - return - } + url := route.Request().URL() + for _, handlerEntry := range routes { + // If the page or the context was closed we stall all requests right away. + if (page != nil && page.closeWasCalled) || b.closeWasCalled { + return } - // If the page is closed or unrouteAll() was called without waiting and interception disabled, - // the method will throw an error - silence it. - _ = route.internalContinue(true) - }() + if !handlerEntry.Matches(url) { + continue + } + if !slices.ContainsFunc(b.routes, func(entry *routeHandlerEntry) bool { + return entry == handlerEntry + }) { + continue + } + if handlerEntry.WillExceed() { + b.routes = slices.DeleteFunc(b.routes, func(rhe *routeHandlerEntry) bool { + return rhe == handlerEntry + }) + } + handled := handlerEntry.Handle(route) + checkInterceptionIfNeeded() + yes := <-handled + if yes { + return + } + } + // If the page is closed or unrouteAll() was called without waiting and interception disabled, + // the method will throw an error - silence it. + _ = route.internalContinue(true) } func (b *browserContextImpl) updateInterceptionPatterns() error { @@ -726,6 +725,40 @@ func (b *browserContextImpl) OnWebError(fn func(WebError)) { b.On("weberror", fn) } +func (b *browserContextImpl) RouteWebSocket(url interface{}, handler func(WebSocketRoute)) error { + b.Lock() + defer b.Unlock() + b.webSocketRoutes = slices.Insert(b.webSocketRoutes, 0, newWebSocketRouteHandler(newURLMatcher(url, b.options.BaseURL), handler)) + + return b.updateWebSocketInterceptionPatterns() +} + +func (b *browserContextImpl) onWebSocketRoute(wr WebSocketRoute) { + b.Lock() + index := slices.IndexFunc(b.webSocketRoutes, func(r *webSocketRouteHandler) bool { + return r.Matches(wr.URL()) + }) + if index == -1 { + b.Unlock() + _, err := wr.ConnectToServer() + if err != nil { + logger.Println(err) + } + return + } + handler := b.webSocketRoutes[index] + b.Unlock() + handler.Handle(wr) +} + +func (b *browserContextImpl) updateWebSocketInterceptionPatterns() error { + patterns := prepareWebSocketRouteHandlerInterceptionPatterns(b.webSocketRoutes) + _, err := b.channel.Send("setWebSocketInterceptionPatterns", map[string]interface{}{ + "patterns": patterns, + }) + return err +} + func (b *browserContextImpl) effectiveCloseReason() *string { b.Lock() defer b.Unlock() @@ -758,7 +791,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini bt.request = fromChannel(initializer["requestContext"]).(*apiRequestContextImpl) bt.clock = newClock(bt) bt.channel.On("bindingCall", func(params map[string]interface{}) { - go bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl)) + bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl)) }) bt.channel.On("close", bt.onClose) @@ -766,7 +799,14 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini bt.onPage(fromChannel(payload["page"]).(*pageImpl)) }) bt.channel.On("route", func(params map[string]interface{}) { - bt.onRoute(fromChannel(params["route"]).(*routeImpl)) + bt.channel.CreateTask(func() { + bt.onRoute(fromChannel(params["route"]).(*routeImpl)) + }) + }) + bt.channel.On("webSocketRoute", func(params map[string]interface{}) { + bt.channel.CreateTask(func() { + bt.onWebSocketRoute(fromChannel(params["webSocketRoute"]).(*webSocketRouteImpl)) + }) }) bt.channel.On("backgroundPage", bt.onBackgroundPage) bt.channel.On("serviceWorker", func(params map[string]interface{}) { diff --git a/channel.go b/channel.go index 9f3c26b..83d2e89 100644 --- a/channel.go +++ b/channel.go @@ -1,6 +1,9 @@ package playwright -import "encoding/json" +import ( + "encoding/json" + "fmt" +) type channel struct { eventEmitter @@ -16,6 +19,23 @@ func (c *channel) MarshalJSON() ([]byte, error) { }) } +// for catch errors of route handlers etc. +func (c *channel) CreateTask(fn func()) { + go func() { + defer func() { + if e := recover(); e != nil { + err, ok := e.(error) + if ok { + c.connection.err.Set(err) + } else { + c.connection.err.Set(fmt.Errorf("%v", e)) + } + } + }() + fn() + }() +} + func (c *channel) Send(method string, options ...interface{}) (interface{}, error) { return c.connection.WrapAPICall(func() (interface{}, error) { return c.innerSend(method, options...).GetResultValue() @@ -30,16 +50,33 @@ func (c *channel) SendReturnAsDict(method string, options ...interface{}) (map[s } func (c *channel) innerSend(method string, options ...interface{}) *protocolCallback { + if err := c.connection.err.Get(); err != nil { + c.connection.err.Set(nil) + pc := newProtocolCallback(false, c.connection.abort) + pc.SetError(err) + return pc + } params := transformOptions(options...) return c.connection.sendMessageToServer(c.owner, method, params, false) } -func (c *channel) SendNoReply(method string, isInternal bool, options ...interface{}) { +// SendNoReply ignores return value and errors +// almost equivalent to `send(...).catch(() => {})` +func (c *channel) SendNoReply(method string, options ...interface{}) { + c.innerSendNoReply(method, c.owner.isInternalType, options...) +} + +func (c *channel) SendNoReplyInternal(method string, options ...interface{}) { + c.innerSendNoReply(method, true, options...) +} + +func (c *channel) innerSendNoReply(method string, isInternal bool, options ...interface{}) { params := transformOptions(options...) _, err := c.connection.WrapAPICall(func() (interface{}, error) { return c.connection.sendMessageToServer(c.owner, method, params, true).GetResult() }, isInternal) if err != nil { + // ignore error actively, log only for debug logger.Printf("SendNoReply failed: %v\n", err) } } diff --git a/channel_owner.go b/channel_owner.go index 4e3a1f0..5159eb2 100644 --- a/channel_owner.go +++ b/channel_owner.go @@ -49,7 +49,7 @@ func (c *channelOwner) setEventSubscriptionMapping(mapping map[string]string) { func (c *channelOwner) updateSubscription(event string, enabled bool) { protocolEvent, ok := c.eventToSubscriptionMapping[event] if ok { - c.channel.SendNoReply("updateSubscription", true, map[string]interface{}{ + c.channel.SendNoReplyInternal("updateSubscription", map[string]interface{}{ "event": protocolEvent, "enabled": enabled, }) diff --git a/connection.go b/connection.go index ff2b1f6..ba1e365 100644 --- a/connection.go +++ b/connection.go @@ -34,6 +34,7 @@ type connection struct { tracingCount atomic.Int32 abort chan struct{} abortOnce sync.Once + err *safeValue[error] // for event listener error closedError *safeValue[error] } @@ -301,6 +302,7 @@ func newConnection(transport transport, localUtils ...*localUtilsImpl) *connecti objects: safe.NewSyncMap[string, *channelOwner](), transport: transport, isRemote: false, + err: &safeValue[error]{}, closedError: &safeValue[error]{}, } if len(localUtils) > 0 { @@ -393,7 +395,7 @@ func newProtocolCallback(noReply bool, abort <-chan struct{}) *protocolCallback } } return &protocolCallback{ - done: make(chan struct{}), + done: make(chan struct{}, 1), abort: abort, } } diff --git a/generated-interfaces.go b/generated-interfaces.go index 17b8071..28a9c83 100644 --- a/generated-interfaces.go +++ b/generated-interfaces.go @@ -398,6 +398,15 @@ type BrowserContext interface { // [this]: https://github.com/microsoft/playwright/issues/1090 RouteFromHAR(har string, options ...BrowserContextRouteFromHAROptions) error + // This method allows to modify websocket connections that are made by any page in the browser context. + // Note that only `WebSocket`s created after this method was called will be routed. It is recommended to call this + // method before creating any pages. + // + // 1. url: Only WebSockets with the url matching this pattern will be routed. A string pattern can be relative to the + // “[object Object]” context option. + // 2. handler: Handler function to route the WebSocket. + RouteWebSocket(url interface{}, handler func(WebSocketRoute)) error + // **NOTE** Service workers are only supported on Chromium-based browsers. // All existing service workers in the context. ServiceWorkers() []Worker @@ -3746,6 +3755,15 @@ type Page interface { // [this]: https://github.com/microsoft/playwright/issues/1090 RouteFromHAR(har string, options ...PageRouteFromHAROptions) error + // This method allows to modify websocket connections that are made by the page. + // Note that only `WebSocket`s created after this method was called will be routed. It is recommended to call this + // method before navigating the page. + // + // 1. url: Only WebSockets with the url matching this pattern will be routed. A string pattern can be relative to the + // “[object Object]” context option. + // 2. handler: Handler function to route the WebSocket. + RouteWebSocket(url interface{}, handler func(WebSocketRoute)) error + // Returns the buffer with the captured screenshot. Screenshot(options ...PageScreenshotOptions) ([]byte, error) @@ -4430,6 +4448,82 @@ type WebSocket interface { WaitForEvent(event string, options ...WebSocketWaitForEventOptions) (interface{}, error) } +// Whenever a [`WebSocket`] route is set up with +// [Page.RouteWebSocket] or [BrowserContext.RouteWebSocket], the `WebSocketRoute` object allows to handle the +// WebSocket, like an actual server would do. +// **Mocking** +// By default, the routed WebSocket will not connect to the server. This way, you can mock entire communcation over +// the WebSocket. Here is an example that responds to a `"request"` with a `"response"`. +// Since we do not call [WebSocketRoute.ConnectToServer] inside the WebSocket route handler, Playwright assumes that +// WebSocket will be mocked, and opens the WebSocket inside the page automatically. +// Here is another example that handles JSON messages: +// **Intercepting** +// Alternatively, you may want to connect to the actual server, but intercept messages in-between and modify or block +// them. Calling [WebSocketRoute.ConnectToServer] returns a server-side `WebSocketRoute` instance that you can send +// messages to, or handle incoming messages. +// Below is an example that modifies some messages sent by the page to the server. Messages sent from the server to +// the page are left intact, relying on the default forwarding. +// After connecting to the server, all **messages are forwarded** between the page and the server by default. +// However, if you call [WebSocketRoute.OnMessage] on the original route, messages from the page to the server **will +// not be forwarded** anymore, but should instead be handled by the “[object Object]”. +// Similarly, calling [WebSocketRoute.OnMessage] on the server-side WebSocket will **stop forwarding messages** from +// the server to the page, and “[object Object]” should take care of them. +// The following example blocks some messages in both directions. Since it calls [WebSocketRoute.OnMessage] in both +// directions, there is no automatic forwarding at all. +// +// [`WebSocket`]: https://developer.mozilla.org/en-US/docs/Web/API/WebSocket +type WebSocketRoute interface { + // Closes one side of the WebSocket connection. + Close(options ...WebSocketRouteCloseOptions) + + // By default, routed WebSocket does not connect to the server, so you can mock entire WebSocket communication. This + // method connects to the actual WebSocket server, and returns the server-side [WebSocketRoute] instance, giving the + // ability to send and receive messages from the server. + // Once connected to the server: + // - Messages received from the server will be **automatically forwarded** to the WebSocket in the page, unless + // [WebSocketRoute.OnMessage] is called on the server-side `WebSocketRoute`. + // - Messages sent by the [`WebSocket.send()`] call + // in the page will be **automatically forwarded** to the server, unless [WebSocketRoute.OnMessage] is called on + // the original `WebSocketRoute`. + // See examples at the top for more details. + // + // [`WebSocket.send()`]: https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send + ConnectToServer() (WebSocketRoute, error) + + // Allows to handle [`WebSocket.close`]. + // By default, closing one side of the connection, either in the page or on the server, will close the other side. + // However, when [WebSocketRoute.OnClose] handler is set up, the default forwarding of closure is disabled, and + // handler should take care of it. + // + // handler: Function that will handle WebSocket closure. Received an optional + // [close code](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#code) and an optional + // [close reason](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#reason). + // + // [`WebSocket.close`]: https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close + OnClose(handler func(*int, *string)) + + // This method allows to handle messages that are sent by the WebSocket, either from the page or from the server. + // When called on the original WebSocket route, this method handles messages sent from the page. You can handle this + // messages by responding to them with [WebSocketRoute.Send], forwarding them to the server-side connection returned + // by [WebSocketRoute.ConnectToServer] or do something else. + // Once this method is called, messages are not automatically forwarded to the server or to the page - you should do + // that manually by calling [WebSocketRoute.Send]. See examples at the top for more details. + // Calling this method again will override the handler with a new one. + // + // handler: Function that will handle messages. + OnMessage(handler func(interface{})) + + // Sends a message to the WebSocket. When called on the original WebSocket, sends the message to the page. When called + // on the result of [WebSocketRoute.ConnectToServer], sends the message to the server. See examples at the top for + // more details. + // + // message: Message to send. + Send(message interface{}) + + // URL of the WebSocket created in the page. + URL() string +} + // The Worker class represents a [WebWorker]. // `worker` event is emitted on the page object to signal a worker creation. `close` event is emitted on the worker // object when the worker is gone. diff --git a/generated-structs.go b/generated-structs.go index ec72d07..d5d89e1 100644 --- a/generated-structs.go +++ b/generated-structs.go @@ -4189,6 +4189,17 @@ type WebSocketWaitForEventOptions struct { Timeout *float64 `json:"timeout"` } +type WebSocketRouteCloseOptions struct { + // Optional [close code]. + // + // [close code]: https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#code + Code *int `json:"code"` + // Optional [close reason]. + // + // [close reason]: https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#reason + Reason *string `json:"reason"` +} + type ClientCertificate struct { // Exact origin that the certificate is valid for. Origin includes `https` protocol, a hostname and optionally a port. Origin string `json:"origin"` diff --git a/har_router.go b/har_router.go index efbe5b7..316da57 100644 --- a/har_router.go +++ b/har_router.go @@ -45,7 +45,7 @@ func (r *harRouter) addPageRoute(page Page) error { } func (r *harRouter) dispose() { - r.localUtils.HarClose(r.harId) + go r.localUtils.HarClose(r.harId) } func (r *harRouter) handle(route Route) error { diff --git a/local_utils.go b/local_utils.go index 2e6749e..395c09a 100644 --- a/local_utils.go +++ b/local_utils.go @@ -100,12 +100,13 @@ func (l *localUtilsImpl) HarLookup(option harLookupOptions) (*harLookupResult, e return &result, err } -func (l *localUtilsImpl) HarClose(harId string) { - l.channel.SendNoReply("harClose", true, []map[string]interface{}{ +func (l *localUtilsImpl) HarClose(harId string) error { + _, err := l.channel.Send("harClose", []map[string]interface{}{ { "harId": harId, }, }) + return err } func (l *localUtilsImpl) HarUnzip(zipFile, harFile string) error { @@ -139,7 +140,7 @@ func (l *localUtilsImpl) TraceDiscarded(stacksId string) error { } func (l *localUtilsImpl) AddStackToTracingNoReply(id uint32, stack []map[string]interface{}) { - l.channel.SendNoReply("addStackToTracingNoReply", true, map[string]interface{}{ + l.channel.SendNoReply("addStackToTracingNoReply", map[string]interface{}{ "callData": map[string]interface{}{ "id": id, "stack": stack, diff --git a/objectFactory.go b/objectFactory.go index fc5ef04..9474c54 100644 --- a/objectFactory.go +++ b/objectFactory.go @@ -62,6 +62,8 @@ func createObjectFactory(parent *channelOwner, objectType string, guid string, i return newTracing(parent, objectType, guid, initializer) case "WebSocket": return newWebsocket(parent, objectType, guid, initializer) + case "WebSocketRoute": + return newWebSocketRoute(parent, objectType, guid, initializer) case "Worker": return newWorker(parent, objectType, guid, initializer) case "WritableStream": diff --git a/page.go b/page.go index 692d647..5b1ed12 100644 --- a/page.go +++ b/page.go @@ -25,6 +25,7 @@ type pageImpl struct { workers []Worker mainFrame Frame routes []*routeHandlerEntry + webSocketRoutes []*webSocketRouteHandler viewportSize *Size ownedContext BrowserContext bindings *safe.SyncMap[string, BindingCallFunction] @@ -83,29 +84,27 @@ func (p *pageImpl) onLocatorHandlerTriggered(uid float64) { remove = Bool(true) } } - go func() { - defer func() { - if remove != nil && *remove { - delete(p.locatorHandlers, uid) - } - _, _ = p.connection.WrapAPICall(func() (interface{}, error) { - p.channel.SendNoReply("resolveLocatorHandlerNoReply", true, map[string]any{ - "uid": uid, - "remove": remove, - }) - return nil, nil - }, true) - }() - - handler.handler(handler.locator) + defer func() { + if remove != nil && *remove { + delete(p.locatorHandlers, uid) + } + _, _ = p.connection.WrapAPICall(func() (interface{}, error) { + _, err := p.channel.Send("resolveLocatorHandlerNoReply", map[string]any{ + "uid": uid, + "remove": remove, + }) + return nil, err + }, true) }() + + handler.handler(handler.locator) } func (p *pageImpl) RemoveLocatorHandler(locator Locator) error { for uid := range p.locatorHandlers { if p.locatorHandlers[uid].locator.equals(locator) { delete(p.locatorHandlers, uid) - p.channel.SendNoReply("unregisterLocatorHandler", false, map[string]any{ + p.channel.SendNoReply("unregisterLocatorHandler", map[string]any{ "uid": uid, }) return nil @@ -194,14 +193,14 @@ func (p *pageImpl) Frames() []Frame { func (p *pageImpl) SetDefaultNavigationTimeout(timeout float64) { p.timeoutSettings.SetDefaultNavigationTimeout(&timeout) - p.channel.SendNoReply("setDefaultNavigationTimeoutNoReply", true, map[string]interface{}{ + p.channel.SendNoReplyInternal("setDefaultNavigationTimeoutNoReply", map[string]interface{}{ "timeout": timeout, }) } func (p *pageImpl) SetDefaultTimeout(timeout float64) { p.timeoutSettings.SetDefaultTimeout(&timeout) - p.channel.SendNoReply("setDefaultTimeoutNoReply", true, map[string]interface{}{ + p.channel.SendNoReplyInternal("setDefaultTimeoutNoReply", map[string]interface{}{ "timeout": timeout, }) } @@ -800,7 +799,7 @@ func newPage(parent *channelOwner, objectType string, guid string, initializer m bt.keyboard = newKeyboard(bt.channel) bt.touchscreen = newTouchscreen(bt.channel) bt.channel.On("bindingCall", func(params map[string]interface{}) { - go bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl)) + bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl)) }) bt.channel.On("close", bt.onClose) bt.channel.On("crash", func() { @@ -819,7 +818,9 @@ func newPage(parent *channelOwner, objectType string, guid string, initializer m bt.onFrameDetached(fromChannel(ev["frame"]).(*frameImpl)) }) bt.channel.On("locatorHandlerTriggered", func(ev map[string]interface{}) { - bt.onLocatorHandlerTriggered(ev["uid"].(float64)) + bt.channel.CreateTask(func() { + bt.onLocatorHandlerTriggered(ev["uid"].(float64)) + }) }) bt.channel.On( "load", func(ev map[string]interface{}) { @@ -830,7 +831,9 @@ func newPage(parent *channelOwner, objectType string, guid string, initializer m bt.Emit("popup", fromChannel(ev["page"]).(*pageImpl)) }) bt.channel.On("route", func(ev map[string]interface{}) { - bt.onRoute(fromChannel(ev["route"]).(*routeImpl)) + bt.channel.CreateTask(func() { + bt.onRoute(fromChannel(ev["route"]).(*routeImpl)) + }) }) bt.channel.On("download", func(ev map[string]interface{}) { url := ev["url"].(string) @@ -845,6 +848,11 @@ func newPage(parent *channelOwner, objectType string, guid string, initializer m bt.channel.On("webSocket", func(ev map[string]interface{}) { bt.Emit("websocket", fromChannel(ev["webSocket"]).(*webSocketImpl)) }) + bt.channel.On("webSocketRoute", func(ev map[string]interface{}) { + bt.channel.CreateTask(func() { + bt.onWebSocketRoute(fromChannel(ev["webSocketRoute"]).(*webSocketRouteImpl)) + }) + }) bt.channel.On("worker", func(ev map[string]interface{}) { bt.onWorker(fromChannel(ev["worker"]).(*workerImpl)) @@ -887,7 +895,7 @@ func (p *pageImpl) onBinding(binding *bindingCallImpl) { if !ok || function == nil { return } - binding.Call(function) + go binding.Call(function) } func (p *pageImpl) onFrameAttached(frame *frameImpl) { @@ -911,55 +919,53 @@ func (p *pageImpl) onFrameDetached(frame *frameImpl) { } func (p *pageImpl) onRoute(route *routeImpl) { - go func() { - p.Lock() - route.context = p.browserContext - routes := make([]*routeHandlerEntry, len(p.routes)) - copy(routes, p.routes) - p.Unlock() + p.Lock() + route.context = p.browserContext + routes := make([]*routeHandlerEntry, len(p.routes)) + copy(routes, p.routes) + p.Unlock() - checkInterceptionIfNeeded := func() { - p.Lock() - defer p.Unlock() - if len(p.routes) == 0 { - _, err := p.connection.WrapAPICall(func() (interface{}, error) { - err := p.updateInterceptionPatterns() - return nil, err - }, true) - if err != nil { - logger.Printf("could not update interception patterns: %v\n", err) - } + checkInterceptionIfNeeded := func() { + p.Lock() + defer p.Unlock() + if len(p.routes) == 0 { + _, err := p.connection.WrapAPICall(func() (interface{}, error) { + err := p.updateInterceptionPatterns() + return nil, err + }, true) + if err != nil { + logger.Printf("could not update interception patterns: %v\n", err) } } + } - url := route.Request().URL() - for _, handlerEntry := range routes { - // If the page was closed we stall all requests right away. - if p.closeWasCalled || p.browserContext.closeWasCalled { - return - } - if !handlerEntry.Matches(url) { - continue - } - if !slices.ContainsFunc(p.routes, func(entry *routeHandlerEntry) bool { - return entry == handlerEntry - }) { - continue - } - if handlerEntry.WillExceed() { - p.routes = slices.DeleteFunc(p.routes, func(rhe *routeHandlerEntry) bool { - return rhe == handlerEntry - }) - } - handled := handlerEntry.Handle(route) - checkInterceptionIfNeeded() + url := route.Request().URL() + for _, handlerEntry := range routes { + // If the page was closed we stall all requests right away. + if p.closeWasCalled || p.browserContext.closeWasCalled { + return + } + if !handlerEntry.Matches(url) { + continue + } + if !slices.ContainsFunc(p.routes, func(entry *routeHandlerEntry) bool { + return entry == handlerEntry + }) { + continue + } + if handlerEntry.WillExceed() { + p.routes = slices.DeleteFunc(p.routes, func(rhe *routeHandlerEntry) bool { + return rhe == handlerEntry + }) + } + handled := handlerEntry.Handle(route) + checkInterceptionIfNeeded() - if <-handled { - return - } + if <-handled { + return } - p.browserContext.onRoute(route) - }() + } + p.browserContext.onRoute(route) } func (p *pageImpl) updateInterceptionPatterns() error { @@ -1345,3 +1351,34 @@ func (p *pageImpl) RequestGC() error { _, err := p.channel.Send("requestGC") return err } + +func (p *pageImpl) RouteWebSocket(url interface{}, handler func(WebSocketRoute)) error { + p.Lock() + defer p.Unlock() + p.webSocketRoutes = slices.Insert(p.webSocketRoutes, 0, newWebSocketRouteHandler(newURLMatcher(url, p.browserContext.options.BaseURL), handler)) + + return p.updateWebSocketInterceptionPatterns() +} + +func (p *pageImpl) onWebSocketRoute(wr WebSocketRoute) { + p.Lock() + index := slices.IndexFunc(p.webSocketRoutes, func(r *webSocketRouteHandler) bool { + return r.Matches(wr.URL()) + }) + if index == -1 { + p.Unlock() + p.browserContext.onWebSocketRoute(wr) + return + } + handler := p.webSocketRoutes[index] + p.Unlock() + handler.Handle(wr) +} + +func (p *pageImpl) updateWebSocketInterceptionPatterns() error { + patterns := prepareWebSocketRouteHandlerInterceptionPatterns(p.webSocketRoutes) + _, err := p.channel.Send("setWebSocketInterceptionPatterns", map[string]interface{}{ + "patterns": patterns, + }) + return err +} diff --git a/patches/main.patch b/patches/main.patch index 5051c9e..f31aaf1 100644 --- a/patches/main.patch +++ b/patches/main.patch @@ -322,7 +322,7 @@ index 59cf4c99c..95f686ffa 100644 :::note diff --git a/docs/src/api/class-browsercontext.md b/docs/src/api/class-browsercontext.md -index b504bf457..ffda8197f 100644 +index b504bf457..453add560 100644 --- a/docs/src/api/class-browsercontext.md +++ b/docs/src/api/class-browsercontext.md @@ -417,7 +417,7 @@ The order of evaluation of multiple scripts installed via [`method: BrowserConte @@ -343,15 +343,16 @@ index b504bf457..ffda8197f 100644 - `handler` <[function]\([Route]\)> handler function to route the request. -@@ -1269,6 +1269,7 @@ Optional setting to control resource content management. If `attach` is specifie +@@ -1353,7 +1353,7 @@ Handler function to route the WebSocket. - ## async method: BrowserContext.routeWebSocket + ### param: BrowserContext.routeWebSocket.handler * since: v1.48 -+* langs: js, python, csharp, java - - This method allows to modify websocket connections that are made by any page in the browser context. +-* langs: csharp, java ++* langs: csharp, java, go + - `handler` <[function]\([WebSocketRoute]\)> -@@ -1361,7 +1362,7 @@ Handler function to route the WebSocket. + Handler function to route the WebSocket. +@@ -1361,7 +1361,7 @@ Handler function to route the WebSocket. ## method: BrowserContext.serviceWorkers * since: v1.11 @@ -360,7 +361,7 @@ index b504bf457..ffda8197f 100644 - returns: <[Array]<[Worker]>> :::note -@@ -1548,6 +1549,13 @@ A glob pattern, regex pattern or predicate receiving [URL] used to register a ro +@@ -1548,6 +1548,13 @@ A glob pattern, regex pattern or predicate receiving [URL] used to register a ro Optional handler function used to register a routing with [`method: BrowserContext.route`]. @@ -374,7 +375,7 @@ index b504bf457..ffda8197f 100644 ### param: BrowserContext.unroute.handler * since: v1.8 * langs: csharp, java -@@ -1589,7 +1597,8 @@ Condition to wait for. +@@ -1589,7 +1596,8 @@ Condition to wait for. ## async method: BrowserContext.waitForConsoleMessage * since: v1.34 @@ -384,7 +385,7 @@ index b504bf457..ffda8197f 100644 - alias-python: expect_console_message - alias-csharp: RunAndWaitForConsoleMessage - returns: <[ConsoleMessage]> -@@ -1620,7 +1629,8 @@ Receives the [ConsoleMessage] object and resolves to truthy value when the waiti +@@ -1620,7 +1628,8 @@ Receives the [ConsoleMessage] object and resolves to truthy value when the waiti ## async method: BrowserContext.waitForEvent * since: v1.8 @@ -394,7 +395,7 @@ index b504bf457..ffda8197f 100644 - alias-python: expect_event - returns: <[any]> -@@ -1686,7 +1696,8 @@ Either a predicate that receives an event or an options object. Optional. +@@ -1686,7 +1695,8 @@ Either a predicate that receives an event or an options object. Optional. ## async method: BrowserContext.waitForPage * since: v1.9 @@ -404,7 +405,7 @@ index b504bf457..ffda8197f 100644 - alias-python: expect_page - alias-csharp: RunAndWaitForPage - returns: <[Page]> -@@ -1705,7 +1716,7 @@ Will throw an error if the context closes before new [Page] is created. +@@ -1705,7 +1715,7 @@ Will throw an error if the context closes before new [Page] is created. ### option: BrowserContext.waitForPage.predicate * since: v1.9 @@ -413,7 +414,7 @@ index b504bf457..ffda8197f 100644 - `predicate` <[function]\([Page]\):[boolean]> Receives the [Page] object and resolves to truthy value when the waiting should resolve. -@@ -1718,7 +1729,8 @@ Receives the [Page] object and resolves to truthy value when the waiting should +@@ -1718,7 +1728,8 @@ Receives the [Page] object and resolves to truthy value when the waiting should ## async method: BrowserContext.waitForEvent2 * since: v1.8 @@ -586,7 +587,7 @@ index 1ff5e5211..31dd3a4b0 100644 Expected options currently selected. diff --git a/docs/src/api/class-page.md b/docs/src/api/class-page.md -index 60512f51c..ea491d514 100644 +index f0c8b34b0..a6d46e70b 100644 --- a/docs/src/api/class-page.md +++ b/docs/src/api/class-page.md @@ -621,7 +621,7 @@ The order of evaluation of multiple scripts installed via [`method: BrowserConte @@ -721,15 +722,16 @@ index 60512f51c..ea491d514 100644 ### param: Page.route.handler * since: v1.8 * langs: csharp, java -@@ -3679,6 +3686,7 @@ Optional setting to control resource content management. If `attach` is specifie +@@ -3752,7 +3759,7 @@ Handler function to route the WebSocket. - ## async method: Page.routeWebSocket + ### param: Page.routeWebSocket.handler * since: v1.48 -+* langs: js, python, csharp, java - - This method allows to modify websocket connections that are made by the page. +-* langs: csharp, java ++* langs: csharp, java, go + - `handler` <[function]\([WebSocketRoute]\)> -@@ -4081,14 +4089,14 @@ await page.GotoAsync("https://www.microsoft.com"); + Handler function to route the WebSocket. +@@ -4081,14 +4088,14 @@ await page.GotoAsync("https://www.microsoft.com"); ### param: Page.setViewportSize.width * since: v1.10 @@ -746,7 +748,7 @@ index 60512f51c..ea491d514 100644 - `height` <[int]> Page height in pixels. -@@ -4275,6 +4283,13 @@ A glob pattern, regex pattern or predicate receiving [URL] to match while routin +@@ -4275,6 +4282,13 @@ A glob pattern, regex pattern or predicate receiving [URL] to match while routin Optional handler function to route the request. @@ -760,7 +762,7 @@ index 60512f51c..ea491d514 100644 ### param: Page.unroute.handler * since: v1.8 * langs: csharp, java -@@ -4313,7 +4328,8 @@ Performs action and waits for the Page to close. +@@ -4313,7 +4327,8 @@ Performs action and waits for the Page to close. ## async method: Page.waitForConsoleMessage * since: v1.9 @@ -770,7 +772,7 @@ index 60512f51c..ea491d514 100644 - alias-python: expect_console_message - alias-csharp: RunAndWaitForConsoleMessage - returns: <[ConsoleMessage]> -@@ -4344,7 +4360,8 @@ Receives the [ConsoleMessage] object and resolves to truthy value when the waiti +@@ -4344,7 +4359,8 @@ Receives the [ConsoleMessage] object and resolves to truthy value when the waiti ## async method: Page.waitForDownload * since: v1.9 @@ -780,7 +782,7 @@ index 60512f51c..ea491d514 100644 - alias-python: expect_download - alias-csharp: RunAndWaitForDownload - returns: <[Download]> -@@ -4375,7 +4392,8 @@ Receives the [Download] object and resolves to truthy value when the waiting sho +@@ -4375,7 +4391,8 @@ Receives the [Download] object and resolves to truthy value when the waiting sho ## async method: Page.waitForEvent * since: v1.8 @@ -790,7 +792,7 @@ index 60512f51c..ea491d514 100644 - alias-python: expect_event - returns: <[any]> -@@ -4428,7 +4446,8 @@ Either a predicate that receives an event or an options object. Optional. +@@ -4428,7 +4445,8 @@ Either a predicate that receives an event or an options object. Optional. ## async method: Page.waitForFileChooser * since: v1.9 @@ -800,7 +802,7 @@ index 60512f51c..ea491d514 100644 - alias-python: expect_file_chooser - alias-csharp: RunAndWaitForFileChooser - returns: <[FileChooser]> -@@ -4586,7 +4605,7 @@ await page.WaitForFunctionAsync("selector => !!document.querySelector(selector)" +@@ -4586,7 +4604,7 @@ await page.WaitForFunctionAsync("selector => !!document.querySelector(selector)" Optional argument to pass to [`param: expression`]. @@ -809,7 +811,7 @@ index 60512f51c..ea491d514 100644 * since: v1.8 ### option: Page.waitForFunction.polling = %%-csharp-java-wait-for-function-polling-%% -@@ -4683,6 +4702,11 @@ Console.WriteLine(await popup.TitleAsync()); // popup is ready to use. +@@ -4683,6 +4701,11 @@ Console.WriteLine(await popup.TitleAsync()); // popup is ready to use. ``` ### param: Page.waitForLoadState.state = %%-wait-for-load-state-state-%% @@ -821,7 +823,7 @@ index 60512f51c..ea491d514 100644 * since: v1.8 ### option: Page.waitForLoadState.timeout = %%-navigation-timeout-%% -@@ -4695,6 +4719,7 @@ Console.WriteLine(await popup.TitleAsync()); // popup is ready to use. +@@ -4695,6 +4718,7 @@ Console.WriteLine(await popup.TitleAsync()); // popup is ready to use. * since: v1.8 * deprecated: This method is inherently racy, please use [`method: Page.waitForURL`] instead. * langs: @@ -829,7 +831,7 @@ index 60512f51c..ea491d514 100644 * alias-python: expect_navigation * alias-csharp: RunAndWaitForNavigation - returns: <[null]|[Response]> -@@ -4779,7 +4804,8 @@ a navigation. +@@ -4779,7 +4803,8 @@ a navigation. ## async method: Page.waitForPopup * since: v1.9 @@ -839,7 +841,7 @@ index 60512f51c..ea491d514 100644 - alias-python: expect_popup - alias-csharp: RunAndWaitForPopup - returns: <[Page]> -@@ -4811,6 +4837,7 @@ Receives the [Page] object and resolves to truthy value when the waiting should +@@ -4811,6 +4836,7 @@ Receives the [Page] object and resolves to truthy value when the waiting should ## async method: Page.waitForRequest * since: v1.8 * langs: @@ -847,7 +849,7 @@ index 60512f51c..ea491d514 100644 * alias-python: expect_request * alias-csharp: RunAndWaitForRequest - returns: <[Request]> -@@ -4918,7 +4945,8 @@ changed by using the [`method: Page.setDefaultTimeout`] method. +@@ -4918,7 +4944,8 @@ changed by using the [`method: Page.setDefaultTimeout`] method. ## async method: Page.waitForRequestFinished * since: v1.12 @@ -857,7 +859,7 @@ index 60512f51c..ea491d514 100644 - alias-python: expect_request_finished - alias-csharp: RunAndWaitForRequestFinished - returns: <[Request]> -@@ -4950,6 +4978,7 @@ Receives the [Request] object and resolves to truthy value when the waiting shou +@@ -4950,6 +4977,7 @@ Receives the [Request] object and resolves to truthy value when the waiting shou ## async method: Page.waitForResponse * since: v1.8 * langs: @@ -865,7 +867,7 @@ index 60512f51c..ea491d514 100644 * alias-python: expect_response * alias-csharp: RunAndWaitForResponse - returns: <[Response]> -@@ -5313,7 +5342,8 @@ await page.WaitForURLAsync("**/target.html"); +@@ -5313,7 +5341,8 @@ await page.WaitForURLAsync("**/target.html"); ## async method: Page.waitForWebSocket * since: v1.9 @@ -875,7 +877,7 @@ index 60512f51c..ea491d514 100644 - alias-python: expect_websocket - alias-csharp: RunAndWaitForWebSocket - returns: <[WebSocket]> -@@ -5344,7 +5374,8 @@ Receives the [WebSocket] object and resolves to truthy value when the waiting sh +@@ -5344,7 +5373,8 @@ Receives the [WebSocket] object and resolves to truthy value when the waiting sh ## async method: Page.waitForWorker * since: v1.9 @@ -885,7 +887,7 @@ index 60512f51c..ea491d514 100644 - alias-python: expect_worker - alias-csharp: RunAndWaitForWorker - returns: <[Worker]> -@@ -5386,7 +5417,8 @@ This does not contain ServiceWorkers +@@ -5386,7 +5416,8 @@ This does not contain ServiceWorkers ## async method: Page.waitForEvent2 * since: v1.8 @@ -1126,16 +1128,32 @@ index 3b3308b88..e21ea93ef 100644 - returns: <[any]> diff --git a/docs/src/api/class-websocketroute.md b/docs/src/api/class-websocketroute.md -index b827db25d..e5fd459c4 100644 +index e23316ebc..ae3d20b9f 100644 --- a/docs/src/api/class-websocketroute.md +++ b/docs/src/api/class-websocketroute.md -@@ -1,5 +1,6 @@ - # class: WebSocketRoute +@@ -325,7 +325,7 @@ Function that will handle WebSocket closure. Received an optional [close code](h + + ### param: WebSocketRoute.onClose.handler * since: v1.48 -+* langs: js, python, csharp, java +-* langs: java ++* langs: java, go + - `handler` <[function]\([null]|[int], [null]|[string]\)> + + Function that will handle WebSocket closure. Received an optional [close code](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#code) and an optional [close reason](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#reason). +@@ -355,6 +355,13 @@ Calling this method again will override the handler with a new one. - Whenever a [`WebSocket`](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket) route is set up with [`method: Page.routeWebSocket`] or [`method: BrowserContext.routeWebSocket`], the `WebSocketRoute` object allows to handle the WebSocket, like an actual server would do. + Function that will handle messages. ++### param: WebSocketRoute.onMessage.handler ++* since: v1.48 ++* langs: go ++- `handler` <[function]\([string]|[Buffer]\)> ++ ++Function that will handle messages. ++ + ### param: WebSocketRoute.onMessage.handler + * since: v1.48 + * langs: csharp, java diff --git a/docs/src/api/params.md b/docs/src/api/params.md index e86c0a19f..08559dc31 100644 --- a/docs/src/api/params.md @@ -1504,10 +1522,10 @@ index e86c0a19f..08559dc31 100644 Firefox user preferences. Learn more about the Firefox user preferences at diff --git a/utils/doclint/generateGoApi.js b/utils/doclint/generateGoApi.js new file mode 100644 -index 000000000..e24572392 +index 000000000..cfb254664 --- /dev/null +++ b/utils/doclint/generateGoApi.js -@@ -0,0 +1,866 @@ +@@ -0,0 +1,868 @@ +/** + * Copyright (c) Microsoft Corporation. + * @@ -1972,6 +1990,8 @@ index 000000000..e24572392 + returns.pop(); + if (parent.name === 'Locator' && (name === 'Page' || name === 'All')) // Locator.Page() (Page, error) + returns.push('error'); ++ if (parent.name === 'WebSocketRoute' && ['Send', 'Close'].includes(name)) ++ returns.pop(); + + // render args + let args = []; @@ -2210,7 +2230,7 @@ index 000000000..e24572392 + return 'interface{}'; + else if (type.expression === '[path]|[Array]<[path]>|[Object]|[Array]<[Object]>') + return 'interface{}'; -+ else if (type.expression === '[function]([null]|[number], [null]|[string])') ++ else if (type.expression === '[function]([null]|[int], [null]|[string])') + return 'func(*int, *string)'; + + let isNullableEnum = false; diff --git a/playwright b/playwright index ceb756d..dc80964 160000 --- a/playwright +++ b/playwright @@ -1 +1 @@ -Subproject commit ceb756dad3a3089d470890fd5c75aa585a47cb7c +Subproject commit dc80964a3f84dc120b5fed8837ff492a38ddb26e diff --git a/route.go b/route.go index 927eb26..c9e6ee9 100644 --- a/route.go +++ b/route.go @@ -253,15 +253,18 @@ func (r *routeImpl) internalContinue(isFallback bool) error { func (r *routeImpl) redirectedNavigationRequest(url string) error { return r.handleRoute(func() error { - _, err := r.channel.Send("redirectNavigationRequest", map[string]interface{}{ - "url": url, + return r.raceWithPageClose(func() error { + _, err := r.channel.Send("redirectNavigationRequest", map[string]interface{}{ + "url": url, + }) + return err }) - return err }) } func newRoute(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *routeImpl { bt := &routeImpl{} bt.createChannelOwner(bt, parent, objectType, guid, initializer) + bt.markAsInternalType() return bt } diff --git a/run.go b/run.go index 11b2ee7..198e3fa 100644 --- a/run.go +++ b/run.go @@ -16,7 +16,7 @@ import ( ) const ( - playwrightCliVersion = "1.48.1" + playwrightCliVersion = "1.48.2" ) var ( diff --git a/selectors.go b/selectors.go index f79e07c..1151647 100644 --- a/selectors.go +++ b/selectors.go @@ -11,7 +11,7 @@ type selectorsOwnerImpl struct { } func (s *selectorsOwnerImpl) setTestIdAttributeName(name string) { - s.channel.SendNoReply("setTestIdAttributeName", false, map[string]interface{}{ + s.channel.SendNoReply("setTestIdAttributeName", map[string]interface{}{ "testIdAttributeName": name, }) } @@ -71,7 +71,7 @@ func (s *selectorsImpl) SetTestIdAttribute(name string) { func (s *selectorsImpl) addChannel(channel *selectorsOwnerImpl) { s.channels.Store(channel.guid, channel) for _, params := range s.registrations { - channel.channel.SendNoReply("register", false, params) + channel.channel.SendNoReply("register", params) channel.setTestIdAttributeName(getTestIdAttributeName()) } } diff --git a/tests/browser_context_test.go b/tests/browser_context_test.go index f0d459b..0b2e828 100644 --- a/tests/browser_context_test.go +++ b/tests/browser_context_test.go @@ -765,3 +765,18 @@ func TestBrowserContextShouldRetryECONNRESET(t *testing.T) { require.Equal(t, []byte("Hello!"), body) require.Equal(t, int32(4), requestCount.Load()) } + +func TestBrowserContextShouldShowErrorAfterFulfill(t *testing.T) { + BeforeEach(t) + + require.NoError(t, page.Route("**/*", func(route playwright.Route) { + require.NoError(t, route.Continue()) + panic("Exception text!?") + })) + + _, err := page.Goto(server.EMPTY_PAGE) + require.NoError(t, err) + // Any next API call should throw because handler did throw during previous goto() + _, err = page.Goto(server.EMPTY_PAGE) + require.ErrorContains(t, err, "Exception text!?") +} diff --git a/tests/route_web_socket_test.go b/tests/route_web_socket_test.go new file mode 100644 index 0000000..978b7bf --- /dev/null +++ b/tests/route_web_socket_test.go @@ -0,0 +1,340 @@ +package playwright_test + +import ( + "fmt" + "net/http" + "regexp" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/playwright-community/playwright-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func assertSlicesEqual(t *testing.T, expected []interface{}, cb func() (interface{}, error)) { + t.Helper() + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + actual, err := cb() + require.NoError(t, err) + require.ElementsMatch(t, expected, actual) + }, 5*time.Second, 200*time.Millisecond) +} + +func setupWS(t *testing.T, target playwright.Page, port string, protocol string) { + t.Helper() + + _, err := target.Goto("about:blank") + require.NoError(t, err) + _, err = target.Evaluate(` + ({ port, binaryType }) => { + window.log = []; + window.ws = new WebSocket('ws://localhost:' + port + '/ws'); + window.ws.binaryType = binaryType; + window.ws.addEventListener('open', () => window.log.push('open')); + window.ws.addEventListener('close', event => window.log.push(`+"`close code=${event.code} reason=${event.reason} wasClean=${event.wasClean}`"+`)); + window.ws.addEventListener('error', event => window.log.push(`+"`error`"+`)); + window.ws.addEventListener('message', async event => { + let data; + console.log(event); + if (typeof event.data === 'string') + data = event.data; + else if (event.data instanceof Blob) + data = 'blob:' + await event.data.text(); + else + data = 'arraybuffer:' + await (new Blob([event.data])).text(); + window.log.push(`+"`message: data=${data} origin=${event.origin} lastEventId=${event.lastEventId}`"+`); + }); + window.wsOpened = new Promise(f => window.ws.addEventListener('open', () => f())); + }`, map[string]interface{}{"port": port, "binaryType": protocol}) + require.NoError(t, err) +} + +func TestShouldWorkWithWSClose(t *testing.T) { + BeforeEach(t) + + wsRouteChan := make(chan playwright.WebSocketRoute, 1) + + handleWS := func(ws playwright.WebSocketRoute) { + _, _ = ws.ConnectToServer() + wsRouteChan <- ws + } + + require.NoError(t, page.RouteWebSocket(regexp.MustCompile(".*"), handleWS)) + + wsConnChan := server.WaitForWebSocketConnection() + setupWS(t, page, server.PORT, "blob") + <-wsConnChan + + wsRoute := <-wsRouteChan + wsRoute.Send("hello") + + assertSlicesEqual(t, []interface{}{"open", "message: data=hello origin=ws://localhost:" + server.PORT + " lastEventId="}, func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) + + closedError := make(chan *websocket.CloseError, 1) + server.OnceWebSocketClose(func(err *websocket.CloseError) { + closedError <- err + }) + + wsRoute.Close(playwright.WebSocketRouteCloseOptions{ + Code: playwright.Int(3009), + Reason: playwright.String("oops"), + }) + + assertSlicesEqual(t, + []interface{}{ + "open", + "message: data=hello origin=ws://localhost:" + server.PORT + " lastEventId=", + "close code=3009 reason=oops wasClean=true", + }, + func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) + + result := <-closedError + require.Equal(t, "3009, oops", fmt.Sprintf("%d, %s", result.Code, result.Reason)) +} + +func TestShouldPatterMatch(t *testing.T) { + BeforeEach(t) + + require.NoError(t, page.RouteWebSocket(regexp.MustCompile(".*/ws$"), func(ws playwright.WebSocketRoute) { + go func() { + _, _ = ws.ConnectToServer() + }() + })) + require.NoError(t, page.RouteWebSocket("**/mock-ws", func(ws playwright.WebSocketRoute) { + ws.OnMessage(func(i interface{}) { + ws.Send("mock-response") + }) + })) + + wsConnChan := server.WaitForWebSocketConnection() + _, err := page.Goto("about:blank") + require.NoError(t, err) + + _, err = page.Evaluate(` + async ({ port }) => { + window.log = []; + window.ws1 = new WebSocket('ws://localhost:' + port + '/ws'); + window.ws1.addEventListener('message', event => window.log.push(`+"`ws1:${event.data}`"+`)); + window.ws2 = new WebSocket('ws://localhost:' + port + '/something/something/mock-ws'); + window.ws2.addEventListener('message', event => window.log.push(`+"`ws2:${event.data}`"+`)); + await Promise.all([ + new Promise(f => window.ws1.addEventListener('open', f)), + new Promise(f => window.ws2.addEventListener('open', f)), + ]); + } + `, map[string]interface{}{"port": server.PORT}) + require.NoError(t, err) + + <-wsConnChan + server.OnWebSocketMessage(func(c *websocket.Conn, r *http.Request, msgType websocket.MessageType, msg []byte) { + err := c.Write(r.Context(), websocket.MessageText, []byte("response")) + t.Log(err) + }) + + _, err = page.Evaluate(`window.ws1.send('request')`) + require.NoError(t, err) + assertSlicesEqual(t, []interface{}{"ws1:response"}, func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) + + _, err = page.Evaluate(`window.ws2.send('request')`) + require.NoError(t, err) + assertSlicesEqual(t, []interface{}{"ws1:response", "ws2:mock-response"}, func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) +} + +func TestRouteWebSocketShouldWorkWithServer(t *testing.T) { + BeforeEach(t) + + wsRouteChan := make(chan playwright.WebSocketRoute, 1) + + handleWS := func(ws playwright.WebSocketRoute) { + server, err := ws.ConnectToServer() + require.NoError(t, err) + + ws.OnMessage(func(message interface{}) { + msg := message.(string) + switch msg { + case "to-respond": + ws.Send("response") + return + case "to-block": + return + case "to-modify": + server.Send("modified") + return + default: + server.Send(message) + } + }) + + server.OnMessage(func(message interface{}) { + msg := message.(string) + switch msg { + case "to-block": + return + case "to-modify": + ws.Send("modified") + return + default: + ws.Send(message) + } + }) + + server.Send("fake") + wsRouteChan <- ws + } + + log := newSyncSlice[string]() + + server.OnWebSocketMessage(func(c *websocket.Conn, r *http.Request, msgType websocket.MessageType, msg []byte) { + log.Append(fmt.Sprintf("message: %s", msg)) + }) + server.OnWebSocketClose(func(err *websocket.CloseError) { + log.Append(fmt.Sprintf("close: code=%d reason=%s", err.Code, err.Reason)) + }) + + require.NoError(t, page.RouteWebSocket(regexp.MustCompile(".*"), handleWS)) + + wsConnChan := server.WaitForWebSocketConnection() + + setupWS(t, page, server.PORT, "blob") + ws := <-wsConnChan + require.EventuallyWithT(t, func(collect *assert.CollectT) { + require.ElementsMatch(t, []string{"message: fake"}, log.Get()) + }, 5*time.Second, 200*time.Millisecond) + + ws.SendMessage(websocket.MessageText, []byte("to-modify")) + ws.SendMessage(websocket.MessageText, []byte("to-block")) + ws.SendMessage(websocket.MessageText, []byte("pass-server")) + + assertSlicesEqual(t, []interface{}{ + "open", + "message: data=modified origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=pass-server origin=ws://localhost:" + server.PORT + " lastEventId=", + }, func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) + + _, err := page.Evaluate(` + () => { + window.ws.send('to-respond'); + window.ws.send('to-modify'); + window.ws.send('to-block'); + window.ws.send('pass-client'); + }`) + require.NoError(t, err) + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + require.ElementsMatch(t, []string{"message: fake", "message: modified", "message: pass-client"}, log.Get()) + }, 5*time.Second, 200*time.Millisecond) + + assertSlicesEqual(t, []interface{}{ + "open", + "message: data=modified origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=pass-server origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=response origin=ws://localhost:" + server.PORT + " lastEventId=", + }, func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) + + route := <-wsRouteChan + route.Send("another") + assertSlicesEqual(t, []interface{}{ + "open", + "message: data=modified origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=pass-server origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=response origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=another origin=ws://localhost:" + server.PORT + " lastEventId=", + }, func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) + + _, err = page.Evaluate(` + () => { + window.ws.send('pass-client-2'); + }`) + require.NoError(t, err) + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + require.ElementsMatch(t, []string{"message: fake", "message: modified", "message: pass-client", "message: pass-client-2"}, log.Get()) + }, 5*time.Second, 200*time.Millisecond) + + _, err = page.Evaluate(` + () => { + window.ws.close(3009, 'problem'); + }`) + require.NoError(t, err) + + require.EventuallyWithT(t, func(collect *assert.CollectT) { + require.ElementsMatch(t, []string{ + "message: fake", + "message: modified", + "message: pass-client", + "message: pass-client-2", + "close: code=3009 reason=problem", + }, log.Get()) + }, 5*time.Second, 200*time.Millisecond) +} + +func TestRouteWebSocketShouldWorkWithoutServer(t *testing.T) { + BeforeEach(t) + + wsRouteChan := make(chan playwright.WebSocketRoute, 1) + + handleWS := func(ws playwright.WebSocketRoute) { + ws.OnMessage(func(message interface{}) { + if message.(string) == "to-respond" { + ws.Send("response") + } + }) + + wsRouteChan <- ws + } + + require.NoError(t, page.RouteWebSocket(regexp.MustCompile(".*"), handleWS)) + setupWS(t, page, server.PORT, "blob") + + _, err := page.Evaluate(` + async () => { + await window.wsOpened; + window.ws.send('to-respond'); + window.ws.send('to-block'); + window.ws.send('to-respond'); + }`) + require.NoError(t, err) + + assertSlicesEqual(t, []interface{}{ + "open", + "message: data=response origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=response origin=ws://localhost:" + server.PORT + " lastEventId=", + }, func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) + + route := <-wsRouteChan + route.Send("another") + // wait for the message to be processed + time.Sleep(100 * time.Millisecond) + route.Close(playwright.WebSocketRouteCloseOptions{ + Code: playwright.Int(3008), + Reason: playwright.String("oops"), + }) + assertSlicesEqual(t, []interface{}{ + "open", + "message: data=response origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=response origin=ws://localhost:" + server.PORT + " lastEventId=", + "message: data=another origin=ws://localhost:" + server.PORT + " lastEventId=", + "close code=3008 reason=oops wasClean=true", + }, func() (interface{}, error) { + return page.Evaluate(`window.log`) + }) +} diff --git a/tests/utils_test.go b/tests/utils_test.go index a6c9d6e..2acc2f2 100644 --- a/tests/utils_test.go +++ b/tests/utils_test.go @@ -86,6 +86,7 @@ func (t *testServer) AfterEach() { t.requestSubscriberes = make(map[string][]chan *http.Request) t.eventEmitter.RemoveListeners("connection") t.eventEmitter.RemoveListeners("message") + t.eventEmitter.RemoveListeners("close") t.testServer.CloseClientConnections() } @@ -97,10 +98,31 @@ func (t *testServer) SetTLSConfig(config *tls.Config) { t.testServer.TLS = config } +type wsConnection struct { + Conn *websocket.Conn + Req *http.Request +} + +func (c *wsConnection) SendMessage(msgType websocket.MessageType, data []byte) { + err := c.Conn.Write(c.Req.Context(), msgType, data) + if err != nil { + log.Println("testServer: could not write ws message:", err) + return + } +} + func (t *testServer) OnceWebSocketConnection(handler func(c *websocket.Conn, r *http.Request)) { t.eventEmitter.Once("connection", handler) } +func (t *testServer) OnWebSocketClose(handler func(err *websocket.CloseError)) { + t.eventEmitter.On("close", handler) +} + +func (t *testServer) OnceWebSocketClose(handler func(err *websocket.CloseError)) { + t.eventEmitter.Once("close", handler) +} + func (t *testServer) OnWebSocketMessage(handler func(c *websocket.Conn, r *http.Request, msgType websocket.MessageType, msg []byte)) { t.eventEmitter.On("message", handler) } @@ -119,10 +141,10 @@ func (t *testServer) SendOnWebSocketConnection(msgType websocket.MessageType, da }) } -func (t *testServer) WaitForWebSocketConnection() <-chan *websocket.Conn { - channel := make(chan *websocket.Conn) +func (t *testServer) WaitForWebSocketConnection() <-chan *wsConnection { + channel := make(chan *wsConnection) t.OnceWebSocketConnection(func(c *websocket.Conn, r *http.Request) { - channel <- c + channel <- &wsConnection{Conn: c, Req: r} close(channel) }) return channel @@ -143,7 +165,13 @@ func (t *testServer) wsHandler(w http.ResponseWriter, r *http.Request) { for { typ, message, err := c.Read(r.Context()) if err != nil { - if websocket.CloseStatus(err) != websocket.StatusNormalClosure && websocket.CloseStatus(err) != websocket.StatusNoStatusRcvd { + closeErr := new(websocket.CloseError) + if errors.As(err, closeErr) { + t.eventEmitter.Emit("close", closeErr) + } + switch websocket.CloseStatus(err) { + case websocket.StatusNormalClosure, websocket.StatusGoingAway, websocket.StatusNoStatusRcvd: + default: log.Println("testServer: could not read ws message:", err) } break diff --git a/websocket_route.go b/websocket_route.go new file mode 100644 index 0000000..bb74ab9 --- /dev/null +++ b/websocket_route.go @@ -0,0 +1,220 @@ +package playwright + +import ( + "encoding/base64" + "fmt" + "regexp" + "sync/atomic" +) + +type webSocketRouteImpl struct { + channelOwner + connected *atomic.Bool + server WebSocketRoute + onPageMessage func(interface{}) + onPageClose func(code *int, reason *string) + onServerMessage func(interface{}) + onServerClose func(code *int, reason *string) +} + +func newWebSocketRoute(parent *channelOwner, objectType string, guid string, initializer map[string]interface{}) *webSocketRouteImpl { + route := &webSocketRouteImpl{ + connected: &atomic.Bool{}, + } + route.createChannelOwner(route, parent, objectType, guid, initializer) + route.markAsInternalType() + + route.server = newServerWebSocketRoute(route) + + route.channel.On("messageFromPage", func(event map[string]interface{}) { + msg, err := untransformWebSocketMessage(event) + if err != nil { + panic(fmt.Errorf("Could not decode WebSocket message: %w", err)) + } + if route.onPageMessage != nil { + route.onPageMessage(msg) + } else if route.connected.Load() { + go route.channel.SendNoReply("sendToServer", event) + } + }) + + route.channel.On("messageFromServer", func(event map[string]interface{}) { + msg, err := untransformWebSocketMessage(event) + if err != nil { + panic(fmt.Errorf("Could not decode WebSocket message: %w", err)) + } + if route.onServerMessage != nil { + route.onServerMessage(msg) + } else { + go route.channel.SendNoReply("sendToPage", event) + } + }) + + route.channel.On("closePage", func(event map[string]interface{}) { + if route.onPageClose != nil { + route.onPageClose(event["code"].(*int), event["reason"].(*string)) + } else { + go route.channel.SendNoReply("closeServer", event) + } + }) + + route.channel.On("closeServer", func(event map[string]interface{}) { + if route.onServerClose != nil { + route.onServerClose(event["code"].(*int), event["reason"].(*string)) + } else { + go route.channel.SendNoReply("closePage", event) + } + }) + + return route +} + +func (r *webSocketRouteImpl) Close(options ...WebSocketRouteCloseOptions) { + r.channel.SendNoReply("closePage", options, map[string]interface{}{"wasClean": true}) +} + +func (r *webSocketRouteImpl) ConnectToServer() (WebSocketRoute, error) { + if r.connected.Load() { + return nil, fmt.Errorf("Already connected to the server") + } + r.channel.SendNoReply("connect") + r.connected.Store(true) + return r.server, nil +} + +func (r *webSocketRouteImpl) OnClose(handler func(code *int, reason *string)) { + r.onPageClose = handler +} + +func (r *webSocketRouteImpl) OnMessage(handler func(interface{})) { + r.onPageMessage = handler +} + +func (r *webSocketRouteImpl) Send(message interface{}) { + data, err := transformWebSocketMessage(message) + if err != nil { + panic(fmt.Errorf("Could not encode WebSocket message: %w", err)) + } + go r.channel.SendNoReply("sendToPage", data) +} + +func (r *webSocketRouteImpl) URL() string { + return r.initializer["url"].(string) +} + +func (r *webSocketRouteImpl) afterHandle() error { + if r.connected.Load() { + return nil + } + // Ensure that websocket is "open" and can send messages without an actual server connection. + _, err := r.channel.Send("ensureOpened") + return err +} + +type serverWebSocketRouteImpl struct { + webSocketRoute *webSocketRouteImpl +} + +func newServerWebSocketRoute(route *webSocketRouteImpl) *serverWebSocketRouteImpl { + return &serverWebSocketRouteImpl{webSocketRoute: route} +} + +func (s *serverWebSocketRouteImpl) OnMessage(handler func(interface{})) { + s.webSocketRoute.onServerMessage = handler +} + +func (s *serverWebSocketRouteImpl) OnClose(handler func(code *int, reason *string)) { + s.webSocketRoute.onServerClose = handler +} + +func (s *serverWebSocketRouteImpl) ConnectToServer() (WebSocketRoute, error) { + return nil, fmt.Errorf("ConnectToServer must be called on the page-side WebSocketRoute") +} + +func (s *serverWebSocketRouteImpl) URL() string { + return s.webSocketRoute.URL() +} + +func (s *serverWebSocketRouteImpl) Close(options ...WebSocketRouteCloseOptions) { + go s.webSocketRoute.channel.SendNoReply("close", options, map[string]interface{}{"wasClean": true}) +} + +func (s *serverWebSocketRouteImpl) Send(message interface{}) { + data, err := transformWebSocketMessage(message) + if err != nil { + panic(fmt.Errorf("Could not encode WebSocket message: %w", err)) + } + go s.webSocketRoute.channel.SendNoReply("sendToServer", data) +} + +func transformWebSocketMessage(message interface{}) (map[string]interface{}, error) { + data := map[string]interface{}{} + switch v := message.(type) { + case []byte: + data["isBase64"] = true + data["message"] = base64.StdEncoding.EncodeToString(v) + case string: + data["isBase64"] = false + data["message"] = v + default: + return nil, fmt.Errorf("Unsupported message type: %T", v) + } + return data, nil +} + +func untransformWebSocketMessage(data map[string]interface{}) (interface{}, error) { + if data["isBase64"].(bool) { + return base64.StdEncoding.DecodeString(data["message"].(string)) + } + return data["message"], nil +} + +type webSocketRouteHandler struct { + matcher *urlMatcher + handler func(WebSocketRoute) +} + +func newWebSocketRouteHandler(matcher *urlMatcher, handler func(WebSocketRoute)) *webSocketRouteHandler { + return &webSocketRouteHandler{matcher: matcher, handler: handler} +} + +func (h *webSocketRouteHandler) Handle(route WebSocketRoute) { + h.handler(route) + err := route.(*webSocketRouteImpl).afterHandle() + if err != nil { + panic(fmt.Errorf("Could not handle WebSocketRoute: %w", err)) + } +} + +func (h *webSocketRouteHandler) Matches(wsURL string) bool { + return h.matcher.Matches(wsURL) +} + +func prepareWebSocketRouteHandlerInterceptionPatterns(handlers []*webSocketRouteHandler) []map[string]interface{} { + patterns := []map[string]interface{}{} + all := false + for _, handler := range handlers { + switch handler.matcher.raw.(type) { + case *regexp.Regexp: + pattern, flags := convertRegexp(handler.matcher.raw.(*regexp.Regexp)) + patterns = append(patterns, map[string]interface{}{ + "regexSource": pattern, + "regexFlags": flags, + }) + case string: + patterns = append(patterns, map[string]interface{}{ + "glob": handler.matcher.raw.(string), + }) + default: + all = true + } + } + if all { + return []map[string]interface{}{ + { + "glob": "**/*", + }, + } + } + return patterns +}