diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 71a1531b..e142c135 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -8,6 +8,10 @@ kotlinx-bcv = "0.14.0" ktor = "2.3.8" +netty = "4.1.107.Final" + +bouncycastle = "1.77" + turbine = "1.0.0" rsocket-java = "1.1.3" @@ -44,6 +48,11 @@ ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor" } ktor-server-netty = { module = "io.ktor:ktor-server-netty", version.ref = "ktor" } ktor-server-jetty = { module = "io.ktor:ktor-server-jetty", version.ref = "ktor" } +netty-handler = { module = "io.netty:netty-handler", version.ref = "netty" } +netty-codec-http = { module = "io.netty:netty-codec-http", version.ref = "netty" } + +bouncycastle = { module = "org.bouncycastle:bcpkix-jdk18on", version.ref = "bouncycastle" } + turbine = { module = "app.cash.turbine:turbine", version.ref = "turbine" } rsocket-java-core = { module = 'io.rsocket:rsocket-core', version.ref = "rsocket-java" } diff --git a/rsocket-internal-io/api/rsocket-internal-io.api b/rsocket-internal-io/api/rsocket-internal-io.api index fcdfae40..24cb3ef6 100644 --- a/rsocket-internal-io/api/rsocket-internal-io.api +++ b/rsocket-internal-io/api/rsocket-internal-io.api @@ -24,6 +24,8 @@ public final class io/rsocket/kotlin/internal/io/ChannelsKt { public final class io/rsocket/kotlin/internal/io/ContextKt { public static final fun childContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; + public static final fun invokeOnCancellation (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun invokeOnCancellation$default (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static final fun supervisorContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; } diff --git a/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt index 295caa8c..ce6c4222 100644 --- a/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt +++ b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt @@ -21,3 +21,23 @@ import kotlin.coroutines.* public fun CoroutineContext.supervisorContext(): CoroutineContext = plus(SupervisorJob(get(Job))) public fun CoroutineContext.childContext(): CoroutineContext = plus(Job(get(Job))) + +public inline fun CoroutineScope.invokeOnCancellation( + context: CoroutineContext = EmptyCoroutineContext, + crossinline block: suspend () -> Unit, +) { + launch(context) { + try { + awaitCancellation() + } catch (cause: Throwable) { + withContext(NonCancellable) { + try { + block() + } catch (suppressed: Throwable) { + cause.addSuppressed(suppressed) + } + } + throw cause + } + } +} diff --git a/rsocket-transports/netty-tcp/api/rsocket-transport-netty-tcp.api b/rsocket-transports/netty-tcp/api/rsocket-transport-netty-tcp.api new file mode 100644 index 00000000..f3f40991 --- /dev/null +++ b/rsocket-transports/netty-tcp/api/rsocket-transport-netty-tcp.api @@ -0,0 +1,47 @@ +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTarget : io/rsocket/kotlin/transport/RSocketClientTarget { + public abstract fun getRemoteAddress ()Ljava/net/SocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport$Factory; + public fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTarget; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getLocalAddress ()Ljava/net/InetSocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTarget : io/rsocket/kotlin/transport/RSocketServerTarget { + public abstract fun getLocalAddress ()Ljava/net/InetSocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport$Factory; + public fun target ()Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTarget; + public fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTarget; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V +} + diff --git a/rsocket-transports/netty-tcp/build.gradle.kts b/rsocket-transports/netty-tcp/build.gradle.kts new file mode 100644 index 00000000..8f945dbc --- /dev/null +++ b/rsocket-transports/netty-tcp/build.gradle.kts @@ -0,0 +1,39 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import rsocketbuild.* + +plugins { + id("rsocketbuild.multiplatform-library") +} + +description = "rsocket-kotlin Netty TCP client/server transport implementation" + +kotlin { + jvmTarget() + + sourceSets { + jvmMain.dependencies { + implementation(projects.rsocketInternalIo) + api(projects.rsocketCore) + api(libs.netty.handler) + } + jvmTest.dependencies { + implementation(projects.rsocketTransportTests) + implementation(libs.bouncycastle) + } + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpChannelHandler.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpChannelHandler.kt new file mode 100644 index 00000000..b6708537 --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpChannelHandler.kt @@ -0,0 +1,96 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.ktor.utils.io.core.* +import io.netty.buffer.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.channels.Channel +import java.net.* +import kotlin.coroutines.* + +internal class NettyTcpChannelHandler( + private val sslContext: SslContext?, + private val remoteAddress: SocketAddress?, +) : ChannelInitializer() { + private val frames = channelForCloseable(Channel.UNLIMITED) + + @RSocketTransportApi + fun connect( + context: CoroutineContext, + channel: DuplexChannel, + ): NettyTcpSession = NettyTcpSession( + coroutineContext = context, + channel = channel, + frames = frames + ) + + override fun initChannel(ch: DuplexChannel): Unit = with(ch.pipeline()) { + if (sslContext != null) { + val sslHandler = if ( + remoteAddress is InetSocketAddress && + ch.parent() == null // not server + ) { + sslContext.newHandler(ch.alloc(), remoteAddress.hostName, remoteAddress.port) + } else { + sslContext.newHandler(ch.alloc()) + } + addLast("ssl", sslHandler) + } + addLast( + "rsocket-length-encoder", + LengthFieldPrepender( + /* lengthFieldLength = */ 3 + ) + ) + addLast( + "rsocket-length-decoder", + LengthFieldBasedFrameDecoder( + /* maxFrameLength = */ Int.MAX_VALUE, + /* lengthFieldOffset = */ 0, + /* lengthFieldLength = */ 3, + /* lengthAdjustment = */ 0, + /* initialBytesToStrip = */ 3 + ) + ) + addLast( + "rsocket-frame-receiver", + IncomingFramesChannelHandler(frames) + ) + } + + private class IncomingFramesChannelHandler( + private val channel: SendChannel, + ) : SimpleChannelInboundHandler() { + override fun channelInactive(ctx: ChannelHandlerContext) { + channel.close() //TODO? + super.channelInactive(ctx) + } + + override fun channelRead0(ctx: ChannelHandlerContext, msg: ByteBuf) { + channel.trySend(buildPacket { + writeFully(msg.nioBuffer()) + }).getOrThrow() + } + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt new file mode 100644 index 00000000..35068180 --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt @@ -0,0 +1,157 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyTcpClientTarget : RSocketClientTarget { + public val remoteAddress: SocketAddress +} + +public sealed interface NettyTcpClientTransport : RSocketTransport< + InetSocketAddress, + NettyTcpClientTarget> { + + public fun target(hostname: String, port: Int): NettyTcpClientTarget = target(InetSocketAddress(hostname, port)) + + public companion object Factory : RSocketTransportFactory< + InetSocketAddress, + NettyTcpClientTarget, + NettyTcpClientTransport, + NettyTcpClientTransportBuilder>(::NettyTcpClientTransportBuilderImpl) +} + +public sealed interface NettyTcpClientTransportBuilder : RSocketTransportBuilder< + InetSocketAddress, + NettyTcpClientTarget, + NettyTcpClientTransport> { + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: Bootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) +} + +private class NettyTcpClientTransportBuilderImpl : NettyTcpClientTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var eventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (Bootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.eventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: Bootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyTcpClientTransport { + val group = eventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioSocketChannel::class.java) + + val transportContext = context.supervisorContext() + group.asCoroutineDispatcher() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + group.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forClient() + .apply(it) + .build() + } + + val bootstrap = Bootstrap().apply { + bootstrap?.invoke(this) + group(group) + channelFactory(factory) + } + + return NettyTcpClientTransportImpl( + coroutineContext = transportContext, + sslContext = sslContext, + bootstrap = bootstrap + ) + } +} + +private class NettyTcpClientTransportImpl( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, +) : NettyTcpClientTransport { + override fun target(address: InetSocketAddress): NettyTcpClientTarget = NettyTcpClientTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + remoteAddress = address, + sslContext = sslContext, + bootstrap = bootstrap + ) +} + +private class NettyTcpClientTargetImpl( + override val coroutineContext: CoroutineContext, + override val remoteAddress: SocketAddress, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, +) : NettyTcpClientTarget { + @RSocketTransportApi + override suspend fun createSession(): RSocketTransportSession { + ensureActive() + + val handler = NettyTcpChannelHandler( + sslContext = sslContext, + remoteAddress = remoteAddress + ) + val future = bootstrap.clone().apply { + handler(handler) + }.connect(remoteAddress) + + future.awaitFuture() + + return handler.connect(coroutineContext.childContext(), future.channel() as DuplexChannel) + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt new file mode 100644 index 00000000..039ba11c --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt @@ -0,0 +1,215 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import javax.net.ssl.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyTcpServerInstance : RSocketServerInstance { + public val localAddress: InetSocketAddress +} + +public sealed interface NettyTcpServerTarget : RSocketServerTarget { + public val localAddress: InetSocketAddress? +} + +public sealed interface NettyTcpServerTransport : RSocketTransport< + InetSocketAddress?, + NettyTcpServerTarget> { + + public fun target(): NettyTcpServerTarget = target(null) + public fun target(hostname: String = "0.0.0.0", port: Int = 0): NettyTcpServerTarget = target(InetSocketAddress(hostname, port)) + + public companion object Factory : RSocketTransportFactory< + InetSocketAddress?, + NettyTcpServerTarget, + NettyTcpServerTransport, + NettyTcpServerTransportBuilder>(::NettyTcpServerTransportBuilderImpl) +} + +public sealed interface NettyTcpServerTransportBuilder : + RSocketTransportBuilder< + InetSocketAddress?, + NettyTcpServerTarget, + NettyTcpServerTransport> { + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: ServerBootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) +} + +private class NettyTcpServerTransportBuilderImpl : NettyTcpServerTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var parentEventLoopGroup: EventLoopGroup? = null + private var childEventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (ServerBootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = parentGroup + this.childEventLoopGroup = childGroup + this.manageEventLoopGroup = manage + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = group + this.childEventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: ServerBootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyTcpServerTransport { + val parentGroup = parentEventLoopGroup ?: NioEventLoopGroup() + val childGroup = childEventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioServerSocketChannel::class.java) + + val transportContext = context.supervisorContext() + childGroup.asCoroutineDispatcher() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + childGroup.shutdownGracefully().awaitFuture() + parentGroup.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forServer(KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())) + .apply(it) + .build() + } + + val bootstrap = ServerBootstrap().apply { + bootstrap?.invoke(this) + group(parentGroup, childGroup) + channelFactory(factory) + } + + return NettyTcpServerTransportImpl( + coroutineContext = transportContext, + bootstrap = bootstrap, + sslContext = sslContext + ) + } +} + +private class NettyTcpServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, +) : NettyTcpServerTransport { + override fun target(address: InetSocketAddress?): NettyTcpServerTarget { + return NettyTcpServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + localAddress = address, + bootstrap = bootstrap, + sslContext = sslContext + ) + } +} + +private class NettyTcpServerTargetImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: InetSocketAddress?, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, +) : NettyTcpServerTarget { + @RSocketTransportApi + override suspend fun startServer(acceptor: RSocketServerAcceptor): NettyTcpServerInstance { + ensureActive() + + val instanceContext = coroutineContext.supervisorContext() + try { + val future = bootstrap.clone().apply { + childHandler(AcceptorChannelHandler(instanceContext, sslContext, acceptor)) + }.bind(localAddress ?: InetSocketAddress(0)) + + try { + future.awaitFuture() + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + + return NettyTcpServerInstanceImpl( + coroutineContext = instanceContext, + channel = future.channel() as ServerChannel + ) + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + } +} + +@RSocketTransportApi +private class AcceptorChannelHandler( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val acceptor: RSocketServerAcceptor, +) : ChannelInitializer(), CoroutineScope { + override fun initChannel(ch: DuplexChannel) { + val handler = NettyTcpChannelHandler( + sslContext = sslContext, + remoteAddress = null + ) + ch.pipeline().addLast(handler) + val connection = handler.connect(coroutineContext.childContext(), ch) + launch { acceptor.acceptSession(connection) } + } +} + +private class NettyTcpServerInstanceImpl( + override val coroutineContext: CoroutineContext, + private val channel: ServerChannel, +) : NettyTcpServerInstance { + override val localAddress: InetSocketAddress get() = channel.localAddress() as InetSocketAddress + + init { + linkCompletionWith(channel) + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpSession.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpSession.kt new file mode 100644 index 00000000..f5209872 --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpSession.kt @@ -0,0 +1,44 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.ktor.utils.io.core.* +import io.netty.buffer.* +import io.netty.channel.socket.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.channels.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyTcpSession( + override val coroutineContext: CoroutineContext, + private val channel: DuplexChannel, + private val frames: ReceiveChannel, +) : RSocketTransportSession.Sequential { + + init { + linkCompletionWith(channel) + } + + override suspend fun sendFrame(frame: ByteReadPacket) { + channel.writeAndFlush(Unpooled.wrappedBuffer(frame.readByteBuffer())).awaitFuture() + } + + override suspend fun receiveFrame(): ByteReadPacket { + return frames.receive() + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/utils.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/utils.kt new file mode 100644 index 00000000..be54b620 --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/utils.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.channel.* +import io.netty.util.concurrent.* +import io.rsocket.kotlin.internal.io.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@Suppress("UNCHECKED_CAST") +internal suspend inline fun Future.awaitFuture(): T = suspendCancellableCoroutine { cont -> + addListener { + when { + it.isSuccess -> cont.resume(it.now as T) + else -> cont.resumeWithException(it.cause()) + } + } + cont.invokeOnCancellation { + cancel(true) + } +} + +internal fun CoroutineScope.linkCompletionWith(channel: Channel) { + channel.closeFuture().addListener { + cancel("Netty channel closed", it.cause()) + } + invokeOnCancellation { + channel.close().awaitFuture() + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpTransportTest.kt b/rsocket-transports/netty-tcp/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpTransportTest.kt new file mode 100644 index 00000000..f1b43e9c --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpTransportTest.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.channel.nio.* +import io.netty.handler.ssl.util.* +import io.rsocket.kotlin.transport.tests.* +import kotlin.concurrent.* + +private val eventLoop = NioEventLoopGroup().also { + Runtime.getRuntime().addShutdownHook(thread(start = false) { + it.shutdownGracefully().await(1000) + }) +} +private val certificates = SelfSignedCertificate() + +class NettyTcpTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyTcpServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + }.target() + ) + client = connectClient( + NettyTcpClientTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + }.target(server.localAddress) + ) + } +} + +class NettyTcpSslTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyTcpServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + keyManager(certificates.certificate(), certificates.privateKey()) + } + }.target() + ) + client = connectClient( + NettyTcpClientTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + trustManager(InsecureTrustManagerFactory.INSTANCE) + } + }.target(server.localAddress) + ) + } +} diff --git a/rsocket-transports/netty-websocket/api/rsocket-transport-netty-websocket.api b/rsocket-transports/netty-websocket/api/rsocket-transport-netty-websocket.api new file mode 100644 index 00000000..5dce871e --- /dev/null +++ b/rsocket-transports/netty-websocket/api/rsocket-transport-netty-websocket.api @@ -0,0 +1,51 @@ +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTarget : io/rsocket/kotlin/transport/RSocketClientTarget { + public abstract fun getConfig ()Lio/netty/handler/codec/http/websocketx/WebSocketClientProtocolConfig; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport$Factory; + public fun target (Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTarget; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V + public abstract fun webSockets (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getConfig ()Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig; + public abstract fun getLocalAddress ()Ljava/net/InetSocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTarget : io/rsocket/kotlin/transport/RSocketServerTarget { + public abstract fun getConfig ()Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig; + public abstract fun getLocalAddress ()Ljava/net/InetSocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory; + public fun target ()Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTarget; + public fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTarget; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V + public abstract fun webSockets (Lkotlin/jvm/functions/Function1;)V +} + diff --git a/rsocket-transports/netty-websocket/build.gradle.kts b/rsocket-transports/netty-websocket/build.gradle.kts new file mode 100644 index 00000000..7bc03ca2 --- /dev/null +++ b/rsocket-transports/netty-websocket/build.gradle.kts @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import rsocketbuild.* + +plugins { + id("rsocketbuild.multiplatform-library") +} + +description = "rsocket-kotlin Netty WebSocket client/server transport implementation" + +kotlin { + jvmTarget() + + sourceSets { + jvmMain.dependencies { + implementation(projects.rsocketInternalIo) + api(projects.rsocketCore) + api(libs.netty.handler) + api(libs.netty.codec.http) + } + jvmTest.dependencies { + implementation(projects.rsocketTransportTests) + implementation(libs.bouncycastle) + } + } +} diff --git a/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt new file mode 100644 index 00000000..175145d9 --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt @@ -0,0 +1,104 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.ktor.utils.io.core.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import java.net.* +import kotlin.coroutines.* + + +internal class NettyWebSocketChannelHandler( + private val sslContext: SslContext?, + private val remoteAddress: SocketAddress?, + private val httpHandler: ChannelHandler, + private val webSocketHandler: ChannelHandler, +) : ChannelInitializer() { + private val frames = channelForCloseable(Channel.UNLIMITED) + private val handshakeDeferred = CompletableDeferred() + + @RSocketTransportApi + suspend fun connect( + context: CoroutineContext, + channel: DuplexChannel, + ): NettyWebSocketSession { + handshakeDeferred.await() + + return NettyWebSocketSession( + coroutineContext = context.childContext(), + channel = channel, + frames = frames + ) + } + + override fun initChannel(ch: DuplexChannel): Unit = with(ch.pipeline()) { + if (sslContext != null) { + val sslHandler = if ( + remoteAddress is InetSocketAddress && + ch.parent() == null // not server + ) { + sslContext.newHandler(ch.alloc(), remoteAddress.hostName, remoteAddress.port) + } else { + sslContext.newHandler(ch.alloc()) + } + addLast("ssl", sslHandler) + } + addLast("http", httpHandler) + addLast(HttpObjectAggregator(65536)) //TODO size? + addLast("websocket", webSocketHandler) + + addLast( + "rsocket-frame-receiver", + IncomingFramesChannelHandler() + ) + } + + private inner class IncomingFramesChannelHandler : SimpleChannelInboundHandler() { + override fun channelInactive(ctx: ChannelHandlerContext) { + frames.close() //TODO? + super.channelInactive(ctx) + } + + override fun channelRead0(ctx: ChannelHandlerContext, msg: WebSocketFrame) { + if (msg !is BinaryWebSocketFrame && msg !is TextWebSocketFrame) { + error("wrong frame type") + } + + frames.trySend(buildPacket { + writeFully(msg.content().nioBuffer()) + }).getOrThrow() + } + + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { + if ( + evt is WebSocketServerProtocolHandler.HandshakeComplete || + evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE + ) { + handshakeDeferred.complete(Unit) + } + //TODO: handle timeout - ? + } + } +} diff --git a/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt new file mode 100644 index 00000000..818363b0 --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt @@ -0,0 +1,172 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyWebSocketClientTarget : RSocketClientTarget { + public val config: WebSocketClientProtocolConfig +} + +public sealed interface NettyWebSocketClientTransport : RSocketTransport< + WebSocketClientProtocolConfig, + NettyWebSocketClientTarget> { + + public fun target(target: WebSocketClientProtocolConfig.Builder.() -> Unit): NettyWebSocketClientTarget = + target(WebSocketClientProtocolConfig.newBuilder().apply(target).build()) + + public companion object Factory : RSocketTransportFactory< + WebSocketClientProtocolConfig, + NettyWebSocketClientTarget, + NettyWebSocketClientTransport, + NettyWebSocketClientTransportBuilder>(::NettyWebSocketClientTransportBuilderImpl) { + } +} + +public sealed interface NettyWebSocketClientTransportBuilder : + RSocketTransportBuilder { + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: Bootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) + public fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit) +} + +private class NettyWebSocketClientTransportBuilderImpl : NettyWebSocketClientTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var eventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (Bootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + private var webSockets: (WebSocketClientProtocolConfig.Builder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.eventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: Bootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + override fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit) { + webSockets = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyWebSocketClientTransport { + val group = eventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioSocketChannel::class.java) + + val transportContext = context.supervisorContext() + group.asCoroutineDispatcher() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + group.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forClient() + .apply(it) + .build() + } + + val bootstrap = Bootstrap().apply { + bootstrap?.invoke(this) + group(group) + channelFactory(factory) + } + + return NettyWebSocketClientTransportImpl( + coroutineContext = transportContext, + sslContext = sslContext, + bootstrap = bootstrap, + webSocketConfig = webSockets + ) + } +} + +private class NettyWebSocketClientTransportImpl( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, + private val webSocketConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?, +) : NettyWebSocketClientTransport { + override fun target(address: WebSocketClientProtocolConfig): NettyWebSocketClientTarget = NettyWebSocketClientTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + config = when (webSocketConfig) { + null -> address + else -> address.toBuilder().apply(webSocketConfig).build() + }, + sslContext = sslContext, + bootstrap = bootstrap, + ) +} + +private class NettyWebSocketClientTargetImpl( + override val coroutineContext: CoroutineContext, + override val config: WebSocketClientProtocolConfig, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, +) : NettyWebSocketClientTarget { + + @RSocketTransportApi + override suspend fun createSession(): RSocketTransportSession { + val remoteAddress = InetSocketAddress(config.webSocketUri().host, config.webSocketUri().port) + val handler = NettyWebSocketChannelHandler( + sslContext = sslContext, + remoteAddress = remoteAddress, + httpHandler = HttpClientCodec(), + webSocketHandler = WebSocketClientProtocolHandler(config), + ) + val future = bootstrap.clone().apply { + handler(handler) + }.connect(remoteAddress) + + future.awaitFuture() + + return handler.connect(coroutineContext, future.channel() as DuplexChannel) + } +} diff --git a/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt new file mode 100644 index 00000000..b72016de --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt @@ -0,0 +1,230 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import javax.net.ssl.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyWebSocketServerInstance : RSocketServerInstance { + public val localAddress: InetSocketAddress + public val config: WebSocketServerProtocolConfig +} + +public sealed interface NettyWebSocketServerTarget : RSocketServerTarget { + public val localAddress: InetSocketAddress? + public val config: WebSocketServerProtocolConfig +} + +public sealed interface NettyWebSocketServerTransport : RSocketTransport< + InetSocketAddress?, + NettyWebSocketServerTarget> { + + public fun target(): NettyWebSocketServerTarget = target(null) + public fun target(hostname: String = "0.0.0.0", port: Int = 0): NettyWebSocketServerTarget = target(InetSocketAddress(hostname, port)) + + public companion object Factory : RSocketTransportFactory< + InetSocketAddress?, + NettyWebSocketServerTarget, + NettyWebSocketServerTransport, + NettyWebSocketServerTransportBuilder>(::NettyWebSocketServerTransportBuilderImpl) +} + +public sealed interface NettyWebSocketServerTransportBuilder : + RSocketTransportBuilder { + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: ServerBootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) + public fun webSockets(block: WebSocketServerProtocolConfig.Builder.() -> Unit) +} + +private class NettyWebSocketServerTransportBuilderImpl : NettyWebSocketServerTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var parentEventLoopGroup: EventLoopGroup? = null + private var childEventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (ServerBootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + private var webSockets: (WebSocketServerProtocolConfig.Builder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = parentGroup + this.childEventLoopGroup = childGroup + this.manageEventLoopGroup = manage + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = group + this.childEventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: ServerBootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + override fun webSockets(block: WebSocketServerProtocolConfig.Builder.() -> Unit) { + webSockets = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyWebSocketServerTransport { + val parentGroup = parentEventLoopGroup ?: NioEventLoopGroup() + val childGroup = childEventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioServerSocketChannel::class.java) + + val transportContext = context.supervisorContext() + childGroup.asCoroutineDispatcher() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + childGroup.shutdownGracefully().awaitFuture() + parentGroup.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forServer(KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())) + .apply(it) + .build() + } + + val bootstrap = ServerBootstrap().apply { + bootstrap?.invoke(this) + group(parentGroup, childGroup) + channelFactory(factory) + } + + return NettyWebSocketServerTransportImpl( + coroutineContext = transportContext, + bootstrap = bootstrap, + sslContext = sslContext, + webSocketConfig = webSockets + ) + } +} + +private class NettyWebSocketServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, + private val webSocketConfig: (WebSocketServerProtocolConfig.Builder.() -> Unit)?, +) : NettyWebSocketServerTransport { + override fun target(address: InetSocketAddress?): NettyWebSocketServerTarget = NettyWebSocketServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + localAddress = address, + config = WebSocketServerProtocolConfig.newBuilder().also { webSocketConfig?.invoke(it) }.build(), + bootstrap = bootstrap, + sslContext = sslContext + ) +} + +private class NettyWebSocketServerTargetImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: InetSocketAddress?, + override val config: WebSocketServerProtocolConfig, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, +) : NettyWebSocketServerTarget { + + @RSocketTransportApi + override suspend fun startServer(acceptor: RSocketServerAcceptor): NettyWebSocketServerInstance { + val instanceContext = coroutineContext.supervisorContext() + try { + val future = bootstrap.clone().apply { + localAddress(localAddress ?: InetSocketAddress(0)) + childHandler(AcceptorChannelHandler(instanceContext, sslContext, acceptor, config)) + }.bind() + + try { + future.awaitFuture() + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + + return NettyWebSocketServerInstanceImpl( + coroutineContext = instanceContext, + channel = future.channel() as ServerChannel, + config = config + ) + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + } +} + +@RSocketTransportApi +private class AcceptorChannelHandler( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val acceptor: RSocketServerAcceptor, + private val config: WebSocketServerProtocolConfig, +) : ChannelInitializer(), CoroutineScope { + override fun initChannel(ch: DuplexChannel) { + val handler = NettyWebSocketChannelHandler( + sslContext = sslContext, + remoteAddress = null, + httpHandler = HttpServerCodec(), + webSocketHandler = WebSocketServerProtocolHandler(config) + ) + ch.pipeline().addLast(handler) + launch { + acceptor.acceptSession(handler.connect(coroutineContext, ch)) + } + } +} + +private class NettyWebSocketServerInstanceImpl( + override val coroutineContext: CoroutineContext, + private val channel: ServerChannel, + override val config: WebSocketServerProtocolConfig, +) : NettyWebSocketServerInstance { + override val localAddress: InetSocketAddress get() = channel.localAddress() as InetSocketAddress + + init { + linkCompletionWith(channel) + } +} diff --git a/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketSession.kt b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketSession.kt new file mode 100644 index 00000000..903c3f8c --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketSession.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.ktor.utils.io.core.* +import io.netty.buffer.* +import io.netty.channel.socket.* +import io.netty.handler.codec.http.websocketx.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.channels.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyWebSocketSession( + override val coroutineContext: CoroutineContext, + private val channel: DuplexChannel, + private val frames: ReceiveChannel, +) : RSocketTransportSession.Sequential { + + init { + linkCompletionWith(channel) + } + + override suspend fun sendFrame(frame: ByteReadPacket) { + channel.writeAndFlush(BinaryWebSocketFrame(Unpooled.wrappedBuffer(frame.readByteBuffer()))).awaitFuture() + } + + override suspend fun receiveFrame(): ByteReadPacket { + return frames.receive() + } +} diff --git a/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/utils.kt b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/utils.kt new file mode 100644 index 00000000..58aa21bb --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/utils.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.channel.* +import io.netty.util.concurrent.* +import io.rsocket.kotlin.internal.io.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@Suppress("UNCHECKED_CAST") +internal suspend inline fun Future.awaitFuture(): T = suspendCancellableCoroutine { cont -> + addListener { + when { + it.isSuccess -> cont.resume(it.now as T) + else -> cont.resumeWithException(it.cause()) + } + } + cont.invokeOnCancellation { + cancel(true) + } +} + +internal fun CoroutineScope.linkCompletionWith(channel: Channel) { + channel.closeFuture().addListener { + cancel("Netty channel closed", it.cause()) + } + invokeOnCancellation { + channel.close().awaitFuture() + } +} diff --git a/rsocket-transports/netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt b/rsocket-transports/netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt new file mode 100644 index 00000000..f8fa7d96 --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt @@ -0,0 +1,69 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.channel.nio.* +import io.netty.handler.ssl.util.* +import io.rsocket.kotlin.transport.tests.* +import kotlin.concurrent.* + +private val eventLoop = NioEventLoopGroup().also { + Runtime.getRuntime().addShutdownHook(thread(start = false) { + it.shutdownGracefully().await(1000) + }) +} +private val certificates = SelfSignedCertificate() + +class NettyWebSocketTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyWebSocketServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + }.target() + ) + client = connectClient( + NettyWebSocketClientTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + }.target { + webSocketUri("ws://localhost:${server.localAddress.port}") + } + ) + } +} + +class NettyWebSocketSslTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyWebSocketServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + keyManager(certificates.certificate(), certificates.privateKey()) + } + }.target() + ) + client = connectClient( + NettyWebSocketClientTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + trustManager(InsecureTrustManagerFactory.INSTANCE) + } + }.target { + webSocketUri("ws://localhost:${server.localAddress.port}") + } + ) + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 511c69cf..cf90bce0 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -44,6 +44,9 @@ projects { module("ktor-websocket-server") module("ktor-websocket-internal") + module("netty-tcp") + module("netty-websocket") + module("nodejs-tcp") }