diff --git a/src/network/Proxy.ts b/src/network/Proxy.ts index 1a6ff46f16..973c7f525a 100644 --- a/src/network/Proxy.ts +++ b/src/network/Proxy.ts @@ -1,5 +1,12 @@ import type { AddressInfo, Socket } from 'net'; -import type { Host, Port, Address, ConnectionInfo, TLSConfig } from './types'; +import type { + Host, + Port, + Address, + ConnectionInfo, + TLSConfig, + ConnectionEstablishedCallback, +} from './types'; import type { ConnectionsForward } from './ConnectionForward'; import type { NodeId } from '../nodes/types'; import type { Timer } from '../types'; @@ -48,6 +55,7 @@ class Proxy { proxy: new Map(), reverse: new Map(), }; + protected connectionEstablishedCallback: ConnectionEstablishedCallback; constructor({ authToken, @@ -56,6 +64,7 @@ class Proxy { connEndTime = 1000, connPunchIntervalTime = 1000, connKeepAliveIntervalTime = 1000, + connectionEstablishedCallback = () => {}, logger, }: { authToken: string; @@ -64,6 +73,7 @@ class Proxy { connEndTime?: number; connPunchIntervalTime?: number; connKeepAliveIntervalTime?: number; + connectionEstablishedCallback?: ConnectionEstablishedCallback; logger?: Logger; }) { this.logger = logger ?? new Logger(Proxy.name); @@ -77,6 +87,7 @@ class Proxy { this.server = http.createServer(); this.server.on('request', this.handleRequest); this.server.on('connect', this.handleConnectForward); + this.connectionEstablishedCallback = connectionEstablishedCallback; this.logger.info(`Created ${Proxy.name}`); } @@ -521,6 +532,14 @@ class Proxy { timer, ); conn.compose(clientSocket); + // With the connection composed without error we can assume that the + // connection was established and verified + await this.connectionEstablishedCallback({ + remoteNodeId: conn.getServerNodeIds()[0], + remoteHost: conn.host, + remotePort: conn.port, + type: 'forward', + }); } protected async establishConnectionForward( @@ -687,6 +706,14 @@ class Proxy { timer, ); await conn.compose(utpConn, timer); + // With the connection composed without error we can assume that the + // connection was established and verified + await this.connectionEstablishedCallback({ + remoteNodeId: conn.getClientNodeIds()[0], + remoteHost: conn.host, + remotePort: conn.port, + type: 'reverse', + }); } protected async establishConnectionReverse( diff --git a/src/network/types.ts b/src/network/types.ts index 40d672a85a..a5a62b4c20 100644 --- a/src/network/types.ts +++ b/src/network/types.ts @@ -55,6 +55,15 @@ type ConnectionInfo = { remotePort: Port; }; +type ConnectionData = { + remoteNodeId: NodeId; + remoteHost: Host; + remotePort: Port; + type: 'forward' | 'reverse'; +}; + +type ConnectionEstablishedCallback = (data: ConnectionData) => any; + type PingMessage = { type: 'ping'; }; @@ -73,6 +82,8 @@ export type { TLSConfig, ProxyConfig, ConnectionInfo, + ConnectionData, + ConnectionEstablishedCallback, PingMessage, PongMessage, NetworkMessage, diff --git a/tests/network/Proxy.test.ts b/tests/network/Proxy.test.ts index fc8055ea7d..7e8b12d464 100644 --- a/tests/network/Proxy.test.ts +++ b/tests/network/Proxy.test.ts @@ -1,6 +1,6 @@ import type { AddressInfo, Socket } from 'net'; import type { KeyPairPem } from '@/keys/types'; -import type { Host, Port } from '@/network/types'; +import type { ConnectionData, Host, Port } from '@/network/types'; import net from 'net'; import http from 'http'; import tls from 'tls'; @@ -2973,4 +2973,120 @@ describe(Proxy.name, () => { utpSocket.unref(); await serverClose(); }); + test('connectionEstablishedCallback is called when a ReverseConnection is established', async () => { + const clientKeyPair = await keysUtils.generateKeyPair(1024); + const clientKeyPairPem = keysUtils.keyPairToPem(clientKeyPair); + const clientCert = keysUtils.generateCertificate( + clientKeyPair.publicKey, + clientKeyPair.privateKey, + clientKeyPair.privateKey, + 86400, + ); + const clientCertPem = keysUtils.certToPem(clientCert); + const { + serverListen, + serverClose, + serverConnP, + serverConnEndP, + serverConnClosedP, + serverHost, + serverPort, + } = tcpServer(); + await serverListen(0, localHost); + const clientNodeId = keysUtils.certNodeId(clientCert)!; + let callbackData: ConnectionData | undefined; + const proxy = new Proxy({ + logger: logger, + authToken: '', + connectionEstablishedCallback: (data) => { + callbackData = data; + }, + }); + await proxy.start({ + serverHost: serverHost(), + serverPort: serverPort(), + proxyHost: localHost, + tlsConfig: { + keyPrivatePem: keyPairPem.privateKey, + certChainPem: certPem, + }, + }); + + const proxyHost = proxy.getProxyHost(); + const proxyPort = proxy.getProxyPort(); + const { p: clientReadyP, resolveP: resolveClientReadyP } = promise(); + const { p: clientSecureConnectP, resolveP: resolveClientSecureConnectP } = + promise(); + const { p: clientCloseP, resolveP: resolveClientCloseP } = promise(); + const utpSocket = UTP({ allowHalfOpen: true }); + const utpSocketBind = promisify(utpSocket.bind).bind(utpSocket); + const handleMessage = async (data: Buffer) => { + const msg = networkUtils.unserializeNetworkMessage(data); + if (msg.type === 'ping') { + resolveClientReadyP(); + await send(networkUtils.pongBuffer); + } + }; + utpSocket.on('message', handleMessage); + const send = async (data: Buffer) => { + const utpSocketSend = promisify(utpSocket.send).bind(utpSocket); + await utpSocketSend(data, 0, data.byteLength, proxyPort, proxyHost); + }; + await utpSocketBind(0, localHost); + const utpSocketPort = utpSocket.address().port; + await proxy.openConnectionReverse( + localHost, + utpSocketPort as Port, + ); + const utpConn = utpSocket.connect(proxyPort, proxyHost); + const tlsSocket = tls.connect( + { + key: Buffer.from(clientKeyPairPem.privateKey, 'ascii'), + cert: Buffer.from(clientCertPem, 'ascii'), + socket: utpConn, + rejectUnauthorized: false, + }, + () => { + resolveClientSecureConnectP(); + }, + ); + let tlsSocketEnded = false; + tlsSocket.on('end', () => { + tlsSocketEnded = true; + if (utpConn.destroyed) { + tlsSocket.destroy(); + } else { + tlsSocket.end(); + tlsSocket.destroy(); + } + }); + tlsSocket.on('close', () => { + resolveClientCloseP(); + }); + await send(networkUtils.pingBuffer); + expect(proxy.getConnectionReverseCount()).toBe(1); + await clientReadyP; + await clientSecureConnectP; + await serverConnP; + await proxy.closeConnectionReverse( + localHost, + utpSocketPort as Port, + ); + expect(proxy.getConnectionReverseCount()).toBe(0); + await clientCloseP; + await serverConnEndP; + await serverConnClosedP; + expect(tlsSocketEnded).toBe(true); + utpSocket.off('message', handleMessage); + utpSocket.close(); + utpSocket.unref(); + await proxy.stop(); + await serverClose(); + + // Checking callback data + expect(callbackData?.remoteNodeId.equals(clientNodeId)).toBe(true); + expect(callbackData?.remoteHost).toEqual(localHost); + expect(callbackData?.remotePort).toEqual(utpSocketPort); + expect(callbackData?.type).toEqual('reverse'); + }); });