diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index d6006cc2..958f6442 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -9,6 +9,10 @@ kotlinx-bcv = "0.13.2" ktor = "2.3.6" +netty = "4.1.101.Final" + +bouncycastle = "1.77" + turbine = "1.0.0" rsocket-java = "1.1.3" @@ -44,6 +48,11 @@ ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor" } ktor-server-netty = { module = "io.ktor:ktor-server-netty", version.ref = "ktor" } ktor-server-jetty = { module = "io.ktor:ktor-server-jetty", version.ref = "ktor" } +netty-handler = { module = "io.netty:netty-handler", version.ref = "netty" } +netty-codec-http = { module = "io.netty:netty-codec-http", version.ref = "netty" } + +bouncycastle = { module = "org.bouncycastle:bcpkix-jdk18on", version.ref = "bouncycastle" } + turbine = { module = "app.cash.turbine:turbine", version.ref = "turbine" } rsocket-java-core = { module = 'io.rsocket:rsocket-core', version.ref = "rsocket-java" } diff --git a/rsocket-internal-io/api/rsocket-internal-io.api b/rsocket-internal-io/api/rsocket-internal-io.api index 4be03cf6..93037f89 100644 --- a/rsocket-internal-io/api/rsocket-internal-io.api +++ b/rsocket-internal-io/api/rsocket-internal-io.api @@ -5,6 +5,8 @@ public final class io/rsocket/kotlin/internal/io/ChannelsKt { public final class io/rsocket/kotlin/internal/io/ContextKt { public static final fun childContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; + public static final fun invokeOnCancellation (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun invokeOnCancellation$default (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static final fun supervisorContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; } diff --git a/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt index c6be8758..22d91ad8 100644 --- a/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt +++ b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt @@ -21,3 +21,23 @@ import kotlin.coroutines.* public fun CoroutineContext.supervisorContext(): CoroutineContext = plus(SupervisorJob(get(Job))) public fun CoroutineContext.childContext(): CoroutineContext = plus(Job(get(Job))) + +public inline fun CoroutineScope.invokeOnCancellation( + context: CoroutineContext = EmptyCoroutineContext, + crossinline block: suspend () -> Unit, +) { + launch(context) { + try { + awaitCancellation() + } catch (cause: Throwable) { + withContext(NonCancellable) { + try { + block() + } catch (suppressed: Throwable) { + cause.addSuppressed(suppressed) + } + } + throw cause + } + } +} diff --git a/rsocket-transport-netty-tcp/api/rsocket-transport-netty-tcp.api b/rsocket-transport-netty-tcp/api/rsocket-transport-netty-tcp.api new file mode 100644 index 00000000..939602dd --- /dev/null +++ b/rsocket-transport-netty-tcp/api/rsocket-transport-netty-tcp.api @@ -0,0 +1,69 @@ +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport : io/rsocket/kotlin/transport/RSocketClientTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport$Factory; + public abstract fun getRemoteAddress ()Ljava/net/SocketAddress; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { + public final fun invoke (Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder, io/rsocket/kotlin/transport/RSocketTransportEngineBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransportEngine : io/rsocket/kotlin/transport/RSocketTransportEngine { + public static final field Factory Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransportEngine$Factory; + public abstract fun createTransport (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransportEngine$DefaultImpls { + public static fun createTransport (Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransportEngine;Ljava/lang/String;I)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransportEngine$Factory : io/rsocket/kotlin/transport/RSocketTransportEngineFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getLocalAddress ()Ljava/net/SocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport : io/rsocket/kotlin/transport/RSocketServerTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport$Factory; + public abstract fun getLocalAddress ()Ljava/net/SocketAddress; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { + public final fun invoke (Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport; + public final fun invoke (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport; + public static synthetic fun invoke$default (Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport$Factory;Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport; + public static synthetic fun invoke$default (Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport$Factory;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder, io/rsocket/kotlin/transport/RSocketTransportEngineBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportEngine : io/rsocket/kotlin/transport/RSocketTransportEngine { + public static final field Factory Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportEngine$Factory; + public abstract fun createTransport ()Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport; + public abstract fun createTransport (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportEngine$DefaultImpls { + public static fun createTransport (Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportEngine;)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport; + public static fun createTransport (Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportEngine;Ljava/lang/String;I)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport; + public static synthetic fun createTransport$default (Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportEngine;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportEngine$Factory : io/rsocket/kotlin/transport/RSocketTransportEngineFactory { +} + diff --git a/rsocket-transport-netty-tcp/build.gradle.kts b/rsocket-transport-netty-tcp/build.gradle.kts new file mode 100644 index 00000000..30732bc7 --- /dev/null +++ b/rsocket-transport-netty-tcp/build.gradle.kts @@ -0,0 +1,44 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import rsocketbuild.* + +plugins { + id("rsocketbuild.template.library") +} + +kotlin { + jvmTarget() + + sourceSets { + jvmMain { + dependencies { + implementation(projects.rsocketInternalIo) + + api(projects.rsocketCore) + api(libs.netty.handler) + } + } + jvmTest { + dependencies { + implementation(projects.rsocketTransportTests) + implementation(libs.bouncycastle) + } + } + } +} + +description = "rsocket-kotlin Netty TCP transport implementation" diff --git a/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpChannelHandler.kt b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpChannelHandler.kt new file mode 100644 index 00000000..1d6dafbd --- /dev/null +++ b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpChannelHandler.kt @@ -0,0 +1,96 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.ktor.utils.io.core.* +import io.netty.buffer.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.channels.Channel +import java.net.* +import kotlin.coroutines.* + +internal class NettyTcpChannelHandler( + private val sslContext: SslContext?, + private val remoteAddress: SocketAddress?, +) : ChannelInitializer() { + private val frames = channelForCloseable(Channel.UNLIMITED) + + @RSocketTransportApi + fun connect( + context: CoroutineContext, + channel: DuplexChannel, + ): NettyTcpSession = NettyTcpSession( + coroutineContext = context, + channel = channel, + frames = frames + ) + + override fun initChannel(ch: DuplexChannel): Unit = with(ch.pipeline()) { + if (sslContext != null) { + val sslHandler = if ( + remoteAddress is InetSocketAddress && + ch.parent() == null // not server + ) { + sslContext.newHandler(ch.alloc(), remoteAddress.hostName, remoteAddress.port) + } else { + sslContext.newHandler(ch.alloc()) + } + addLast("ssl", sslHandler) + } + addLast( + "rsocket-length-encoder", + LengthFieldPrepender( + /* lengthFieldLength = */ 3 + ) + ) + addLast( + "rsocket-length-decoder", + LengthFieldBasedFrameDecoder( + /* maxFrameLength = */ Int.MAX_VALUE, + /* lengthFieldOffset = */ 0, + /* lengthFieldLength = */ 3, + /* lengthAdjustment = */ 0, + /* initialBytesToStrip = */ 3 + ) + ) + addLast( + "rsocket-frame-receiver", + IncomingFramesChannelHandler(frames) + ) + } + + private class IncomingFramesChannelHandler( + private val channel: SendChannel, + ) : SimpleChannelInboundHandler() { + override fun channelInactive(ctx: ChannelHandlerContext) { + channel.close() //TODO? + super.channelInactive(ctx) + } + + override fun channelRead0(ctx: ChannelHandlerContext, msg: ByteBuf) { + channel.trySend(buildPacket { + writeFully(msg.nioBuffer()) + }).getOrThrow() + } + } +} diff --git a/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt new file mode 100644 index 00000000..5afedbd4 --- /dev/null +++ b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt @@ -0,0 +1,204 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyTcpClientTransport : RSocketClientTransport { + public val remoteAddress: SocketAddress + + public companion object Factory : RSocketTransportFactory< + SocketAddress, + NettyTcpClientTransport, + NettyTcpClientTransportBuilder>(::NettyTcpClientTransportBuilderImpl) { + + public fun invoke( + context: CoroutineContext, + hostname: String, + port: Int, + block: NettyTcpClientTransportBuilder.() -> Unit, + ): NettyTcpClientTransport = invoke(context, InetSocketAddress(hostname, port), block) + } +} + +public sealed interface NettyTcpClientTransportEngine : RSocketTransportEngine { + public fun createTransport( + hostname: String, + port: Int, + ): NettyTcpClientTransport = createTransport(InetSocketAddress(hostname, port)) + + public companion object Factory : RSocketTransportEngineFactory< + SocketAddress, + NettyTcpClientTransport, + NettyTcpClientTransportEngine, + NettyTcpClientTransportBuilder>(::NettyTcpClientTransportBuilderImpl) +} + +public sealed interface NettyTcpClientTransportBuilder : + RSocketTransportBuilder, + RSocketTransportEngineBuilder { + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: Bootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) +} + +private class NettyTcpClientTransportBuilderImpl : NettyTcpClientTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var eventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (Bootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.eventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: Bootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext, target: SocketAddress): NettyTcpClientTransport { + return build(context).buildTransport(target) + } + + @RSocketTransportApi + override fun buildEngine(context: CoroutineContext): NettyTcpClientTransportEngine { + return build(context).buildEngine() + } + + private fun build(context: CoroutineContext): NettyTcpClientTransportResources { + val group = eventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioSocketChannel::class.java) + + val transportContext = context.supervisorContext() + group.asCoroutineDispatcher() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + group.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forClient() + .apply(it) + .build() + } + + val bootstrap = Bootstrap().apply { + bootstrap?.invoke(this) + group(group) + channelFactory(factory) + } + + return NettyTcpClientTransportResources( + coroutineContext = transportContext, + sslContext = sslContext, + bootstrap = bootstrap + ) + } +} + +private class NettyTcpClientTransportResources( + private val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, +) { + fun buildTransport(address: SocketAddress): NettyTcpClientTransport { + return NettyTcpClientTransportImpl( + coroutineContext = coroutineContext, + remoteAddress = address, + sslContext = sslContext, + bootstrap = bootstrap + ) + } + + fun buildEngine(): NettyTcpClientTransportEngine { + return NettyTcpClientTransportEngineImpl( + coroutineContext = coroutineContext, + sslContext = sslContext, + bootstrap = bootstrap + ) + } +} + +private class NettyTcpClientTransportEngineImpl( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, +) : NettyTcpClientTransportEngine { + override fun createTransport(target: SocketAddress): NettyTcpClientTransport { + return NettyTcpClientTransportImpl( + coroutineContext = coroutineContext.supervisorContext(), + remoteAddress = target, + sslContext = sslContext, + bootstrap = bootstrap + ) + } +} + +private class NettyTcpClientTransportImpl( + override val coroutineContext: CoroutineContext, + override val remoteAddress: SocketAddress, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, +) : NettyTcpClientTransport { + @RSocketTransportApi + override suspend fun createSession(): RSocketTransportSession { + ensureActive() + + val handler = NettyTcpChannelHandler( + sslContext = sslContext, + remoteAddress = remoteAddress + ) + val future = bootstrap.clone().apply { + handler(handler) + }.connect(remoteAddress) + + future.awaitFuture() + + return handler.connect(coroutineContext.childContext(), future.channel() as DuplexChannel) + } +} diff --git a/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt new file mode 100644 index 00000000..be076b36 --- /dev/null +++ b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt @@ -0,0 +1,263 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import javax.net.ssl.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyTcpServerInstance : RSocketServerInstance { + public val localAddress: SocketAddress +} + +public sealed interface NettyTcpServerTransport : RSocketServerTransport { + public val localAddress: SocketAddress? + + public companion object Factory : RSocketTransportFactory< + SocketAddress?, + NettyTcpServerTransport, + NettyTcpServerTransportBuilder>(::NettyTcpServerTransportBuilderImpl) { + public operator fun invoke( + context: CoroutineContext, + hostname: String = "0.0.0.0", + port: Int = 0, + block: NettyTcpServerTransportBuilder.() -> Unit = {}, + ): NettyTcpServerTransport = invoke(context, InetSocketAddress(hostname, port), block) + + public operator fun invoke( + context: CoroutineContext, + block: NettyTcpServerTransportBuilder.() -> Unit = {}, + ): NettyTcpServerTransport = invoke(context, null, block) + } +} + +public sealed interface NettyTcpServerTransportEngine : RSocketTransportEngine { + public fun createTransport( + hostname: String = "0.0.0.0", + port: Int = 0, + ): NettyTcpServerTransport = createTransport(InetSocketAddress(hostname, port)) + + public fun createTransport(): NettyTcpServerTransport = createTransport(null) + + public companion object Factory : RSocketTransportEngineFactory< + SocketAddress?, + NettyTcpServerTransport, + NettyTcpServerTransportEngine, + NettyTcpServerTransportBuilder>(::NettyTcpServerTransportBuilderImpl) +} + +public sealed interface NettyTcpServerTransportBuilder : + RSocketTransportBuilder, + RSocketTransportEngineBuilder { + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: ServerBootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) +} + +private class NettyTcpServerTransportBuilderImpl : NettyTcpServerTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var parentEventLoopGroup: EventLoopGroup? = null + private var childEventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (ServerBootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = parentGroup + this.childEventLoopGroup = childGroup + this.manageEventLoopGroup = manage + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = group + this.childEventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: ServerBootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext, target: SocketAddress?): NettyTcpServerTransport { + return build(context).buildTransport(target) + } + + @RSocketTransportApi + override fun buildEngine(context: CoroutineContext): NettyTcpServerTransportEngine { + return build(context).buildEngine() + } + + private fun build(context: CoroutineContext): NettyTcpServerTransportResources { + val parentGroup = parentEventLoopGroup ?: NioEventLoopGroup() + val childGroup = childEventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioServerSocketChannel::class.java) + + val transportContext = context.supervisorContext() + childGroup.asCoroutineDispatcher() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + childGroup.shutdownGracefully().awaitFuture() + parentGroup.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forServer(KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())) + .apply(it) + .build() + } + + val bootstrap = ServerBootstrap().apply { + bootstrap?.invoke(this) + group(parentGroup, childGroup) + channelFactory(factory) + } + + return NettyTcpServerTransportResources( + coroutineContext = transportContext, + bootstrap = bootstrap, + sslContext = sslContext + ) + } +} + +private class NettyTcpServerTransportResources( + private val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, +) { + fun buildTransport(address: SocketAddress?): NettyTcpServerTransport { + return NettyTcpServerTransportImpl( + coroutineContext = coroutineContext, + localAddress = address, + bootstrap = bootstrap, + sslContext = sslContext + ) + } + + fun buildEngine(): NettyTcpServerTransportEngine { + return NettyTcpServerTransportEngineImpl( + coroutineContext = coroutineContext, + bootstrap = bootstrap, + sslContext = sslContext + ) + } +} + +private class NettyTcpServerTransportEngineImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, +) : NettyTcpServerTransportEngine { + override fun createTransport(target: SocketAddress?): NettyTcpServerTransport { + return NettyTcpServerTransportImpl( + coroutineContext = coroutineContext.supervisorContext(), + localAddress = target, + bootstrap = bootstrap, + sslContext = sslContext + ) + } +} + +private class NettyTcpServerTransportImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: SocketAddress?, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, +) : NettyTcpServerTransport { + @RSocketTransportApi + override suspend fun startServer(acceptor: RSocketServerAcceptor): NettyTcpServerInstance { + ensureActive() + + val instanceContext = coroutineContext.supervisorContext() + try { + val future = bootstrap.clone().apply { + childHandler(AcceptorChannelHandler(instanceContext, sslContext, acceptor)) + }.bind(localAddress ?: InetSocketAddress(0)) + + try { + future.awaitFuture() + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + + return NettyTcpServerInstanceImpl( + coroutineContext = instanceContext, + channel = future.channel() as ServerChannel + ) + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + } +} + +@RSocketTransportApi +private class AcceptorChannelHandler( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val acceptor: RSocketServerAcceptor, +) : ChannelInitializer(), CoroutineScope { + override fun initChannel(ch: DuplexChannel) { + val handler = NettyTcpChannelHandler( + sslContext = sslContext, + remoteAddress = null + ) + ch.pipeline().addLast(handler) + val connection = handler.connect(coroutineContext.childContext(), ch) + launch { acceptor.acceptSession(connection) } + } +} + +private class NettyTcpServerInstanceImpl( + override val coroutineContext: CoroutineContext, + private val channel: ServerChannel, +) : NettyTcpServerInstance { + override val localAddress: SocketAddress get() = channel.localAddress() + + init { + linkCompletionWith(channel) + } +} diff --git a/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpSession.kt b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpSession.kt new file mode 100644 index 00000000..5a040b2a --- /dev/null +++ b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpSession.kt @@ -0,0 +1,44 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.ktor.utils.io.core.* +import io.netty.buffer.* +import io.netty.channel.socket.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.channels.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyTcpSession( + override val coroutineContext: CoroutineContext, + private val channel: DuplexChannel, + private val frames: ReceiveChannel, +) : RSocketTransportSession.Sequential { + + init { + linkCompletionWith(channel) + } + + override suspend fun sendFrame(frame: ByteReadPacket) { + channel.writeAndFlush(Unpooled.wrappedBuffer(frame.readByteBuffer())).awaitFuture() + } + + override suspend fun receiveFrame(): ByteReadPacket { + return frames.receive() + } +} diff --git a/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/utils.kt b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/utils.kt new file mode 100644 index 00000000..32378a99 --- /dev/null +++ b/rsocket-transport-netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/utils.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.channel.* +import io.netty.util.concurrent.* +import io.rsocket.kotlin.internal.io.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@Suppress("UNCHECKED_CAST") +internal suspend inline fun Future.awaitFuture(): T = suspendCancellableCoroutine { cont -> + addListener { + when { + it.isSuccess -> cont.resume(it.now as T) + else -> cont.resumeWithException(it.cause()) + } + } + cont.invokeOnCancellation { + cancel(true) + } +} + +internal fun CoroutineScope.linkCompletionWith(channel: Channel) { + channel.closeFuture().addListener { + cancel("Netty channel closed", it.cause()) + } + invokeOnCancellation { + channel.close().awaitFuture() + } +} diff --git a/rsocket-transport-netty-tcp/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpTransportTest.kt b/rsocket-transport-netty-tcp/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpTransportTest.kt new file mode 100644 index 00000000..35c55cfd --- /dev/null +++ b/rsocket-transport-netty-tcp/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpTransportTest.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.channel.nio.* +import io.netty.handler.ssl.util.* +import io.rsocket.kotlin.transport.tests.* +import kotlin.concurrent.* + +private val eventLoop = NioEventLoopGroup().also { + Runtime.getRuntime().addShutdownHook(thread(start = false) { + it.shutdownGracefully().await(1000) + }) +} +private val certificates = SelfSignedCertificate() + +class NettyTcpTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyTcpServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + } + ) + client = connectClient( + NettyTcpClientTransport(testContext, server.localAddress) { + eventLoopGroup(eventLoop, manage = false) + } + ) + } +} + +class NettyTcpSslTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyTcpServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + keyManager(certificates.certificate(), certificates.privateKey()) + } + } + ) + client = connectClient( + NettyTcpClientTransport(testContext, server.localAddress) { + eventLoopGroup(eventLoop, manage = false) + ssl { + trustManager(InsecureTrustManagerFactory.INSTANCE) + } + } + ) + } +} diff --git a/rsocket-transport-netty-websocket/api/rsocket-transport-netty-websocket.api b/rsocket-transport-netty-websocket/api/rsocket-transport-netty-websocket.api new file mode 100644 index 00000000..b12f49db --- /dev/null +++ b/rsocket-transport-netty-websocket/api/rsocket-transport-netty-websocket.api @@ -0,0 +1,87 @@ +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport : io/rsocket/kotlin/transport/RSocketClientTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport$Factory; + public abstract fun getConfig ()Lio/netty/handler/codec/http/websocketx/WebSocketClientProtocolConfig; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { + public final fun invoke (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport; + public static synthetic fun invoke$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport$Factory;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder, io/rsocket/kotlin/transport/RSocketTransportEngineBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V + public abstract fun webSockets (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransportEngine : io/rsocket/kotlin/transport/RSocketTransportEngine { + public static final field Factory Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransportEngine$Factory; + public abstract fun createTransport (Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransportEngine$DefaultImpls { + public static fun createTransport (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransportEngine;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransportEngine$Factory : io/rsocket/kotlin/transport/RSocketTransportEngineFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getConfig ()Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig; + public abstract fun getLocalAddress ()Ljava/net/SocketAddress; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTarget { + public fun (Ljava/net/SocketAddress;Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;)V + public final fun getConfig ()Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig; + public final fun getLocalAddress ()Ljava/net/SocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport : io/rsocket/kotlin/transport/RSocketServerTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory; + public abstract fun getConfig ()Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig; + public abstract fun getLocalAddress ()Ljava/net/SocketAddress; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { + public final fun invoke (Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public final fun invoke (Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public final fun invoke (Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public final fun invoke (Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static synthetic fun invoke$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory;Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static synthetic fun invoke$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory;Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static synthetic fun invoke$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory;Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static synthetic fun invoke$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory;Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder, io/rsocket/kotlin/transport/RSocketTransportEngineBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V + public abstract fun webSockets (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine : io/rsocket/kotlin/transport/RSocketTransportEngine { + public abstract fun createTransport (Ljava/lang/String;ILkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public abstract fun createTransport (Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public abstract fun createTransport (Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public abstract fun createTransport (Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine$DefaultImpls { + public static fun createTransport (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine;Ljava/lang/String;ILkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static fun createTransport (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine;Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static fun createTransport (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine;Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static fun createTransport (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine;Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static synthetic fun createTransport$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine;Ljava/lang/String;ILkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static synthetic fun createTransport$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine;Lkotlin/coroutines/CoroutineContext;Ljava/lang/String;ILio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static synthetic fun createTransport$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine;Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; + public static synthetic fun createTransport$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportEngine;Lkotlin/coroutines/CoroutineContext;Ljava/net/SocketAddress;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport; +} + diff --git a/rsocket-transport-netty-websocket/build.gradle.kts b/rsocket-transport-netty-websocket/build.gradle.kts new file mode 100644 index 00000000..a86a46c2 --- /dev/null +++ b/rsocket-transport-netty-websocket/build.gradle.kts @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import rsocketbuild.* + +plugins { + id("rsocketbuild.template.library") +} + +kotlin { + jvmTarget() + + sourceSets { + jvmMain { + dependencies { + implementation(projects.rsocketInternalIo) + + api(projects.rsocketCore) + api(libs.netty.handler) + api(libs.netty.codec.http) + } + } + jvmTest { + dependencies { + implementation(projects.rsocketTransportTests) + implementation(libs.bouncycastle) + } + } + } +} + +description = "rsocket-kotlin Netty WebSocket transport implementation" diff --git a/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt new file mode 100644 index 00000000..f277748b --- /dev/null +++ b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketChannelHandler.kt @@ -0,0 +1,104 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.ktor.utils.io.core.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import java.net.* +import kotlin.coroutines.* + + +internal class NettyWebSocketChannelHandler( + private val sslContext: SslContext?, + private val remoteAddress: SocketAddress?, + private val httpHandler: ChannelHandler, + private val webSocketHandler: ChannelHandler, +) : ChannelInitializer() { + private val frames = channelForCloseable(Channel.UNLIMITED) + private val handshakeDeferred = CompletableDeferred() + + @RSocketTransportApi + suspend fun connect( + context: CoroutineContext, + channel: DuplexChannel, + ): NettyWebSocketSession { + handshakeDeferred.await() + + return NettyWebSocketSession( + coroutineContext = context.childContext(), + channel = channel, + frames = frames + ) + } + + override fun initChannel(ch: DuplexChannel): Unit = with(ch.pipeline()) { + if (sslContext != null) { + val sslHandler = if ( + remoteAddress is InetSocketAddress && + ch.parent() == null // not server + ) { + sslContext.newHandler(ch.alloc(), remoteAddress.hostName, remoteAddress.port) + } else { + sslContext.newHandler(ch.alloc()) + } + addLast("ssl", sslHandler) + } + addLast("http", httpHandler) + addLast(HttpObjectAggregator(65536)) //TODO size? + addLast("websocket", webSocketHandler) + + addLast( + "rsocket-frame-receiver", + IncomingFramesChannelHandler() + ) + } + + private inner class IncomingFramesChannelHandler : SimpleChannelInboundHandler() { + override fun channelInactive(ctx: ChannelHandlerContext) { + frames.close() //TODO? + super.channelInactive(ctx) + } + + override fun channelRead0(ctx: ChannelHandlerContext, msg: WebSocketFrame) { + if (msg !is BinaryWebSocketFrame && msg !is TextWebSocketFrame) { + error("wrong frame type") + } + + frames.trySend(buildPacket { + writeFully(msg.content().nioBuffer()) + }).getOrThrow() + } + + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { + if ( + evt is WebSocketServerProtocolHandler.HandshakeComplete || + evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE + ) { + handshakeDeferred.complete(Unit) + } + //TODO: handle timeout - ? + } + } +} diff --git a/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt new file mode 100644 index 00000000..34191f1d --- /dev/null +++ b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt @@ -0,0 +1,225 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyWebSocketClientTransport : RSocketClientTransport { + public val config: WebSocketClientProtocolConfig + + public companion object Factory : RSocketTransportFactory< + WebSocketClientProtocolConfig, + NettyWebSocketClientTransport, + NettyWebSocketClientTransportBuilder>(::NettyWebSocketClientTransportBuilderImpl) { + + public operator fun invoke( + context: CoroutineContext, + target: WebSocketClientProtocolConfig.Builder.() -> Unit, + block: NettyWebSocketClientTransportBuilder.() -> Unit = {}, + ): NettyWebSocketClientTransport { + return invoke(context, WebSocketClientProtocolConfig.newBuilder().apply(target).build(), block) + } + } +} + +public sealed interface NettyWebSocketClientTransportEngine : + RSocketTransportEngine { + + public fun createTransport( + target: WebSocketClientProtocolConfig.Builder.() -> Unit, + ): NettyWebSocketClientTransport { + return createTransport(WebSocketClientProtocolConfig.newBuilder().apply(target).build()) + } + + public companion object Factory : RSocketTransportEngineFactory< + WebSocketClientProtocolConfig, + NettyWebSocketClientTransport, + NettyWebSocketClientTransportEngine, + NettyWebSocketClientTransportBuilder>(::NettyWebSocketClientTransportBuilderImpl) +} + +public sealed interface NettyWebSocketClientTransportBuilder : + RSocketTransportBuilder, + RSocketTransportEngineBuilder { + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: Bootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) + public fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit) +} + +private class NettyWebSocketClientTransportBuilderImpl : NettyWebSocketClientTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var eventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (Bootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + private var webSockets: (WebSocketClientProtocolConfig.Builder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.eventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: Bootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + override fun webSockets(block: WebSocketClientProtocolConfig.Builder.() -> Unit) { + webSockets = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext, target: WebSocketClientProtocolConfig): NettyWebSocketClientTransport { + return build(context).buildTransport(target) + } + + @RSocketTransportApi + override fun buildEngine(context: CoroutineContext): NettyWebSocketClientTransportEngine { + return build(context).buildEngine() + } + + private fun build(context: CoroutineContext): NettyWebSocketClientTransportResources { + val group = eventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioSocketChannel::class.java) + + val transportContext = context.supervisorContext() + group.asCoroutineDispatcher() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + group.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forClient() + .apply(it) + .build() + } + + val bootstrap = Bootstrap().apply { + bootstrap?.invoke(this) + group(group) + channelFactory(factory) + } + + return NettyWebSocketClientTransportResources( + coroutineContext = transportContext, + sslContext = sslContext, + bootstrap = bootstrap, + webSocketConfig = webSockets + ) + } +} + +private class NettyWebSocketClientTransportResources( + private val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, + private val webSocketConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?, +) { + fun buildTransport(config: WebSocketClientProtocolConfig): NettyWebSocketClientTransport { + return NettyWebSocketClientTransportImpl( + coroutineContext = coroutineContext, + config = config.toBuilder().apply { + webSocketConfig?.invoke(this) + }.build(), + sslContext = sslContext, + bootstrap = bootstrap, + ) + } + + fun buildEngine(): NettyWebSocketClientTransportEngine { + return NettyWebSocketClientTransportEngineImpl( + coroutineContext = coroutineContext, + sslContext = sslContext, + bootstrap = bootstrap, + webSocketConfig = webSocketConfig, + ) + } +} + +private class NettyWebSocketClientTransportEngineImpl( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, + private val webSocketConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?, +) : NettyWebSocketClientTransportEngine { + override fun createTransport(target: WebSocketClientProtocolConfig): NettyWebSocketClientTransport { + return NettyWebSocketClientTransportImpl( + coroutineContext = coroutineContext.supervisorContext(), + config = target.toBuilder().apply { + webSocketConfig?.invoke(this) + }.build(), + sslContext = sslContext, + bootstrap = bootstrap, + ) + } +} + +private class NettyWebSocketClientTransportImpl( + override val coroutineContext: CoroutineContext, + override val config: WebSocketClientProtocolConfig, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, +) : NettyWebSocketClientTransport { + + @RSocketTransportApi + override suspend fun createSession(): RSocketTransportSession { + val remoteAddress = InetSocketAddress(config.webSocketUri().host, config.webSocketUri().port) + val handler = NettyWebSocketChannelHandler( + sslContext = sslContext, + remoteAddress = remoteAddress, + httpHandler = HttpClientCodec(), + webSocketHandler = WebSocketClientProtocolHandler(config), + ) + val future = bootstrap.clone().apply { + handler(handler) + }.connect(remoteAddress) + + future.awaitFuture() + + return handler.connect(coroutineContext, future.channel() as DuplexChannel) + } +} diff --git a/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt new file mode 100644 index 00000000..6aa3f6fc --- /dev/null +++ b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt @@ -0,0 +1,369 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import javax.net.ssl.* +import kotlin.coroutines.* +import kotlin.reflect.* + +public sealed interface NettyWebSocketServerInstance : RSocketServerInstance { + public val localAddress: SocketAddress + public val config: WebSocketServerProtocolConfig +} + +public class NettyWebSocketServerTarget( + public val localAddress: SocketAddress?, + public val config: WebSocketServerProtocolConfig?, +) + +public sealed interface NettyWebSocketServerTransport : RSocketServerTransport { + public val localAddress: SocketAddress? + public val config: WebSocketServerProtocolConfig + + public companion object Factory : RSocketTransportFactory< + NettyWebSocketServerTarget, + NettyWebSocketServerTransport, + NettyWebSocketServerTransportBuilder>(::NettyWebSocketServerTransportBuilderImpl) { + + // address|host+port + config|builder + + public operator fun invoke( + context: CoroutineContext, + localAddress: SocketAddress?, + protocolConfig: WebSocketServerProtocolConfig, + block: NettyWebSocketServerTransportBuilder.() -> Unit = {}, + ): NettyWebSocketServerTransport = invoke( + context = context, + target = NettyWebSocketServerTarget(localAddress, protocolConfig), + block = block + ) + + public operator fun invoke( + context: CoroutineContext, + localAddress: SocketAddress?, + protocolConfig: WebSocketServerProtocolConfig.Builder.() -> Unit = {}, + block: NettyWebSocketServerTransportBuilder.() -> Unit = {}, + ): NettyWebSocketServerTransport = invoke( + context = context, + target = NettyWebSocketServerTarget( + localAddress = localAddress, + config = WebSocketServerProtocolConfig.newBuilder().apply(protocolConfig).build() + ), + block = block + ) + + public operator fun invoke( + context: CoroutineContext, + hostname: String = "0.0.0.0", + port: Int = 0, + protocolConfig: WebSocketServerProtocolConfig, + block: NettyWebSocketServerTransportBuilder.() -> Unit = {}, + ): NettyWebSocketServerTransport = invoke( + context = context, + target = NettyWebSocketServerTarget(InetSocketAddress(hostname, port), protocolConfig), + block = block + ) + + public operator fun invoke( + context: CoroutineContext, + hostname: String = "0.0.0.0", + port: Int = 0, + protocolConfig: WebSocketServerProtocolConfig.Builder.() -> Unit = {}, + block: NettyWebSocketServerTransportBuilder.() -> Unit = {}, + ): NettyWebSocketServerTransport = invoke( + context = context, + target = NettyWebSocketServerTarget( + localAddress = InetSocketAddress(hostname, port), + config = WebSocketServerProtocolConfig.newBuilder().apply(protocolConfig).build() + ), + block = block + ) + } +} + +public sealed interface NettyWebSocketServerTransportEngine : + RSocketTransportEngine { + + public fun createTransport( + context: CoroutineContext, + localAddress: SocketAddress?, + protocolConfig: WebSocketServerProtocolConfig, + block: NettyWebSocketServerTransportBuilder.() -> Unit = {}, + ): NettyWebSocketServerTransport = createTransport( + target = NettyWebSocketServerTarget(localAddress, protocolConfig), + ) + + public fun createTransport( + context: CoroutineContext, + localAddress: SocketAddress?, + protocolConfig: WebSocketServerProtocolConfig.Builder.() -> Unit = {}, + block: NettyWebSocketServerTransportBuilder.() -> Unit = {}, + ): NettyWebSocketServerTransport = createTransport( + target = NettyWebSocketServerTarget( + localAddress = localAddress, + config = WebSocketServerProtocolConfig.newBuilder().apply(protocolConfig).build() + ), + ) + + public fun createTransport( + context: CoroutineContext, + hostname: String = "0.0.0.0", + port: Int = 0, + protocolConfig: WebSocketServerProtocolConfig, + block: NettyWebSocketServerTransportBuilder.() -> Unit = {}, + ): NettyWebSocketServerTransport = createTransport( + target = NettyWebSocketServerTarget(InetSocketAddress(hostname, port), protocolConfig), + ) + + public fun createTransport( + hostname: String = "0.0.0.0", + port: Int = 0, + protocolConfig: WebSocketServerProtocolConfig.Builder.() -> Unit = {}, + ): NettyWebSocketServerTransport = createTransport( + target = NettyWebSocketServerTarget( + localAddress = InetSocketAddress(hostname, port), + config = WebSocketServerProtocolConfig.newBuilder().apply(protocolConfig).build() + ), + ) +} + +public sealed interface NettyWebSocketServerTransportBuilder : + RSocketTransportBuilder, + RSocketTransportEngineBuilder { + + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: ServerBootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) + public fun webSockets(block: WebSocketServerProtocolConfig.Builder.() -> Unit) +} + +private class NettyWebSocketServerTransportBuilderImpl : NettyWebSocketServerTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var parentEventLoopGroup: EventLoopGroup? = null + private var childEventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (ServerBootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + private var webSockets: (WebSocketServerProtocolConfig.Builder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = parentGroup + this.childEventLoopGroup = childGroup + this.manageEventLoopGroup = manage + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = group + this.childEventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: ServerBootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + override fun webSockets(block: WebSocketServerProtocolConfig.Builder.() -> Unit) { + webSockets = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext, target: NettyWebSocketServerTarget): NettyWebSocketServerTransport { + return build(context).buildTransport(target) + } + + @RSocketTransportApi + override fun buildEngine(context: CoroutineContext): NettyWebSocketServerTransportEngine { + return build(context).buildEngine() + } + + private fun build(context: CoroutineContext): NettyWebSocketServerTransportResources { + val parentGroup = parentEventLoopGroup ?: NioEventLoopGroup() + val childGroup = childEventLoopGroup ?: NioEventLoopGroup() + val factory = channelFactory ?: ReflectiveChannelFactory(NioServerSocketChannel::class.java) + + val transportContext = context.supervisorContext() + childGroup.asCoroutineDispatcher() + if (manageEventLoopGroup) CoroutineScope(transportContext).invokeOnCancellation { + childGroup.shutdownGracefully().awaitFuture() + parentGroup.shutdownGracefully().awaitFuture() + } + + val sslContext = ssl?.let { + SslContextBuilder + .forServer(KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())) + .apply(it) + .build() + } + + val bootstrap = ServerBootstrap().apply { + bootstrap?.invoke(this) + group(parentGroup, childGroup) + channelFactory(factory) + } + + return NettyWebSocketServerTransportResources( + coroutineContext = transportContext, + bootstrap = bootstrap, + sslContext = sslContext, + webSocketConfig = webSockets + ) + } +} + +private class NettyWebSocketServerTransportResources( + private val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, + private val webSocketConfig: (WebSocketServerProtocolConfig.Builder.() -> Unit)?, +) { + fun buildTransport(target: NettyWebSocketServerTarget): NettyWebSocketServerTransport { + return NettyWebSocketServerTransportImpl( + coroutineContext = coroutineContext, + localAddress = target.localAddress, + config = (target.config?.toBuilder() ?: WebSocketServerProtocolConfig.newBuilder()).apply { + webSocketConfig?.invoke(this) + }.build(), + bootstrap = bootstrap, + sslContext = sslContext + ) + } + + fun buildEngine(): NettyWebSocketServerTransportEngine { + return NettyWebSocketServerTransportEngineImpl( + coroutineContext = coroutineContext, + bootstrap = bootstrap, + sslContext = sslContext, + webSocketConfig = webSocketConfig + ) + } +} + +private class NettyWebSocketServerTransportEngineImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, + private val webSocketConfig: (WebSocketServerProtocolConfig.Builder.() -> Unit)?, +) : NettyWebSocketServerTransportEngine { + override fun createTransport(target: NettyWebSocketServerTarget): NettyWebSocketServerTransport { + return NettyWebSocketServerTransportImpl( + coroutineContext = coroutineContext.supervisorContext(), + localAddress = target.localAddress, + config = (target.config?.toBuilder() ?: WebSocketServerProtocolConfig.newBuilder()).apply { + webSocketConfig?.invoke(this) + }.build(), + bootstrap = bootstrap, + sslContext = sslContext + ) + } +} + +private class NettyWebSocketServerTransportImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: SocketAddress?, + override val config: WebSocketServerProtocolConfig, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, +) : NettyWebSocketServerTransport { + + @RSocketTransportApi + override suspend fun startServer(acceptor: RSocketServerAcceptor): NettyWebSocketServerInstance { + val instanceContext = coroutineContext.supervisorContext() + try { + val future = bootstrap.clone().apply { + localAddress(localAddress ?: InetSocketAddress(0)) + childHandler(AcceptorChannelHandler(instanceContext, sslContext, acceptor, config)) + }.bind() + + try { + future.awaitFuture() + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + + return NettyWebSocketServerInstanceImpl( + coroutineContext = instanceContext, + channel = future.channel() as ServerChannel, + config = config + ) + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + } +} + +@RSocketTransportApi +private class AcceptorChannelHandler( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val acceptor: RSocketServerAcceptor, + private val config: WebSocketServerProtocolConfig, +) : ChannelInitializer(), CoroutineScope { + override fun initChannel(ch: DuplexChannel) { + val handler = NettyWebSocketChannelHandler( + sslContext = sslContext, + remoteAddress = null, + httpHandler = HttpServerCodec(), + webSocketHandler = WebSocketServerProtocolHandler(config) + ) + ch.pipeline().addLast(handler) + launch { + acceptor.acceptSession(handler.connect(coroutineContext, ch)) + } + } +} + +private class NettyWebSocketServerInstanceImpl( + override val coroutineContext: CoroutineContext, + private val channel: ServerChannel, + override val config: WebSocketServerProtocolConfig, +) : NettyWebSocketServerInstance { + override val localAddress: SocketAddress get() = channel.localAddress() + + init { + linkCompletionWith(channel) + } +} diff --git a/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketSession.kt b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketSession.kt new file mode 100644 index 00000000..0346a501 --- /dev/null +++ b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketSession.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.ktor.utils.io.core.* +import io.netty.buffer.* +import io.netty.channel.socket.* +import io.netty.handler.codec.http.websocketx.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.channels.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyWebSocketSession( + override val coroutineContext: CoroutineContext, + private val channel: DuplexChannel, + private val frames: ReceiveChannel, +) : RSocketTransportSession.Sequential { + + init { + linkCompletionWith(channel) + } + + override suspend fun sendFrame(frame: ByteReadPacket) { + channel.writeAndFlush(BinaryWebSocketFrame(Unpooled.wrappedBuffer(frame.readByteBuffer()))).awaitFuture() + } + + override suspend fun receiveFrame(): ByteReadPacket { + return frames.receive() + } +} diff --git a/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/utils.kt b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/utils.kt new file mode 100644 index 00000000..285cba63 --- /dev/null +++ b/rsocket-transport-netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/utils.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.channel.* +import io.netty.util.concurrent.* +import io.rsocket.kotlin.internal.io.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@Suppress("UNCHECKED_CAST") +internal suspend inline fun Future.awaitFuture(): T = suspendCancellableCoroutine { cont -> + addListener { + when { + it.isSuccess -> cont.resume(it.now as T) + else -> cont.resumeWithException(it.cause()) + } + } + cont.invokeOnCancellation { + cancel(true) + } +} + +internal fun CoroutineScope.linkCompletionWith(channel: Channel) { + channel.closeFuture().addListener { + cancel("Netty channel closed", it.cause()) + } + invokeOnCancellation { + channel.close().awaitFuture() + } +} diff --git a/rsocket-transport-netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt b/rsocket-transport-netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt new file mode 100644 index 00000000..7a03dc12 --- /dev/null +++ b/rsocket-transport-netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt @@ -0,0 +1,72 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.channel.nio.* +import io.netty.handler.ssl.util.* +import io.rsocket.kotlin.transport.tests.* +import java.net.* +import kotlin.concurrent.* + +private val eventLoop = NioEventLoopGroup().also { + Runtime.getRuntime().addShutdownHook(thread(start = false) { + it.shutdownGracefully().await(1000) + }) +} +private val certificates = SelfSignedCertificate() + +class NettyWebSocketTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyWebSocketServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + } + ) + client = connectClient( + NettyWebSocketClientTransport(testContext, { + val address = server.localAddress as InetSocketAddress + webSocketUri("ws://localhost:${address.port}") + }) { + eventLoopGroup(eventLoop, manage = false) + } + ) + } +} + +class NettyWebSocketSslTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyWebSocketServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + keyManager(certificates.certificate(), certificates.privateKey()) + } + } + ) + client = connectClient( + NettyWebSocketClientTransport(testContext, { + val address = server.localAddress as InetSocketAddress + webSocketUri("ws://localhost:${address.port}") + }) { + eventLoopGroup(eventLoop, manage = false) + ssl { + trustManager(InsecureTrustManagerFactory.INSTANCE) + } + } + ) + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 3b05c558..0ae7c221 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -40,6 +40,8 @@ include( "rsocket-transport-ktor-websocket-client", "rsocket-transport-ktor-websocket-server", "rsocket-transport-nodejs-tcp", + "rsocket-transport-netty-tcp", + "rsocket-transport-netty-websocket", "rsocket-transport-tests" )