diff --git a/core/shared/src/main/scala/cats/effect/IO.scala b/core/shared/src/main/scala/cats/effect/IO.scala index 711c7b8d4c..ee3a9da1fc 100644 --- a/core/shared/src/main/scala/cats/effect/IO.scala +++ b/core/shared/src/main/scala/cats/effect/IO.scala @@ -489,11 +489,8 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { def guarantee(finalizer: IO[Unit]): IO[A] = // this is a little faster than the default implementation, which helps Resource IO uncancelable { poll => - val handled = finalizer handleErrorWith { t => - IO.executionContext.flatMap(ec => IO(ec.reportFailure(t))) - } - - poll(this).onCancel(finalizer).onError(_ => handled).flatTap(_ => finalizer) + val onError: PartialFunction[Throwable, IO[Unit]] = { case _ => finalizer.reportError } + poll(this).onCancel(finalizer).onError(onError).flatTap(_ => finalizer) } /** @@ -519,12 +516,10 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { def guaranteeCase(finalizer: OutcomeIO[A @uncheckedVariance] => IO[Unit]): IO[A] = IO.uncancelable { poll => val finalized = poll(this).onCancel(finalizer(Outcome.canceled)) - val handled = finalized.onError { e => - finalizer(Outcome.errored(e)).handleErrorWith { t => - IO.executionContext.flatMap(ec => IO(ec.reportFailure(t))) - } + val onError: PartialFunction[Throwable, IO[Unit]] = { + case e => finalizer(Outcome.errored(e)).reportError } - handled.flatTap(a => finalizer(Outcome.succeeded(IO.pure(a)))) + finalized.onError(onError).flatTap { (a: A) => finalizer(Outcome.succeeded(IO.pure(a))) } } def handleError[B >: A](f: Throwable => B): IO[B] = @@ -588,8 +583,20 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { def onCancel(fin: IO[Unit]): IO[A] = IO.OnCancel(this, fin) - def onError(f: Throwable => IO[Unit]): IO[A] = - handleErrorWith(t => f(t).voidError *> IO.raiseError(t)) + @deprecated("Use onError with PartialFunction argument", "3.6.0") + def onError(f: Throwable => IO[Unit]): IO[A] = { + val pf: PartialFunction[Throwable, IO[Unit]] = { case t => f(t).reportError } + onError(pf) + } + + /** + * Execute a callback on certain errors, then rethrow them. Any non matching error is rethrown + * as well. + * + * Implements `ApplicativeError.onError`. + */ + def onError(pf: PartialFunction[Throwable, IO[Unit]]): IO[A] = + handleErrorWith(t => pf.applyOrElse(t, (_: Throwable) => IO.unit) *> IO.raiseError(t)) /** * Like `Parallel.parProductL` @@ -928,6 +935,19 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] { def void: IO[Unit] = map(_ => ()) + /** + * Similar to [[IO.voidError]], but also reports the error. + */ + private[effect] def reportError(implicit ev: A <:< Unit): IO[Unit] = { + val _ = ev + asInstanceOf[IO[Unit]].handleErrorWith { t => + IO.executionContext.flatMap(ec => IO(ec.reportFailure(t))) + } + } + + /** + * Discard any error raised by the source. + */ def voidError(implicit ev: A <:< Unit): IO[Unit] = { val _ = ev asInstanceOf[IO[Unit]].handleError(_ => ()) @@ -1975,6 +1995,9 @@ object IO extends IOCompanionPlatform with IOLowPriorityImplicits with TuplePara override def handleError[A](fa: IO[A])(f: Throwable => A): IO[A] = fa.handleError(f) + override def onError[A](fa: IO[A])(pf: PartialFunction[Throwable, IO[Unit]]): IO[A] = + fa.onError(pf) + override def timeout[A](fa: IO[A], duration: FiniteDuration)( implicit ev: TimeoutException <:< Throwable): IO[A] = { fa.timeout(duration)