From 9a321a956d635eded6415c983932bcc1a6f48d97 Mon Sep 17 00:00:00 2001 From: Lorenzo Gabriele Date: Mon, 9 Oct 2023 10:17:28 +0200 Subject: [PATCH] WIP: Support cats-effect 3.6 --- build.sc | 5 +- snunit-http4s/src/snunit/Http4sApp.scala | 3 +- .../snunit/http4s/CEAsyncServerBuilder.scala | 142 ++++++++++++++++++ snunit-http4s/src/snunit/http4s/Impl.scala | 2 +- snunit/src/snunit/AsyncServerBuilder.scala | 8 +- 5 files changed, 153 insertions(+), 7 deletions(-) create mode 100644 snunit-http4s/src/snunit/http4s/CEAsyncServerBuilder.scala diff --git a/build.sc b/build.sc index b8acc31..8fc3ca8 100644 --- a/build.sc +++ b/build.sc @@ -163,7 +163,10 @@ trait SNUnitHttp4s extends Common.Cross with Cross.Module2[String, String] with case s"1.$_" => "1" } def artifactName = s"snunit-http4s$http4sBinaryVersion" - def ivyDeps = super.ivyDeps() ++ Agg(ivy"org.http4s::http4s-server::$http4sVersion") + def ivyDeps = super.ivyDeps() ++ Agg( + ivy"org.http4s::http4s-server::$http4sVersion", + ivy"org.typelevel::cats-effect::3.6-0142603", + ) def sources = T.sources { super.sources() ++ Agg(PathRef(millSourcePath / s"http4s-$http4sBinaryVersion" / "src")) } diff --git a/snunit-http4s/src/snunit/Http4sApp.scala b/snunit-http4s/src/snunit/Http4sApp.scala index e33d3f8..432814b 100644 --- a/snunit-http4s/src/snunit/Http4sApp.scala +++ b/snunit-http4s/src/snunit/Http4sApp.scala @@ -1,11 +1,12 @@ package snunit import cats.effect.IO +import cats.effect.IOApp import cats.effect.Resource import org.http4s.HttpApp import snunit.http4s.SNUnitServerBuilder -trait Http4sApp extends epollcat.EpollApp.Simple { +trait Http4sApp extends IOApp.Simple { def routes: Resource[IO, HttpApp[IO]] override def run = routes.use { r => diff --git a/snunit-http4s/src/snunit/http4s/CEAsyncServerBuilder.scala b/snunit-http4s/src/snunit/http4s/CEAsyncServerBuilder.scala new file mode 100644 index 0000000..ae1bc53 --- /dev/null +++ b/snunit-http4s/src/snunit/http4s/CEAsyncServerBuilder.scala @@ -0,0 +1,142 @@ +package snunit + +import snunit.unsafe.{*, given} + +import cats.effect.* + +import scala.annotation.tailrec +import scala.scalanative.libc.errno.errno +import scala.scalanative.libc.string.strerror +import scala.scalanative.posix.fcntl._ +import scala.scalanative.posix.fcntlOps._ +import scala.scalanative.posix.sys.ioctl._ +import scala.scalanative.runtime.Intrinsics +import scala.scalanative.runtime.fromRawPtr +import scala.scalanative.runtime.toRawPtr +import scala.scalanative.unsafe.* +import scala.util.control.NonFatal + +object CEAsyncServerBuilder { + private val initArray: Array[Byte] = new Array[Byte](sizeof[nxt_unit_init_t].toInt) + private val init: nxt_unit_init_t_* = initArray.at(0).asInstanceOf[nxt_unit_init_t_*] + def setRequestHandler(requestHandler: RequestHandler): this.type = { + ServerBuilder.setRequestHandler(requestHandler) + this + } + def setWebsocketHandler(websocketHandler: WebsocketHandler): this.type = { + ServerBuilder.setWebsocketHandler(websocketHandler) + this + } + private var shutdownHandler: (() => Unit) => Unit = shutdown => shutdown() + def setShutdownHandler(shutdownHandler: (() => Unit) => Unit): this.type = { + this.shutdownHandler = shutdownHandler + this + } + def build(): IO[Unit] = { + ServerBuilder.setBaseHandlers(init) + init.callbacks.add_port = CEAsyncServerBuilder.add_port + init.callbacks.remove_port = CEAsyncServerBuilder.remove_port + init.callbacks.quit = CEAsyncServerBuilder.quit + val ctx: nxt_unit_ctx_t_* = nxt_unit_init(init) + if (ctx.isNull) { + throw new Exception("Failed to create Unit object") + } + IO { () } + } + + private val add_port: add_port_t = add_port_t { (ctx: nxt_unit_ctx_t_*, port: nxt_unit_port_t_*) => + { + if (port.in_fd != -1) { + var result = NXT_UNIT_OK + locally { + val res = fcntl(port.in_fd, F_SETFL, O_NONBLOCK) + if (res == -1) { + nxt_unit_warn(ctx, s"fcntl(${port.in_fd}, O_NONBLOCK) failed: ${fromCString(strerror(errno))}, $errno)") + result = -1 + } + } + if (result == NXT_UNIT_OK) { + try { + PortData.register(ctx, port) + NXT_UNIT_OK + } catch { + case NonFatal(e @ _) => + nxt_unit_warn(ctx, s"Polling failed: ${fromCString(strerror(errno))}, $errno)") + NXT_UNIT_ERROR + } + } else result + } else NXT_UNIT_OK + } + } + + private val remove_port: remove_port_t = + remove_port_t { (_: nxt_unit_t_*, ctx: nxt_unit_ctx_t_*, port: nxt_unit_port_t_*) => + { + if (port.data != null && !ctx.isNull) { + PortData.fromPort(port).stop() + } + } + } + + private val quit: quit_t = quit_t { (ctx: nxt_unit_ctx_t_*) => + shutdownHandler { () => + nxt_unit_done(ctx) + } + } + + private class PortData private ( + val ctx: nxt_unit_ctx_t_*, + val port: nxt_unit_port_t_* + ) { + private var stopped: Boolean = false + + // ideally this shouldn't be needed. + // in theory rc == NXT_UNIT_AGAIN + // would mean that there aren't any messages + // to read. In practice if we stop at rc == NXT_UNIT_AGAIN + // there are some unprocessed messages which effect in + // epollcat (which uses edge-triggering) to hang on close + // since one port to remain open and one callback registered + def continueReading: Boolean = { + val bytesAvailable = stackalloc[Int]() + ioctl(port.in_fd, FIONREAD, bytesAvailable.asInstanceOf[Ptr[Byte]]) + !bytesAvailable > 0 + } + + (for + pollers <- Resource.eval(IO.pollers) + poller = pollers.head.asInstanceOf[FileDescriptorPoller] + handle <- poller.registerFileDescriptor(port.in_fd, monitorReadReady = true, monitorWriteReady = false) + res <- Resource.eval(handle.pollReadRec(()) { _ => + IO { + // process messages until we are blocked + while (nxt_unit_process_port_msg(ctx, port) == NXT_UNIT_OK || continueReading) {} + // suspend until more data is available on the socket, then we will be invoked again + Left(()) + } + }) + yield res).useForever.unsafeRunAndForget()(cats.effect.unsafe.implicits.global) + + def stop(): Unit = { + stopped = true + PortData.stopped.put(this, ()) + } + } + + private object PortData { + private[this] val references = new java.util.IdentityHashMap[PortData, Unit] + + private[this] val stopped = new java.util.IdentityHashMap[PortData, Unit] + + def isLastFDStopped: Boolean = references == stopped + + def register(ctx: nxt_unit_ctx_t_*, port: nxt_unit_port_t_*): Unit = + val portData = new PortData(ctx, port) + references.put(portData, ()) + port.data = fromRawPtr(Intrinsics.castObjectToRawPtr(portData)) + + def fromPort(port: nxt_unit_port_t_*): PortData = { + Intrinsics.castRawPtrToObject(toRawPtr(port.data)).asInstanceOf[PortData] + } + } +} diff --git a/snunit-http4s/src/snunit/http4s/Impl.scala b/snunit-http4s/src/snunit/http4s/Impl.scala index 9915eab..a54080b 100644 --- a/snunit-http4s/src/snunit/http4s/Impl.scala +++ b/snunit-http4s/src/snunit/http4s/Impl.scala @@ -23,7 +23,7 @@ private[http4s] object Impl { .parallel[F](await = true) .use { dispatcher => Async[F].delay( - snunit.AsyncServerBuilder + snunit.CEAsyncServerBuilder .setRequestHandler(new snunit.RequestHandler { def handleRequest(req: snunit.Request): Unit = { val run = httpApp diff --git a/snunit/src/snunit/AsyncServerBuilder.scala b/snunit/src/snunit/AsyncServerBuilder.scala index 7837a44..27c6340 100644 --- a/snunit/src/snunit/AsyncServerBuilder.scala +++ b/snunit/src/snunit/AsyncServerBuilder.scala @@ -67,14 +67,14 @@ object AsyncServerBuilder { } } - private val remove_port: remove_port_t = remove_port_t { - (_: nxt_unit_t_*, ctx: nxt_unit_ctx_t_*, port: nxt_unit_port_t_*) => + private val remove_port: remove_port_t = + remove_port_t { (_: nxt_unit_t_*, ctx: nxt_unit_ctx_t_*, port: nxt_unit_port_t_*) => { if (port.data != null && !ctx.isNull) { PortData.fromPort(port).stop() } } - } + } private val quit: quit_t = quit_t { (ctx: nxt_unit_ctx_t_*) => shutdownHandler { () => @@ -104,7 +104,7 @@ object AsyncServerBuilder { !bytesAvailable > 0 } - val continue = rc == NXT_UNIT_OK || continueReading + val continue = rc == NXT_UNIT_OK // || continueReading if (stopped && PortData.isLastFDStopped) { shutdownHandler { () =>