From cfb6c9ce5afba7b811f3c1356194ebc892a07b64 Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 10 Jun 2024 17:16:55 +0200 Subject: [PATCH 01/13] Reorder functions in `NodeRelay.scala` This commit doesn't contain any logical change, we just move code to align with the FSM flow. It makes it easier to follow the progress of the state machine to always scroll down when advancing states. --- .../eclair/payment/relay/NodeRelay.scala | 140 +++++++++--------- 1 file changed, 69 insertions(+), 71 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 49471f82b3..abbaad81bb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -263,60 +263,6 @@ class NodeRelay private(nodeParams: NodeParams, relay(upstream, nextPayload, nextPacket_opt, confidence) } - /** - * Once the payment is forwarded, we're waiting for fail/fulfill responses from downstream nodes. - * - * @param upstream complete HTLC set received. - * @param nextPayload relay instructions. - * @param fulfilledUpstream true if we already fulfilled the payment upstream. - */ - private def sending(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, startedAt: TimestampMilli, fulfilledUpstream: Boolean): Behavior[Command] = - Behaviors.receiveMessagePartial { - rejectExtraHtlcPartialFunction orElse { - // this is the fulfill that arrives from downstream channels - case WrappedPreimageReceived(PreimageReceived(_, paymentPreimage)) => - if (!fulfilledUpstream) { - // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). - context.log.debug("got preimage from downstream") - fulfillPayment(upstream, paymentPreimage) - sending(upstream, nextPayload, startedAt, fulfilledUpstream = true) - } else { - // we don't want to fulfill multiple times - Behaviors.same - } - case WrappedPaymentSent(paymentSent) => - context.log.debug("trampoline payment fully resolved downstream") - success(upstream, fulfilledUpstream, paymentSent) - recordRelayDuration(startedAt, isSuccess = true) - stopping() - case WrappedPaymentFailed(PaymentFailed(_, _, failures, _)) => - context.log.debug(s"trampoline payment failed downstream") - if (!fulfilledUpstream) { - rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload)) - } - recordRelayDuration(startedAt, isSuccess = fulfilledUpstream) - stopping() - } - } - - /** - * Once the downstream payment is settled (fulfilled or failed), we reject new upstream payments while we wait for our parent to stop us. - */ - private def stopping(): Behavior[Command] = { - parent ! NodeRelayer.RelayComplete(context.self, paymentHash, paymentSecret) - Behaviors.receiveMessagePartial { - rejectExtraHtlcPartialFunction orElse { - case Stop => Behaviors.stopped - } - } - } - - private val payFsmAdapters = { - context.messageAdapter[PreimageReceived](WrappedPreimageReceived) - context.messageAdapter[PaymentSent](WrappedPaymentSent) - context.messageAdapter[PaymentFailed](WrappedPaymentFailed) - }.toClassic - private def relay(upstream: Upstream.Hot.Trampoline, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket], confidence: Double): Behavior[Command] = { val displayNodeId = payloadOut match { case payloadOut: IntermediatePayload.NodeRelay.Standard => payloadOut.outgoingNodeId @@ -346,23 +292,6 @@ class NodeRelay private(nodeParams: NodeParams, } } - private def relayToRecipient(upstream: Upstream.Hot.Trampoline, - payloadOut: IntermediatePayload.NodeRelay, - recipient: Recipient, - paymentCfg: SendPaymentConfig, - routeParams: RouteParams, - useMultiPart: Boolean): Behavior[Command] = { - val payment = - if (useMultiPart) { - SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) - } else { - SendPaymentToNode(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) - } - val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart) - payFSM ! payment - sending(upstream, payloadOut, TimestampMilli.now(), fulfilledUpstream = false) - } - /** * Blinded paths in Bolt 12 invoices may encode the introduction node with an scid and a direction: we need to resolve * that to a nodeId in order to reach that introduction node and use the blinded path. @@ -385,6 +314,75 @@ class NodeRelay private(nodeParams: NodeParams, relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, features.hasFeature(Features.BasicMultiPartPayment)) } + private def relayToRecipient(upstream: Upstream.Hot.Trampoline, + payloadOut: IntermediatePayload.NodeRelay, + recipient: Recipient, + paymentCfg: SendPaymentConfig, + routeParams: RouteParams, + useMultiPart: Boolean): Behavior[Command] = { + val payFsmAdapters = { + context.messageAdapter[PreimageReceived](WrappedPreimageReceived) + context.messageAdapter[PaymentSent](WrappedPaymentSent) + context.messageAdapter[PaymentFailed](WrappedPaymentFailed) + }.toClassic + val payment = if (useMultiPart) { + SendMultiPartPayment(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) + } else { + SendPaymentToNode(payFsmAdapters, recipient, nodeParams.maxPaymentAttempts, routeParams) + } + val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart) + payFSM ! payment + sending(upstream, payloadOut, TimestampMilli.now(), fulfilledUpstream = false) + } + + /** + * Once the payment is forwarded, we're waiting for fail/fulfill responses from downstream nodes. + * + * @param upstream complete HTLC set received. + * @param nextPayload relay instructions. + * @param fulfilledUpstream true if we already fulfilled the payment upstream. + */ + private def sending(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, startedAt: TimestampMilli, fulfilledUpstream: Boolean): Behavior[Command] = + Behaviors.receiveMessagePartial { + rejectExtraHtlcPartialFunction orElse { + // this is the fulfill that arrives from downstream channels + case WrappedPreimageReceived(PreimageReceived(_, paymentPreimage)) => + if (!fulfilledUpstream) { + // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). + context.log.debug("got preimage from downstream") + fulfillPayment(upstream, paymentPreimage) + sending(upstream, nextPayload, startedAt, fulfilledUpstream = true) + } else { + // we don't want to fulfill multiple times + Behaviors.same + } + case WrappedPaymentSent(paymentSent) => + context.log.debug("trampoline payment fully resolved downstream") + success(upstream, fulfilledUpstream, paymentSent) + recordRelayDuration(startedAt, isSuccess = true) + stopping() + case WrappedPaymentFailed(PaymentFailed(_, _, failures, _)) => + context.log.debug(s"trampoline payment failed downstream") + if (!fulfilledUpstream) { + rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload)) + } + recordRelayDuration(startedAt, isSuccess = fulfilledUpstream) + stopping() + } + } + + /** + * Once the downstream payment is settled (fulfilled or failed), we reject new upstream payments while we wait for our parent to stop us. + */ + private def stopping(): Behavior[Command] = { + parent ! NodeRelayer.RelayComplete(context.self, paymentHash, paymentSecret) + Behaviors.receiveMessagePartial { + rejectExtraHtlcPartialFunction orElse { + case Stop => Behaviors.stopped + } + } + } + private def rejectExtraHtlcPartialFunction: PartialFunction[Command, Behavior[Command]] = { case Relay(nodeRelayPacket, _) => rejectExtraHtlc(nodeRelayPacket.add) From 54185b0cb1c92aee74531dda10f8875e28fc7cd8 Mon Sep 17 00:00:00 2001 From: t-bast Date: Tue, 11 Jun 2024 15:07:05 +0200 Subject: [PATCH 02/13] Rework node relay FSM flow We refactor `NodeRelay.scala` to re-order some steps, without making meaningful functional changes. The steps are: 1. Fully receive the incoming payment 2. Resolve the next node (unwrap blinded paths if needed) 3. Wake-up the next node if necessary (mobile wallet) 4. Relay outgoing payment Note that we introduce a wake-up step, that will be enriched in future commits and can be extended to include mobile notifications. The file is now also easier to follow, as steps are done linearly by simply scrolling down. --- .../main/scala/fr/acinq/eclair/Setup.scala | 2 +- .../acinq/eclair/io/PeerReadyNotifier.scala | 6 +- .../eclair/payment/relay/NodeRelay.scala | 175 +++++++++--------- .../eclair/payment/relay/NodeRelayer.scala | 9 +- .../acinq/eclair/payment/relay/Relayer.scala | 8 +- .../fr/acinq/eclair/channel/FuzzySpec.scala | 4 +- .../basic/fixtures/MinimalNodeFixture.scala | 3 +- .../payment/PostRestartHtlcCleanerSpec.scala | 2 +- .../payment/relay/NodeRelayerSpec.scala | 132 +------------ .../eclair/payment/relay/RelayerSpec.scala | 3 +- 10 files changed, 116 insertions(+), 228 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index b6ca12c5e5..63587b7b33 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -360,7 +360,7 @@ class Setup(val datadir: File, offerManager = system.spawn(Behaviors.supervise(OfferManager(nodeParams, router, paymentTimeout = 1 minute)).onFailure(typed.SupervisorStrategy.resume), name = "offer-manager") paymentHandler = system.actorOf(SimpleSupervisor.props(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler", SupervisorStrategy.Resume)) triggerer = system.spawn(Behaviors.supervise(AsyncPaymentTriggerer()).onFailure(typed.SupervisorStrategy.resume), name = "async-payment-triggerer") - relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, router, register, paymentHandler, triggerer, Some(postRestartCleanUpInitialized)), "relayer", SupervisorStrategy.Resume)) + relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, router, register, paymentHandler, Some(postRestartCleanUpInitialized)), "relayer", SupervisorStrategy.Resume)) _ = relayer ! PostRestartHtlcCleaner.Init(channels) // Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system, // we want to make sure the handler for post-restart broken HTLCs has finished initializing. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index 81d6c71b5c..4fad93e55e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -88,7 +88,11 @@ object PeerReadyNotifier { context.log.error("no switchboard found") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped - } + } + case Timeout => + context.log.info("timed out finding switchboard actor") + replyTo ! PeerUnavailable(remoteNodeId) + Behaviors.stopped } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index abbaad81bb..9c79098391 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -16,16 +16,17 @@ package fr.acinq.eclair.payment.relay -import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.{TypedActorContextOps, TypedActorRefOps} import akka.actor.typed.scaladsl.{ActorContext, Behaviors} +import akka.actor.typed.{Behavior, SupervisorStrategy} import akka.actor.{ActorRef, typed} import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Upstream} import fr.acinq.eclair.db.PendingCommandsDb +import fr.acinq.eclair.io.PeerReadyNotifier import fr.acinq.eclair.payment.IncomingPaymentPacket.NodeRelayPacket import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ @@ -40,11 +41,12 @@ import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} import java.util.UUID import java.util.concurrent.TimeUnit import scala.collection.immutable.Queue +import scala.concurrent.duration.DurationInt /** * It [[NodeRelay]] aggregates incoming HTLCs (in case multi-part was used upstream) and then forwards the requested amount (using the @@ -62,7 +64,7 @@ object NodeRelay { private case class WrappedPreimageReceived(preimageReceived: PreimageReceived) extends Command private case class WrappedPaymentSent(paymentSent: PaymentSent) extends Command private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command - private[relay] case class WrappedPeerReadyResult(result: AsyncPaymentTriggerer.Result) extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command private case class WrappedResolvedPaths(resolved: Seq[ResolvedPath]) extends Command // @formatter:on @@ -88,7 +90,6 @@ object NodeRelay { relayId: UUID, nodeRelayPacket: NodeRelayPacket, outgoingPaymentFactory: OutgoingPaymentFactory, - triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], router: ActorRef): Behavior[Command] = Behaviors.setup { context => val paymentHash = nodeRelayPacket.add.paymentHash @@ -108,7 +109,7 @@ object NodeRelay { case IncomingPaymentPacket.RelayToTrampolinePacket(_, _, _, nextPacket) => Some(nextPacket) case _: IncomingPaymentPacket.RelayToBlindedPathsPacket => None } - new NodeRelay(nodeParams, parent, register, relayId, paymentHash, nodeRelayPacket.outerPayload.paymentSecret, context, outgoingPaymentFactory, triggerer, router) + new NodeRelay(nodeParams, parent, register, relayId, paymentHash, nodeRelayPacket.outerPayload.paymentSecret, context, outgoingPaymentFactory, router) .receiving(Queue.empty, nodeRelayPacket.innerPayload, nextPacket_opt, incomingPaymentHandler) } } @@ -125,18 +126,29 @@ object NodeRelay { Some(InvalidOnionPayload(UInt64(2), 0)) } else { payloadOut match { - case payloadOut: IntermediatePayload.NodeRelay.Standard => - if (payloadOut.invoiceFeatures.isDefined && payloadOut.paymentSecret.isEmpty) { - Some(InvalidOnionPayload(UInt64(8), 0)) // payment secret field is missing - } else { - None - } - case _: IntermediatePayload.NodeRelay.ToBlindedPaths => - None + // If we're relaying a standard payment to a non-trampoline recipient, we need the payment secret. + case payloadOut: IntermediatePayload.NodeRelay.Standard if payloadOut.invoiceFeatures.isDefined && payloadOut.paymentSecret.isEmpty => Some(InvalidOnionPayload(UInt64(8), 0)) + case _: IntermediatePayload.NodeRelay.Standard => None + case _: IntermediatePayload.NodeRelay.ToBlindedPaths => None } } } + private def shouldWakeUpNextNode(nodeParams: NodeParams, recipient: Recipient): Boolean = { + false + } + + /** When we have identified that the next node is one of our peers, return their (real) nodeId. */ + private def nodeIdToWakeUp(recipient: Recipient): PublicKey = { + recipient match { + case r: ClearRecipient => r.nodeId + case r: SpontaneousRecipient => r.nodeId + case r: TrampolineRecipient => r.nodeId + // When using blinded paths, the recipient nodeId is blinded. The actual node is the introduction of the path. + case r: BlindedRecipient => r.blindedHops.head.nodeId + } + } + /** Compute route params that honor our fee and cltv requirements. */ private def computeRouteParams(nodeParams: NodeParams, amountIn: MilliSatoshi, expiryIn: CltvExpiry, amountOut: MilliSatoshi, expiryOut: CltvExpiry): RouteParams = { nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams @@ -188,7 +200,6 @@ class NodeRelay private(nodeParams: NodeParams, paymentSecret: ByteVector32, context: ActorContext[NodeRelay.Command], outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, - triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], router: ActorRef) { import NodeRelay._ @@ -223,54 +234,13 @@ class NodeRelay private(nodeParams: NodeParams, rejectPayment(upstream, Some(failure)) stopping() case None => - nextPayload match { - // TODO: async payments are not currently supported for blinded recipients. We should update the AsyncPaymentTriggerer to decrypt the blinded path. - case nextPayload: IntermediatePayload.NodeRelay.Standard if nextPayload.isAsyncPayment && nodeParams.features.hasFeature(Features.AsyncPaymentPrototype) => - waitForTrigger(upstream, nextPayload, nextPacket_opt) - case _ => - doSend(upstream, nextPayload, nextPacket_opt) - } + resolveNextNode(upstream, nextPayload, nextPacket_opt) } } - private def waitForTrigger(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay.Standard, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { - context.log.info(s"waiting for async payment to trigger before relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv}, asyncPaymentsParams=${nodeParams.relayParams.asyncPaymentsParams})") - val timeoutBlock = nodeParams.currentBlockHeight + nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks - val safetyBlock = (upstream.expiryIn - nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout).blockHeight - // wait for notification until which ever occurs first: the hold timeout block or the safety block - val notifierTimeout = Seq(timeoutBlock, safetyBlock).min - val peerReadyResultAdapter = context.messageAdapter[AsyncPaymentTriggerer.Result](WrappedPeerReadyResult) - - triggerer ! AsyncPaymentTriggerer.Watch(peerReadyResultAdapter, nextPayload.outgoingNodeId, paymentHash, notifierTimeout) - context.system.eventStream ! EventStream.Publish(WaitingToRelayPayment(nextPayload.outgoingNodeId, paymentHash)) - Behaviors.receiveMessagePartial { - case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTimeout) => - context.log.warn("rejecting async payment; was not triggered before block {}", notifierTimeout) - rejectPayment(upstream, Some(TemporaryNodeFailure())) // TODO: replace failure type when async payment spec is finalized - stopping() - case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentCanceled) => - context.log.warn(s"payment sender canceled a waiting async payment") - rejectPayment(upstream, Some(TemporaryNodeFailure())) // TODO: replace failure type when async payment spec is finalized - stopping() - case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTriggered) => - doSend(upstream, nextPayload, nextPacket_opt) - } - } - - private def doSend(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { - context.log.debug(s"relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv})") - val confidence = (upstream.received.map(_.add.endorsement).min + 0.5) / 8 - relay(upstream, nextPayload, nextPacket_opt, confidence) - } - - private def relay(upstream: Upstream.Hot.Trampoline, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket], confidence: Double): Behavior[Command] = { - val displayNodeId = payloadOut match { - case payloadOut: IntermediatePayload.NodeRelay.Standard => payloadOut.outgoingNodeId - case _: IntermediatePayload.NodeRelay.ToBlindedPaths => randomKey().publicKey - } - val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, displayNodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, confidence) - val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv) - payloadOut match { + /** Once we've fully received the incoming HTLC set, we must identify the next node before forwarding the payment. */ + private def resolveNextNode(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { + nextPayload match { case payloadOut: IntermediatePayload.NodeRelay.Standard => // If invoice features are provided in the onion, the sender is asking us to relay to a non-trampoline recipient. payloadOut.invoiceFeatures match { @@ -278,48 +248,77 @@ class NodeRelay private(nodeParams: NodeParams, val extraEdges = payloadOut.invoiceRoutingInfo.getOrElse(Nil).flatMap(Bolt11Invoice.toExtraEdges(_, payloadOut.outgoingNodeId)) val paymentSecret = payloadOut.paymentSecret.get // NB: we've verified that there was a payment secret in validateRelay val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features(features).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, extraEdges, payloadOut.paymentMetadata) - context.log.debug("sending the payment to non-trampoline recipient (MPP={})", recipient.features.hasFeature(Features.BasicMultiPartPayment)) - relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = recipient.features.hasFeature(Features.BasicMultiPartPayment)) + context.log.debug("forwarding payment to non-trampoline recipient {}", recipient.nodeId) + ensureRecipientReady(upstream, recipient, nextPayload, None) case None => - context.log.debug("sending the payment to the next trampoline node") val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks - val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = packetOut_opt) - relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, useMultiPart = true) + val recipient = ClearRecipient(payloadOut.outgoingNodeId, Features.empty, payloadOut.amountToForward, payloadOut.outgoingCltv, paymentSecret, nextTrampolineOnion_opt = nextPacket_opt) + context.log.debug("forwarding payment to the next trampoline node {}", recipient.nodeId) + ensureRecipientReady(upstream, recipient, nextPayload, nextPacket_opt) } case payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths => + // Blinded paths in Bolt 12 invoices may encode the introduction node with an scid and a direction: we need to + // resolve that to a nodeId in order to reach that introduction node and use the blinded path. + // If we are the introduction node ourselves, we'll also need to decrypt the onion and identify the next node. context.spawnAnonymous(BlindedPathsResolver(nodeParams, paymentHash, router, register)) ! Resolve(context.messageAdapter[Seq[ResolvedPath]](WrappedResolvedPaths), payloadOut.outgoingBlindedPaths) - waitForResolvedPaths(upstream, payloadOut, paymentCfg, routeParams) + Behaviors.receiveMessagePartial { + rejectExtraHtlcPartialFunction orElse { + case WrappedResolvedPaths(resolved) if resolved.isEmpty => + context.log.warn("rejecting trampoline payment to blinded paths: no usable blinded path") + rejectPayment(upstream, Some(UnknownNextPeer())) + stopping() + case WrappedResolvedPaths(resolved) => + // We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient. + val blindedNodeId = resolved.head.route.blindedNodeIds.last + val recipient = BlindedRecipient.fromPaths(blindedNodeId, Features(payloadOut.invoiceFeatures).invoiceFeatures(), payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty) + context.log.debug("forwarding payment to blinded recipient {}", recipient.nodeId) + ensureRecipientReady(upstream, recipient, nextPayload, nextPacket_opt) + } + } } } /** - * Blinded paths in Bolt 12 invoices may encode the introduction node with an scid and a direction: we need to resolve - * that to a nodeId in order to reach that introduction node and use the blinded path. + * The next node may be a mobile wallet directly connected to us: in that case, we'll need to wake them up before + * relaying the payment. */ - private def waitForResolvedPaths(upstream: Upstream.Hot.Trampoline, - payloadOut: IntermediatePayload.NodeRelay.ToBlindedPaths, - paymentCfg: SendPaymentConfig, - routeParams: RouteParams): Behavior[Command] = + private def ensureRecipientReady(upstream: Upstream.Hot.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { + if (shouldWakeUpNextNode(nodeParams, recipient)) { + waitForPeerReady(upstream, recipient, nextPayload, nextPacket_opt) + } else { + relay(upstream, recipient, nextPayload, nextPacket_opt) + } + } + + /** + * The next node is the payment recipient. They are directly connected to us and may be offline. We try to wake them + * up and will relay the payment once they're connected and channels are reestablished. + */ + private def waitForPeerReady(upstream: Upstream.Hot.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { + val nextNodeId = nodeIdToWakeUp(recipient) + context.log.info("trying to wake up next peer (nodeId={})", nextNodeId) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nextNodeId, timeout_opt = Some(Left(30 seconds)))).onFailure(SupervisorStrategy.stop)) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) Behaviors.receiveMessagePartial { - case WrappedResolvedPaths(resolved) if resolved.isEmpty => - context.log.warn(s"rejecting trampoline payment to blinded paths: no usable blinded path") - rejectPayment(upstream, Some(UnknownNextPeer())) - stopping() - case WrappedResolvedPaths(resolved) => - val features = Features(payloadOut.invoiceFeatures).invoiceFeatures() - // We don't have access to the invoice: we use the only node_id that somewhat makes sense for the recipient. - val blindedNodeId = resolved.head.route.blindedNodeIds.last - val recipient = BlindedRecipient.fromPaths(blindedNodeId, features, payloadOut.amountToForward, payloadOut.outgoingCltv, resolved, Set.empty) - context.log.debug("sending the payment to blinded recipient, useMultiPart={}", features.hasFeature(Features.BasicMultiPartPayment)) - relayToRecipient(upstream, payloadOut, recipient, paymentCfg, routeParams, features.hasFeature(Features.BasicMultiPartPayment)) + rejectExtraHtlcPartialFunction orElse { + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + context.log.warn("rejecting payment: failed to wake-up remote peer") + rejectPayment(upstream, Some(UnknownNextPeer())) + stopping() + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) => + relay(upstream, recipient, nextPayload, nextPacket_opt) + } } + } - private def relayToRecipient(upstream: Upstream.Hot.Trampoline, - payloadOut: IntermediatePayload.NodeRelay, - recipient: Recipient, - paymentCfg: SendPaymentConfig, - routeParams: RouteParams, - useMultiPart: Boolean): Behavior[Command] = { + /** Relay the payment to the next identified node: this is similar to sending an outgoing payment. */ + private def relay(upstream: Upstream.Hot.Trampoline, recipient: Recipient, payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket]): Behavior[Command] = { + context.log.debug("relaying trampoline payment (amountIn={} expiryIn={} amountOut={} expiryOut={})", upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv) + val confidence = (upstream.received.map(_.add.endorsement).min + 0.5) / 8 + val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, recipient.nodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, confidence) + val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv) + // If the next node is using trampoline, we assume that they support MPP. + val useMultiPart = recipient.features.hasFeature(Features.BasicMultiPartPayment) || packetOut_opt.nonEmpty val payFsmAdapters = { context.messageAdapter[PreimageReceived](WrappedPreimageReceived) context.messageAdapter[PaymentSent](WrappedPaymentSent) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala index 20d65b1991..75bb545c89 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala @@ -16,7 +16,6 @@ package fr.acinq.eclair.payment.relay -import akka.actor.typed import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.scalacompat.ByteVector32 @@ -58,7 +57,7 @@ object NodeRelayer { * NB: the payment secret used here is different from the invoice's payment secret and ensures we can * group together HTLCs that the previous trampoline node sent in the same MPP. */ - def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], router: akka.actor.ActorRef, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] = + def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, router: akka.actor.ActorRef, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] = Behaviors.setup { context => Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) { Behaviors.receiveMessage { @@ -73,15 +72,15 @@ object NodeRelayer { case None => val relayId = UUID.randomUUID() context.log.debug(s"spawning a new handler with relayId=$relayId") - val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, triggerer, router), relayId.toString) + val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, outgoingPaymentFactory, router), relayId.toString) context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId) handler ! NodeRelay.Relay(nodeRelayPacket, originNode) - apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children + (childKey -> handler)) + apply(nodeParams, register, outgoingPaymentFactory, router, children + (childKey -> handler)) } case RelayComplete(childHandler, paymentHash, paymentSecret) => // we do a back-and-forth between parent and child before stopping the child to prevent a race condition childHandler ! NodeRelay.Stop - apply(nodeParams, register, outgoingPaymentFactory, triggerer, router, children - PaymentKey(paymentHash, paymentSecret)) + apply(nodeParams, register, outgoingPaymentFactory, router, children - PaymentKey(paymentHash, paymentSecret)) case GetPendingPayments(replyTo) => replyTo ! children Behaviors.same diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index f9f5c0039b..d85f9876ac 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -49,7 +49,7 @@ import scala.util.Random * It also receives channel HTLC events (fulfill / failed) and relays those to the appropriate handlers. * It also maintains an up-to-date view of local channel balances. */ -class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], initialized: Option[Promise[Done]] = None) extends Actor with DiagnosticActorLogging { +class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, initialized: Option[Promise[Done]] = None) extends Actor with DiagnosticActorLogging { import Relayer._ @@ -58,7 +58,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, register, initialized), "post-restart-htlc-cleaner") private val channelRelayer = context.spawn(Behaviors.supervise(ChannelRelayer(nodeParams, register)).onFailure(SupervisorStrategy.resume), "channel-relayer") - private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register), triggerer, router)).onFailure(SupervisorStrategy.resume), name = "node-relayer") + private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register), router)).onFailure(SupervisorStrategy.resume), name = "node-relayer") def receive: Receive = { case init: PostRestartHtlcCleaner.Init => postRestartCleaner forward init @@ -120,8 +120,8 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym object Relayer extends Logging { - def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, triggerer: typed.ActorRef[AsyncPaymentTriggerer.Command], initialized: Option[Promise[Done]] = None): Props = - Props(new Relayer(nodeParams, router, register, paymentHandler, triggerer, initialized)) + def props(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paymentHandler: ActorRef, initialized: Option[Promise[Done]] = None): Props = + Props(new Relayer(nodeParams, router, register, paymentHandler, initialized)) // @formatter:off case class RelayFees(feeBase: MilliSatoshi, feeProportionalMillionths: Long) { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala index dc25ecedc6..d3f7f47dac 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/FuzzySpec.scala @@ -66,8 +66,8 @@ class FuzzySpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with Channe val bobRegister = system.actorOf(Props(new TestRegister())) val alicePaymentHandler = system.actorOf(Props(new PaymentHandler(aliceParams, aliceRegister, TestProbe().ref))) val bobPaymentHandler = system.actorOf(Props(new PaymentHandler(bobParams, bobRegister, TestProbe().ref))) - val aliceRelayer = system.actorOf(Relayer.props(aliceParams, TestProbe().ref, aliceRegister, alicePaymentHandler, TestProbe().ref)) - val bobRelayer = system.actorOf(Relayer.props(bobParams, TestProbe().ref, bobRegister, bobPaymentHandler, TestProbe().ref)) + val aliceRelayer = system.actorOf(Relayer.props(aliceParams, TestProbe().ref, aliceRegister, alicePaymentHandler)) + val bobRelayer = system.actorOf(Relayer.props(bobParams, TestProbe().ref, bobRegister, bobPaymentHandler)) val wallet = new DummyOnChainWallet() val alice: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(aliceParams, wallet, bobParams.nodeId, alice2blockchain.ref, aliceRelayer, FakeTxPublisherFactory(alice2blockchain)), alicePeer.ref) val bob: TestFSMRef[ChannelState, ChannelData, Channel] = TestFSMRef(new Channel(bobParams, wallet, aliceParams.nodeId, bob2blockchain.ref, bobRelayer, FakeTxPublisherFactory(bob2blockchain)), bobPeer.ref) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala index 1fcbada4a6..bbb153a276 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala @@ -90,13 +90,12 @@ object MinimalNodeFixture extends Assertions with Eventually with IntegrationPat val bitcoinClient = new TestBitcoinCoreClient() val wallet = new SingleKeyOnChainWallet() val watcher = TestProbe("watcher") - val triggerer = TestProbe("payment-triggerer") val watcherTyped = watcher.ref.toTyped[ZmqWatcher.Command] val register = system.actorOf(Register.props(), "register") val router = system.actorOf(Router.props(nodeParams, watcherTyped), "router") val offerManager = system.spawn(OfferManager(nodeParams, router, 1 minute), "offer-manager") val paymentHandler = system.actorOf(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler") - val relayer = system.actorOf(Relayer.props(nodeParams, router, register, paymentHandler, triggerer.ref.toTyped), "relayer") + val relayer = system.actorOf(Relayer.props(nodeParams, router, register, paymentHandler), "relayer") val txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcherTyped, bitcoinClient) val channelFactory = Peer.SimpleChannelFactory(nodeParams, watcherTyped, relayer, wallet, txPublisherFactory) val pendingChannelsRateLimiter = system.spawnAnonymous(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, Seq())).onFailure(typed.SupervisorStrategy.resume)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index cd45aa1c7e..91dac119f2 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -57,7 +57,7 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit case class FixtureParam(nodeParams: NodeParams, register: TestProbe, sender: TestProbe, eventListener: TestProbe) { def createRelayer(nodeParams1: NodeParams): (ActorRef, ActorRef) = { - val relayer = system.actorOf(Relayer.props(nodeParams1, TestProbe().ref, register.ref, TestProbe().ref, TestProbe().ref.toTyped)) + val relayer = system.actorOf(Relayer.props(nodeParams1, TestProbe().ref, register.ref, TestProbe().ref)) // we need ensure the post-htlc-restart child actor is initialized sender.send(relayer, Relayer.GetChildActors(sender.ref)) (relayer, sender.expectMsgType[Relayer.ChildActors].postRestartCleaner) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 64b270f4cb..d8ce8d9cef 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -26,16 +26,15 @@ import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32, Crypto} +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion} -import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register, Upstream} import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.IncomingPaymentPacket.{RelayToBlindedPathsPacket, RelayToTrampolinePacket} import fr.acinq.eclair.payment.Invoice.ExtraEdge import fr.acinq.eclair.payment._ -import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer.{AsyncPaymentCanceled, AsyncPaymentTimeout, AsyncPaymentTriggered, Watch} import fr.acinq.eclair.payment.relay.NodeRelayer.PaymentKey import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig @@ -49,8 +48,8 @@ import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRou import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PathId, PaymentConstraints} import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{BlockHeight, Bolt11Feature, CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes, randomBytes32, randomKey} +import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike -import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} import java.util.UUID @@ -65,11 +64,11 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import NodeRelayerSpec._ - case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent], triggerer: TestProbe[AsyncPaymentTriggerer.Command]) { + case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent]) { def createNodeRelay(packetIn: IncomingPaymentPacket.NodeRelayPacket, useRealPaymentFactory: Boolean = false): (ActorRef[NodeRelay.Command], TestProbe[NodeRelayer.Command]) = { val parent = TestProbe[NodeRelayer.Command]("parent-relayer") val outgoingPaymentFactory = if (useRealPaymentFactory) RealOutgoingPaymentFactory(this) else FakeOutgoingPaymentFactory(this) - val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, packetIn, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) + val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, register.ref.toClassic, relayId, packetIn, outgoingPaymentFactory, router.ref.toClassic)) (nodeRelay, parent) } } @@ -92,21 +91,19 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl override def withFixture(test: OneArgTest): Outcome = { val nodeParams = TestConstants.Bob.nodeParams .modify(_.multiPartPaymentExpiry).setTo(5 seconds) - .modify(_.features).setToIf(test.tags.contains("async_payments"))(Features(AsyncPaymentPrototype -> Optional)) .modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") val eventListener = TestProbe[PaymentEvent]("event-listener") system.eventStream ! EventStream.Subscribe(eventListener.ref) val mockPayFSM = TestProbe[Any]("pay-fsm") - val triggerer = TestProbe[AsyncPaymentTriggerer.Command]("payment-triggerer") - withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, register, mockPayFSM, eventListener, triggerer))) + withFixture(test.toNoArgTest(FixtureParam(nodeParams, router, register, mockPayFSM, eventListener))) } test("create child handlers for new payments") { f => import f._ val probe = TestProbe[Any]() - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, FakeOutgoingPaymentFactory(f), triggerer.ref, router.ref.toClassic)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, FakeOutgoingPaymentFactory(f), router.ref.toClassic)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(Map.empty) @@ -145,7 +142,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val outgoingPaymentFactory = FakeOutgoingPaymentFactory(f) { - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(Map.empty) } @@ -153,7 +150,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val (paymentHash1, paymentSecret1, child1) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]()) val (paymentHash2, paymentSecret2, child2) = (randomBytes32(), randomBytes32(), TestProbe[NodeRelay.Command]()) val children = Map(PaymentKey(paymentHash1, paymentSecret1) -> child1.ref, PaymentKey(paymentHash2, paymentSecret2) -> child2.ref) - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic, children)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic, children)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(children) @@ -169,7 +166,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val (paymentSecret1, child1) = (randomBytes32(), TestProbe[NodeRelay.Command]()) val (paymentSecret2, child2) = (randomBytes32(), TestProbe[NodeRelay.Command]()) val children = Map(PaymentKey(paymentHash, paymentSecret1) -> child1.ref, PaymentKey(paymentHash, paymentSecret2) -> child2.ref) - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic, children)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic, children)) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) probe.expectMessage(children) @@ -179,7 +176,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl probe.expectMessage(Map(PaymentKey(paymentHash, paymentSecret2) -> child2.ref)) } { - val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, triggerer.ref, router.ref.toClassic)) + val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, register.ref.toClassic, outgoingPaymentFactory, router.ref.toClassic)) parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head, randomKey().publicKey) parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic) val pending1 = probe.expectMessageType[Map[PaymentKey, ActorRef[NodeRelay.Command]]] @@ -335,115 +332,6 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } - test("fail to relay when not triggered before the hold timeout", Tag("async_payments")) { f => - import f._ - - val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) - - // wait until the NodeRelay is waiting for the trigger - eventListener.expectMessageType[WaitingToRelayPayment] - mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger - - // publish notification that peer is unavailable at the timeout height - val peerWatch = triggerer.expectMessageType[Watch] - assert(asyncTimeoutHeight(nodeParams) < asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - assert(peerWatch.timeout == asyncTimeoutHeight(nodeParams)) - peerWatch.replyTo ! AsyncPaymentTimeout - - incomingAsyncPayment.foreach { p => - val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] - assert(fwd.channelId == p.add.channelId) - assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true)) - } - register.expectNoMessage(100 millis) - } - - test("relay the payment when triggered while waiting", Tag("async_payments"), Tag("long_hold_timeout")) { f => - import f._ - - val (nodeRelayer, parent) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) - - // wait until the NodeRelay is waiting for the trigger - eventListener.expectMessageType[WaitingToRelayPayment] - mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger - - // publish notification that peer is ready before the safety interval before the current incoming payment expires (and before the timeout height) - val peerWatch = triggerer.expectMessageType[Watch] - assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - assert(peerWatch.timeout == asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - peerWatch.replyTo ! AsyncPaymentTriggered - - // upstream payment relayed - val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingAsyncPayment.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5) - val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] - validateOutgoingPayment(outgoingPayment) - // those are adapters for pay-fsm messages - val nodeRelayerAdapters = outgoingPayment.replyTo - - // A first downstream HTLC is fulfilled: we should immediately forward the fulfill upstream. - nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingAsyncPayment.foreach { p => - val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] - assert(fwd.channelId == p.add.channelId) - assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) - } - - // Once all the downstream payments have settled, we should emit the relayed event. - nodeRelayerAdapters ! createSuccessEvent() - val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] - validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)).toSet == incomingAsyncPayment.map(i => (i.add.amountMsat, i.add.channelId)).toSet) - assert(relayEvent.outgoing.nonEmpty) - parent.expectMessageType[NodeRelayer.RelayComplete] - register.expectNoMessage(100 millis) - } - - test("fail to relay when not triggered before the incoming expiry safety timeout", Tag("async_payments"), Tag("long_hold_timeout")) { f => - import f._ - - val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) - mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a complete upstream payment - - // publish notification that peer is unavailable at the cancel-safety-before-timeout-block threshold before the current incoming payment expires (and before the timeout height) - val peerWatch = triggerer.expectMessageType[Watch] - assert(asyncTimeoutHeight(nodeParams) > asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - assert(peerWatch.timeout == asyncSafetyHeight(incomingAsyncPayment, nodeParams)) - peerWatch.replyTo ! AsyncPaymentTimeout - - incomingAsyncPayment.foreach { p => - val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] - assert(fwd.channelId == p.add.channelId) - assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true)) - } - - register.expectNoMessage(100 millis) - } - - test("fail to relay payment when canceled by sender before timeout", Tag("async_payments")) { f => - import f._ - - val (nodeRelayer, _) = createNodeRelay(incomingAsyncPayment.head) - incomingAsyncPayment.foreach(p => nodeRelayer ! NodeRelay.Relay(p, randomKey().publicKey)) - - // wait until the NodeRelay is waiting for the trigger - eventListener.expectMessageType[WaitingToRelayPayment] - mockPayFSM.expectNoMessage(100 millis) // we should NOT trigger a downstream payment before we received a trigger - - // fail the payment if waiting when payment sender sends cancel message - nodeRelayer ! NodeRelay.WrappedPeerReadyResult(AsyncPaymentCanceled) - - incomingAsyncPayment.foreach { p => - val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] - assert(fwd.channelId == p.add.channelId) - assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(TemporaryNodeFailure()), commit = true)) - } - register.expectNoMessage(100 millis) - } - test("relay the payment immediately when the async payment feature is disabled") { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala index 8e68b899e3..2c53d7e430 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/RelayerSpec.scala @@ -55,11 +55,10 @@ class RelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") val paymentHandler = TestProbe[Any]("payment-handler") - val triggerer = TestProbe[AsyncPaymentTriggerer.Command]("payment-triggerer") val probe = TestProbe[Any]() // we can't spawn top-level actors with akka typed testKit.spawn(Behaviors.setup[Any] { context => - val relayer = context.toClassic.actorOf(Relayer.props(nodeParams, router.ref.toClassic, register.ref.toClassic, paymentHandler.ref.toClassic, triggerer.ref)) + val relayer = context.toClassic.actorOf(Relayer.props(nodeParams, router.ref.toClassic, register.ref.toClassic, paymentHandler.ref.toClassic)) probe.ref ! relayer Behaviors.empty[Any] }) From f4c13548056fa865cc4be73bad3553932b58c7e4 Mon Sep 17 00:00:00 2001 From: t-bast Date: Tue, 11 Jun 2024 18:02:49 +0200 Subject: [PATCH 03/13] Add wake-up step to channel and message relay We allow relaying data to contain a wallet `node_id` instead of an scid. When that's the case, we start by waking up that wallet node before we try relaying onion messages or payments. --- .../scala/fr/acinq/eclair/NodeParams.scala | 6 +- .../fr/acinq/eclair/io/MessageRelay.scala | 106 +++++---- .../acinq/eclair/io/PeerReadyNotifier.scala | 2 +- .../fr/acinq/eclair/payment/Monitoring.scala | 1 + .../acinq/eclair/payment/PaymentPacket.scala | 2 +- .../eclair/payment/relay/ChannelRelay.scala | 178 +++++++++------ .../eclair/payment/relay/ChannelRelayer.scala | 11 +- .../eclair/payment/relay/NodeRelay.scala | 43 ++-- .../payment/send/BlindedPathsResolver.scala | 62 ++++-- .../eclair/router/BlindedRouteCreation.scala | 19 +- .../eclair/wire/protocol/PaymentOnion.scala | 13 +- .../eclair/wire/protocol/RouteBlinding.scala | 10 +- .../scala/fr/acinq/eclair/TestConstants.scala | 6 +- .../fr/acinq/eclair/crypto/SphinxSpec.scala | 6 +- .../fr/acinq/eclair/io/MessageRelaySpec.scala | 34 ++- .../eclair/payment/PaymentPacketSpec.scala | 18 +- .../payment/relay/ChannelRelayerSpec.scala | 68 +++++- .../payment/relay/NodeRelayerSpec.scala | 205 ++++++++++++------ .../send/BlindedPathsResolverSpec.scala | 34 ++- .../wire/protocol/PaymentOnionSpec.scala | 20 +- 20 files changed, 577 insertions(+), 267 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 1c66625b53..97c7c6ca5d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -87,7 +87,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, blockchainWatchdogSources: Seq[String], onionMessageConfig: OnionMessageConfig, purgeInvoicesInterval: Option[FiniteDuration], - revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config) { + revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config, + wakeUpTimeout: FiniteDuration) { val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey val nodeId: PublicKey = nodeKeyManager.nodeId @@ -611,7 +612,8 @@ object NodeParams extends Logging { revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config( batchSize = config.getInt("db.revoked-htlc-info-cleaner.batch-size"), interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS) - ) + ), + wakeUpTimeout = 30 seconds, ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index a25a166913..7571362736 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -16,10 +16,10 @@ package fr.acinq.eclair.io -import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} +import akka.actor.typed.{Behavior, SupervisorStrategy} import akka.actor.{ActorRef, typed} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -34,6 +34,8 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionMessage import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams, ShortChannelId} +import scala.concurrent.duration.DurationInt + object MessageRelay { // @formatter:off sealed trait Command @@ -44,29 +46,18 @@ object MessageRelay { policy: RelayPolicy, replyTo_opt: Option[typed.ActorRef[Status]]) extends Command case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command - case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command - case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command + private case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command + private case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command - sealed trait Status { - val messageId: ByteVector32 - } + sealed trait Status { val messageId: ByteVector32 } case class Sent(messageId: ByteVector32) extends Status sealed trait Failure extends Status - case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { - override def toString: String = s"Relay prevented by policy $policy" - } - case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { - override def toString: String = s"Can't connect to peer: ${failure.toString}" - } - case class Disconnected(messageId: ByteVector32) extends Failure { - override def toString: String = "Peer is not connected" - } - case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { - override def toString: String = s"Unknown channel: $channelId" - } - case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { - override def toString: String = s"Message dropped: $reason" - } + case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { override def toString: String = s"Relay prevented by policy $policy" } + case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { override def toString: String = s"Can't connect to peer: ${failure.toString}" } + case class Disconnected(messageId: ByteVector32) extends Failure { override def toString: String = "Peer is not connected" } + case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { override def toString: String = s"Unknown channel: $channelId" } + case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { override def toString: String = s"Message dropped: $reason" } sealed trait RelayPolicy case object RelayChannelsOnly extends RelayPolicy @@ -106,7 +97,7 @@ private class MessageRelay(nodeParams: NodeParams, def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = { nextNode match { case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf => - withNextNodeId(msg, nodeParams.nodeId) + withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nodeParams.nodeId)) case Left(outgoingChannelId) => register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId) waitForNextNodeId(msg, outgoingChannelId) @@ -114,7 +105,7 @@ private class MessageRelay(nodeParams: NodeParams, router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1) waitForNextNodeId(msg, scid) case Right(encodedNodeId: EncodedNodeId.WithPublicKey) => - withNextNodeId(msg, encodedNodeId.publicKey) + withNextNodeId(msg, encodedNodeId) } } @@ -127,34 +118,39 @@ private class MessageRelay(nodeParams: NodeParams, Behaviors.stopped case WrappedOptionalNodeId(Some(nextNodeId)) => log.info("found outgoing node {} for channel {}", nextNodeId, channelId) - withNextNodeId(msg, nextNodeId) + withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nextNodeId)) } } - private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { - if (nextNodeId == nodeParams.nodeId) { - OnionMessages.process(nodeParams.privateKey, msg) match { - case OnionMessages.DropMessage(reason) => - Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment() - replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) - Behaviors.stopped - case OnionMessages.SendMessage(nextNode, nextMessage) => - // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. - queryNextNodeId(nextMessage, nextNode) - case received: OnionMessages.ReceiveMessage => - context.system.eventStream ! EventStream.Publish(received) - replyTo_opt.foreach(_ ! Sent(messageId)) - Behaviors.stopped - } - } else { - policy match { - case RelayChannelsOnly => - switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) - waitForPreviousPeerForPolicyCheck(msg, nextNodeId) - case RelayAll => - switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) - waitForConnection(msg, nextNodeId) - } + private def withNextNodeId(msg: OnionMessage, nextNodeId: EncodedNodeId.WithPublicKey): Behavior[Command] = { + nextNodeId match { + case EncodedNodeId.WithPublicKey.Plain(nodeId) if nodeId == nodeParams.nodeId => + OnionMessages.process(nodeParams.privateKey, msg) match { + case OnionMessages.DropMessage(reason) => + Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment() + replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) + Behaviors.stopped + case OnionMessages.SendMessage(nextNode, nextMessage) => + // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. + queryNextNodeId(nextMessage, nextNode) + case received: OnionMessages.ReceiveMessage => + context.system.eventStream ! EventStream.Publish(received) + replyTo_opt.foreach(_ ! Sent(messageId)) + Behaviors.stopped + } + case EncodedNodeId.WithPublicKey.Plain(nodeId) => + policy match { + case RelayChannelsOnly => + switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) + waitForPreviousPeerForPolicyCheck(msg, nodeId) + case RelayAll => + switchboard ! Peer.Connect(nodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) + waitForConnection(msg, nodeId) + } + case EncodedNodeId.WithPublicKey.Wallet(nodeId) => + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) + waitForWalletNodeUp(msg, nodeId) } } @@ -197,4 +193,18 @@ private class MessageRelay(nodeParams: NodeParams, Behaviors.stopped } } + + private def waitForWalletNodeUp(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { + Behaviors.receiveMessagePartial { + case WrappedPeerReadyResult(r: PeerReadyNotifier.PeerReady) => + log.info("successfully woke up {}: relaying onion message", nextNodeId) + r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt) + Behaviors.stopped + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, Tags.Reasons.ConnectionFailure).increment() + log.info("could not wake up {}: onion message cannot be relayed", nextNodeId) + replyTo_opt.foreach(_ ! Disconnected(messageId)) + Behaviors.stopped + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index 4fad93e55e..fbcd09c645 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -83,7 +83,7 @@ object PeerReadyNotifier { case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) => listings.headOption match { case Some(switchboard) => - waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers) + waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers) case None => context.log.error("no switchboard found") replyTo ! PeerUnavailable(remoteNodeId) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala index d9e1c424af..085fa9bc2b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala @@ -119,6 +119,7 @@ object Monitoring { val Failure = "failure" object FailureType { + val WakeUp = "WakeUp" val Remote = "Remote" val Malformed = "MalformedHtlc" diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index 3f17db19c2..4793153076 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -126,7 +126,7 @@ object IncomingPaymentPacket { decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap { case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) => validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket).flatMap { - case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoingChannelId == ShortChannelId.toSelf => + case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoing == Right(ShortChannelId.toSelf) => decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Set(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features) case relayPacket => Right(relayPacket) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index 25b6cbe2c9..f025da43c4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -16,23 +16,24 @@ package fr.acinq.eclair.payment.relay -import akka.actor.typed.Behavior +import akka.actor.ActorRef import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} -import akka.actor.{ActorRef, typed} +import akka.actor.typed.{Behavior, SupervisorStrategy} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.PendingCommandsDb +import fr.acinq.eclair.io.PeerReadyNotifier import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.relay.Relayer.{OutgoingChannel, OutgoingChannelParams} import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket} import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Logs, NodeParams, TimestampMilli, TimestampSecond, channel, nodeFee} +import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, TimestampMilli, TimestampSecond, channel, nodeFee} import java.util.UUID import java.util.concurrent.TimeUnit @@ -44,6 +45,7 @@ object ChannelRelay { // @formatter:off sealed trait Command private case object DoRelay extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command private case class WrappedForwardFailure(failure: Register.ForwardFailure[CMD_ADD_HTLC]) extends Command private case class WrappedAddResponse(res: CommandResponse[CMD_ADD_HTLC]) extends Command // @formatter:on @@ -57,7 +59,7 @@ object ChannelRelay { def apply(nodeParams: NodeParams, register: ActorRef, channels: Map[ByteVector32, Relayer.OutgoingChannel], - originNode:PublicKey, + originNode: PublicKey, relayId: UUID, r: IncomingPaymentPacket.ChannelRelayPacket): Behavior[Command] = Behaviors.setup { context => @@ -67,9 +69,8 @@ object ChannelRelay { paymentHash_opt = Some(r.add.paymentHash), nodeAlias_opt = Some(nodeParams.alias))) { val upstream = Upstream.Hot.Channel(r.add.removeUnknownTlvs(), TimestampMilli.now(), originNode) - context.self ! DoRelay val confidence = (r.add.endorsement + 0.5) / 8 - new ChannelRelay(nodeParams, register, channels, r, upstream, confidence, context).relay(Seq.empty) + new ChannelRelay(nodeParams, register, channels, r, upstream, confidence, context).start() } } @@ -77,7 +78,7 @@ object ChannelRelay { * This helper method translates relaying errors (returned by the downstream outgoing channel) to BOLT 4 standard * errors that we should return upstream. */ - def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { + private def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { (error, channelUpdate_opt) match { case (_: ExpiryTooSmall, Some(channelUpdate)) => ExpiryTooSoon(Some(channelUpdate)) case (_: ExpiryTooBig, _) => ExpiryTooFar() @@ -121,16 +122,55 @@ class ChannelRelay private(nodeParams: NodeParams, private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure) private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse) + private val nextBlindingKey_opt = r.payload match { + case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding) + case _: IntermediatePayload.ChannelRelay.Standard => None + } + + /** Channel id explicitly requested in the onion payload. */ + private val requestedChannelId_opt = r.payload.outgoing match { + case Left(_) => None + case Right(outgoingChannelId) => channels.collectFirst { + case (channelId, channel) if channel.shortIds.localAlias == outgoingChannelId => channelId + case (channelId, channel) if channel.shortIds.real.toOption.contains(outgoingChannelId) => channelId + } + } + private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException]) - def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { + def start(): Behavior[Command] = { + r.payload.outgoing match { + case Left(walletNodeId) => wakeUp(walletNodeId) + case Right(requestedShortChannelId) => + context.self ! DoRelay + relay(Some(requestedShortChannelId), Seq.empty) + } + } + + private def wakeUp(walletNodeId: PublicKey): Behavior[Command] = { + context.log.info("trying to wake up channel peer (nodeId={})", walletNodeId) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) + Behaviors.receiveMessagePartial { + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + Metrics.recordPaymentRelayFailed(Tags.FailureType.WakeUp, Tags.RelayType.Channel) + context.log.info("rejecting htlc: failed to wake-up remote peer") + safeSendAndStop(r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) => + context.self ! DoRelay + relay(None, Seq.empty) + } + } + + def relay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { Behaviors.receiveMessagePartial { case DoRelay => if (previousFailures.isEmpty) { - context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, r.payload.outgoingChannelId, nextNodeId_opt.getOrElse("")) + val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId) + context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, requestedShortChannelId_opt, nextNodeId_opt.getOrElse("")) } context.log.debug("attempting relay previousAttempts={}", previousFailures.size) - handleRelay(previousFailures) match { + handleRelay(requestedShortChannelId_opt, previousFailures) match { case RelayFailure(cmdFail) => Metrics.recordPaymentRelayFailed(Tags.FailureType(cmdFail), Tags.RelayType.Channel) context.log.info("rejecting htlc reason={}", cmdFail.reason) @@ -138,12 +178,12 @@ class ChannelRelay private(nodeParams: NodeParams, case RelaySuccess(selectedChannelId, cmdAdd) => context.log.info("forwarding htlc #{} from channelId={} to channelId={}", r.add.id, r.add.channelId, selectedChannelId) register ! Register.Forward(forwardFailureAdapter, selectedChannelId, cmdAdd) - waitForAddResponse(selectedChannelId, previousFailures) + waitForAddResponse(selectedChannelId, requestedShortChannelId_opt, previousFailures) } } } - def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] = + private def waitForAddResponse(selectedChannelId: ByteVector32, requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, _))) => context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${upstream.add.id}") @@ -154,25 +194,25 @@ class ChannelRelay private(nodeParams: NodeParams, case WrappedAddResponse(addFailed: RES_ADD_FAILED[_]) => context.log.info("attempt failed with reason={}", addFailed.t.getClass.getSimpleName) context.self ! DoRelay - relay(previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) + relay(requestedShortChannelId_opt, previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) - case WrappedAddResponse(r: RES_SUCCESS[_]) => + case WrappedAddResponse(_: RES_SUCCESS[_]) => context.log.debug("sent htlc to the downstream channel") - waitForAddSettled(r.channelId) + waitForAddSettled() } - def waitForAddSettled(channelId: ByteVector32): Behavior[Command] = + private def waitForAddSettled(): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fulfill: HtlcResult.Fulfill)) => - context.log.info("relaying fulfill to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, channelId) + context.log.info("relaying fulfill to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, htlc.channelId) Metrics.relayFulfill(confidence) val cmd = CMD_FULFILL_HTLC(upstream.add.id, fulfill.paymentPreimage, commit = true) context.system.eventStream ! EventStream.Publish(ChannelPaymentRelayed(upstream.amountIn, htlc.amountMsat, htlc.paymentHash, upstream.add.channelId, htlc.channelId, upstream.receivedAt, TimestampMilli.now())) recordRelayDuration(isSuccess = true) safeSendAndStop(upstream.add.channelId, cmd) - case WrappedAddResponse(RES_ADD_SETTLED(_, _, fail: HtlcResult.Fail)) => - context.log.info("relaying fail to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, channelId) + case WrappedAddResponse(RES_ADD_SETTLED(_, htlc, fail: HtlcResult.Fail)) => + context.log.info("relaying fail to upstream, startedAt={}, endedAt={}, confidence={}, originNode={}, outgoingChannel={}", upstream.receivedAt, TimestampMilli.now(), confidence, upstream.receivedFrom, htlc.channelId) Metrics.relayFail(confidence) Metrics.recordPaymentRelayFailed(Tags.FailureType.Remote, Tags.RelayType.Channel) val cmd = translateRelayFailure(upstream.add.id, fail) @@ -180,7 +220,7 @@ class ChannelRelay private(nodeParams: NodeParams, safeSendAndStop(upstream.add.channelId, cmd) } - def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { + private def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { val toSend = cmd match { case _: CMD_FULFILL_HTLC => cmd case _: CMD_FAIL_HTLC | _: CMD_FAIL_MALFORMED_HTLC => r.payload match { @@ -211,49 +251,44 @@ class ChannelRelay private(nodeParams: NodeParams, * - a CMD_FAIL_HTLC to be sent back upstream * - a CMD_ADD_HTLC to propagate downstream */ - def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = { + private def handleRelay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): RelayResult = { val alreadyTried = previousFailures.map(_.channelId) - selectPreferredChannel(alreadyTried) match { - case None if previousFailures.nonEmpty => - // no more channels to try - val error = previousFailures - // we return the error for the initially requested channel if it exists - .find(failure => requestedChannelId_opt.contains(failure.channelId)) - // otherwise we return the error for the first channel tried - .getOrElse(previousFailures.head) - .failure - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(translateLocalError(error.t, error.channelUpdate)), commit = true)) - case outgoingChannel_opt => - relayOrFail(outgoingChannel_opt) + selectPreferredChannel(requestedShortChannelId_opt, alreadyTried) match { + case Some(outgoingChannel) => relayOrFail(outgoingChannel) + case None => + // No more channels to try. + val cmdFail = if (previousFailures.nonEmpty) { + val error = previousFailures + // We return the error for the initially requested channel if it exists. + .find(failure => requestedChannelId_opt.contains(failure.channelId)) + // Otherwise we return the error for the first channel tried. + .getOrElse(previousFailures.head) + .failure + CMD_FAIL_HTLC(r.add.id, Right(translateLocalError(error.t, error.channelUpdate)), commit = true) + } else { + CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true) + } + RelayFailure(cmdFail) } } - /** all the channels point to the same next node, we take the first one */ - private val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId) - - /** channel id explicitly requested in the onion payload */ - private val requestedChannelId_opt = channels.collectFirst { - case (channelId, channel) if channel.shortIds.localAlias == r.payload.outgoingChannelId => channelId - case (channelId, channel) if channel.shortIds.real.toOption.contains(r.payload.outgoingChannelId) => channelId - } - /** * Select a channel to the same node to relay the payment to, that has the lowest capacity and balance and is * compatible in terms of fees, expiry_delta, etc. * * If no suitable channel is found we default to the originally requested channel. */ - def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { - val requestedShortChannelId = r.payload.outgoingChannelId - context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId) + private def selectPreferredChannel(requestedShortChannelId_opt: Option[ShortChannelId], alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { + context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId_opt) // we filter out channels that we have already tried val candidateChannels: Map[ByteVector32, OutgoingChannel] = channels -- alreadyTried // and we filter again to keep the ones that are compatible with this payment (mainly fees, expiry delta) candidateChannels .values .map { channel => - val relayResult = relayOrFail(Some(channel)) - context.log.debug(s"candidate channel: channelId=${channel.channelId} availableForSend={} capacity={} channelUpdate={} result={}", + val relayResult = relayOrFail(channel) + context.log.debug("candidate channel: channelId={} availableForSend={} capacity={} channelUpdate={} result={}", + channel.channelId, channel.commitments.availableBalanceForSend, channel.commitments.latest.capacity, channel.channelUpdate, @@ -279,7 +314,7 @@ class ChannelRelay private(nodeParams: NodeParams, context.log.debug("requested short channel id is our preferred channel") Some(channel) } else { - context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend) + context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId_opt, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend) Some(channel) } case None => @@ -300,28 +335,35 @@ class ChannelRelay private(nodeParams: NodeParams, * channel, because some parameters don't match with our settings for that channel. In that case we directly fail the * htlc. */ - def relayOrFail(outgoingChannel_opt: Option[OutgoingChannelParams]): RelayResult = { - outgoingChannel_opt match { + private def relayOrFail(outgoingChannel: OutgoingChannelParams): RelayResult = { + val update = outgoingChannel.channelUpdate + validateRelayParams(outgoingChannel) match { + case Some(fail) => + RelayFailure(fail) + case None if !update.channelFlags.isEnabled => + RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(ChannelDisabled(update.messageFlags, update.channelFlags, Some(update))), commit = true)) case None => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) - case Some(c) if !c.channelUpdate.channelFlags.isEnabled => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(ChannelDisabled(c.channelUpdate.messageFlags, c.channelUpdate.channelFlags, Some(c.channelUpdate))), commit = true)) - case Some(c) if r.amountToForward < c.channelUpdate.htlcMinimumMsat => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(AmountBelowMinimum(r.amountToForward, Some(c.channelUpdate))), commit = true)) - case Some(c) if r.expiryDelta < c.channelUpdate.cltvExpiryDelta => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(r.outgoingCltv, Some(c.channelUpdate))), commit = true)) - case Some(c) if r.relayFeeMsat < nodeFee(c.channelUpdate.relayFees, r.amountToForward) && - // fees also do not satisfy the previous channel update for `enforcementDelay` seconds after current update - (TimestampSecond.now() - c.channelUpdate.timestamp > nodeParams.relayParams.enforcementDelay || - outgoingChannel_opt.flatMap(_.prevChannelUpdate).forall(c => r.relayFeeMsat < nodeFee(c.relayFees, r.amountToForward))) => - RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(c.channelUpdate))), commit = true)) - case Some(c: OutgoingChannel) => val origin = Origin.Hot(addResponseAdapter.toClassic, upstream) - val nextBlindingKey_opt = r.payload match { - case payload: IntermediatePayload.ChannelRelay.Blinded => Some(payload.nextBlinding) - case _: IntermediatePayload.ChannelRelay.Standard => None - } - RelaySuccess(c.channelId, CMD_ADD_HTLC(addResponseAdapter.toClassic, r.amountToForward, r.add.paymentHash, r.outgoingCltv, r.nextPacket, nextBlindingKey_opt, confidence, origin, commit = true)) + RelaySuccess(outgoingChannel.channelId, CMD_ADD_HTLC(addResponseAdapter.toClassic, r.amountToForward, r.add.paymentHash, r.outgoingCltv, r.nextPacket, nextBlindingKey_opt, confidence, origin, commit = true)) + } + } + + private def validateRelayParams(outgoingChannel: OutgoingChannelParams): Option[CMD_FAIL_HTLC] = { + val update = outgoingChannel.channelUpdate + // If our current channel update was recently created, we accept payments that used our previous channel update. + val allowPreviousUpdate = TimestampSecond.now() - update.timestamp <= nodeParams.relayParams.enforcementDelay + val prevUpdate_opt = if (allowPreviousUpdate) outgoingChannel.prevChannelUpdate else None + val htlcMinimumOk = update.htlcMinimumMsat <= r.amountToForward || prevUpdate_opt.exists(_.htlcMinimumMsat <= r.amountToForward) + val expiryDeltaOk = update.cltvExpiryDelta <= r.expiryDelta || prevUpdate_opt.exists(_.cltvExpiryDelta <= r.expiryDelta) + val feesOk = nodeFee(update.relayFees, r.amountToForward) <= r.relayFeeMsat || prevUpdate_opt.exists(u => nodeFee(u.relayFees, r.amountToForward) <= r.relayFeeMsat) + if (!htlcMinimumOk) { + Some(CMD_FAIL_HTLC(r.add.id, Right(AmountBelowMinimum(r.amountToForward, Some(update))), commit = true)) + } else if (!expiryDeltaOk) { + Some(CMD_FAIL_HTLC(r.add.id, Right(IncorrectCltvExpiry(r.outgoingCltv, Some(update))), commit = true)) + } else if (!feesOk) { + Some(CMD_FAIL_HTLC(r.add.id, Right(FeeInsufficient(r.add.amountMsat, Some(update))), commit = true)) + } else { + None } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala index 39d61a22c4..8c635df706 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala @@ -24,7 +24,7 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.payment.IncomingPaymentPacket -import fr.acinq.eclair.{SubscriptionsComplete, Logs, NodeParams, ShortChannelId} +import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, SubscriptionsComplete} import java.util.UUID import scala.collection.mutable @@ -70,9 +70,12 @@ object ChannelRelayer { Behaviors.receiveMessage { case Relay(channelRelayPacket, originNode) => val relayId = UUID.randomUUID() - val nextNodeId_opt: Option[PublicKey] = scid2channels.get(channelRelayPacket.payload.outgoingChannelId) match { - case Some(channelId) => channels.get(channelId).map(_.nextNodeId) - case None => None + val nextNodeId_opt: Option[PublicKey] = channelRelayPacket.payload.outgoing match { + case Left(walletNodeId) => Some(walletNodeId) + case Right(outgoingChannelId) => scid2channels.get(outgoingChannelId) match { + case Some(channelId) => channels.get(channelId).map(_.nextNodeId) + case None => None + } } val nextChannels: Map[ByteVector32, Relayer.OutgoingChannel] = nextNodeId_opt match { case Some(nextNodeId) => node2channels.get(nextNodeId).flatMap(channels.get).map(c => c.channelId -> c).toMap diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 9c79098391..00bd63e004 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -41,12 +41,11 @@ import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, EncodedNodeId, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} import java.util.UUID import java.util.concurrent.TimeUnit import scala.collection.immutable.Queue -import scala.concurrent.duration.DurationInt /** * It [[NodeRelay]] aggregates incoming HTLCs (in case multi-part was used upstream) and then forwards the requested amount (using the @@ -134,18 +133,22 @@ object NodeRelay { } } - private def shouldWakeUpNextNode(nodeParams: NodeParams, recipient: Recipient): Boolean = { - false - } - - /** When we have identified that the next node is one of our peers, return their (real) nodeId. */ - private def nodeIdToWakeUp(recipient: Recipient): PublicKey = { + /** This function identifies whether the next node is a wallet node directly connected to us, and returns its node_id. */ + private def nextWalletNodeId(nodeParams: NodeParams, recipient: Recipient): Option[PublicKey] = { recipient match { - case r: ClearRecipient => r.nodeId - case r: SpontaneousRecipient => r.nodeId - case r: TrampolineRecipient => r.nodeId - // When using blinded paths, the recipient nodeId is blinded. The actual node is the introduction of the path. - case r: BlindedRecipient => r.blindedHops.head.nodeId + // These two recipients are only used when we're the payment initiator. + case _: SpontaneousRecipient => None + case _: TrampolineRecipient => None + // When relaying to a trampoline node, the next node may be a wallet node directly connected to us, but we don't + // want to have false positives. Feature branches should check an internal DB/cache to confirm. + case r: ClearRecipient if r.nextTrampolineOnion_opt.nonEmpty => None + // If we're relaying to a non-trampoline recipient, it's never a wallet node. + case _: ClearRecipient => None + // When using blinded paths, we may be the introduction node for a wallet node directly connected to us. + case r: BlindedRecipient => r.blindedHops.head.resolved.route match { + case BlindedPathsResolver.PartialBlindedRoute(walletNodeId: EncodedNodeId.WithPublicKey.Wallet, _, _) => Some(walletNodeId.publicKey) + case _ => None + } } } @@ -283,10 +286,9 @@ class NodeRelay private(nodeParams: NodeParams, * relaying the payment. */ private def ensureRecipientReady(upstream: Upstream.Hot.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { - if (shouldWakeUpNextNode(nodeParams, recipient)) { - waitForPeerReady(upstream, recipient, nextPayload, nextPacket_opt) - } else { - relay(upstream, recipient, nextPayload, nextPacket_opt) + nextWalletNodeId(nodeParams, recipient) match { + case Some(walletNodeId) => waitForPeerReady(upstream, walletNodeId, recipient, nextPayload, nextPacket_opt) + case None => relay(upstream, recipient, nextPayload, nextPacket_opt) } } @@ -294,10 +296,9 @@ class NodeRelay private(nodeParams: NodeParams, * The next node is the payment recipient. They are directly connected to us and may be offline. We try to wake them * up and will relay the payment once they're connected and channels are reestablished. */ - private def waitForPeerReady(upstream: Upstream.Hot.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { - val nextNodeId = nodeIdToWakeUp(recipient) - context.log.info("trying to wake up next peer (nodeId={})", nextNodeId) - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nextNodeId, timeout_opt = Some(Left(30 seconds)))).onFailure(SupervisorStrategy.stop)) + private def waitForPeerReady(upstream: Upstream.Hot.Trampoline, walletNodeId: PublicKey, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { + context.log.info("trying to wake up next peer (nodeId={})", walletNodeId) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) Behaviors.receiveMessagePartial { rejectExtraHtlcPartialFunction orElse { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala index a12799d8d2..12ee84f463 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala @@ -14,7 +14,7 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData import fr.acinq.eclair.wire.protocol.{BlindedRouteData, OfferTypes, RouteBlindingEncryptedDataCodecs} -import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams} +import fr.acinq.eclair.{EncodedNodeId, Logs, MilliSatoshiLong, NodeParams, ShortChannelId} import scodec.bits.ByteVector import scala.annotation.tailrec @@ -45,8 +45,8 @@ object BlindedPathsResolver { override val firstNodeId: PublicKey = introductionNodeId } /** A partially unwrapped blinded route that started at our node: it only contains the part of the route after our node. */ - case class PartialBlindedRoute(nextNodeId: PublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute { - override val firstNodeId: PublicKey = nextNodeId + case class PartialBlindedRoute(nextNodeId: EncodedNodeId.WithPublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute { + override val firstNodeId: PublicKey = nextNodeId.publicKey } // @formatter:on @@ -111,8 +111,14 @@ private class BlindedPathsResolver(nodeParams: NodeParams, feeProportionalMillionths = nextFeeProportionalMillionths, cltvExpiryDelta = nextCltvExpiryDelta ) - register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), paymentRelayData.outgoingChannelId) - waitForNextNodeId(nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + paymentRelayData.outgoing match { + case Left(outgoingNodeId) => + // The next node seems to be a wallet node directly connected to us. + validateRelay(EncodedNodeId.WithPublicKey.Wallet(outgoingNodeId), nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + case Right(outgoingChannelId) => + register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), outgoingChannelId) + waitForNextNodeId(outgoingChannelId, nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + } } } case encodedNodeId: EncodedNodeId.WithPublicKey => @@ -129,7 +135,8 @@ private class BlindedPathsResolver(nodeParams: NodeParams, } /** Resolve the next node in the blinded path when we are the introduction node. */ - private def waitForNextNodeId(nextPaymentInfo: OfferTypes.PaymentInfo, + private def waitForNextNodeId(outgoingChannelId: ShortChannelId, + nextPaymentInfo: OfferTypes.PaymentInfo, paymentRelayData: BlindedRouteData.PaymentRelayData, nextBlinding: PublicKey, nextBlindedNodes: Seq[RouteBlinding.BlindedNode], @@ -137,28 +144,41 @@ private class BlindedPathsResolver(nodeParams: NodeParams, resolved: Seq[ResolvedPath]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedNodeId(None) => - context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", paymentRelayData.outgoingChannelId) + context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", outgoingChannelId) resolveBlindedPaths(toResolve, resolved) case WrappedNodeId(Some(nodeId)) if nodeId == nodeParams.nodeId => // The next node in the route is also our node: this is fishy, there is not reason to include us in the route twice. context.log.warn("ignoring blinded path starting at our node relaying to ourselves") resolveBlindedPaths(toResolve, resolved) case WrappedNodeId(Some(nodeId)) => - // Note that we default to private fees if we don't have a channel yet with that node. - // The announceChannel parameter is ignored if we already have a channel. - val relayFees = getRelayFees(nodeParams, nodeId, announceChannel = false) - val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase && - paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths && - paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta - if (shouldRelay) { - context.log.debug("unwrapped blinded path starting at our node: next_node={}", nodeId) - val path = ResolvedPath(PartialBlindedRoute(nodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo) - resolveBlindedPaths(toResolve, resolved :+ path) - } else { - context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta) - resolveBlindedPaths(toResolve, resolved) - } + validateRelay(EncodedNodeId.WithPublicKey.Plain(nodeId), nextPaymentInfo, paymentRelayData, nextBlinding, nextBlindedNodes, toResolve, resolved) + } + + private def validateRelay(nextNodeId: EncodedNodeId.WithPublicKey, + nextPaymentInfo: OfferTypes.PaymentInfo, + paymentRelayData: BlindedRouteData.PaymentRelayData, + nextBlinding: PublicKey, + nextBlindedNodes: Seq[RouteBlinding.BlindedNode], + toResolve: Seq[PaymentBlindedRoute], + resolved: Seq[ResolvedPath]): Behavior[Command] = { + // Note that we default to private fees if we don't have a channel yet with that node. + // The announceChannel parameter is ignored if we already have a channel. + val relayFees = getRelayFees(nodeParams, nextNodeId.publicKey, announceChannel = false) + val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase && + paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths && + paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta && + nextPaymentInfo.feeBase >= 0.msat && + nextPaymentInfo.feeProportionalMillionths >= 0 && + nextPaymentInfo.cltvExpiryDelta.toInt >= 0 + if (shouldRelay) { + context.log.debug("unwrapped blinded path starting at our node: next_node={}", nextNodeId.publicKey) + val path = ResolvedPath(PartialBlindedRoute(nextNodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo) + resolveBlindedPaths(toResolve, resolved :+ path) + } else { + context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta) + resolveBlindedPaths(toResolve, resolved) } + } /** Resolve the introduction node's [[EncodedNodeId.ShortChannelIdDir]] to the corresponding [[EncodedNodeId.WithPublicKey]]. */ private def waitForNodeId(paymentRoute: PaymentBlindedRoute, toResolve: Seq[PaymentBlindedRoute], resolved: Seq[ResolvedPath]): Behavior[Command] = diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala index 81605c5cc1..37e31c7ff8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala @@ -21,7 +21,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.{RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshi, MilliSatoshiLong, randomKey} import scodec.bits.ByteVector object BlindedRouteCreation { @@ -77,7 +77,7 @@ object BlindedRouteCreation { Total: 24 to 36 bytes */ val targetLength = 36 - val paddedPayloads = payloads.map(tlvs =>{ + val paddedPayloads = payloads.map(tlvs => { val payloadLength = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes.length tlvs.copy(records = tlvs.records + RouteBlindingEncryptedDataTlv.Padding(ByteVector.fill(targetLength - payloadLength)(0))) }) @@ -95,4 +95,19 @@ object BlindedRouteCreation { Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) } + /** Create a blinded route where the recipient is a wallet node. */ + def createBlindedRouteToWallet(hop: Router.ChannelHop, pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { + val routeExpiry = routeFinalExpiry + hop.cltvExpiryDelta + val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), + RouteBlindingEncryptedDataTlv.PathId(pathId), + )).require.bytes + val intermediatePayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream[RouteBlindingEncryptedDataTlv]( + RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(hop.nextNodeId)), + RouteBlindingEncryptedDataTlv.PaymentRelay(hop.cltvExpiryDelta, hop.params.relayFees.feeProportionalMillionths, hop.params.relayFees.feeBase), + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), + )).require.bytes + Sphinx.RouteBlinding.create(randomKey(), Seq(hop.nodeId, hop.nextNodeId), Seq(intermediatePayload, finalPayload)) + } + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index 10bd99f443..4468ed7172 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -23,7 +23,7 @@ import fr.acinq.eclair.wire.protocol.BlindedRouteData.PaymentRelayData import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs._ -import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.{BitVector, ByteVector} /** @@ -227,7 +227,8 @@ object PaymentOnion { object IntermediatePayload { sealed trait ChannelRelay extends IntermediatePayload { // @formatter:off - def outgoingChannelId: ShortChannelId + /** The outgoing channel, or the nodeId of one of our peers. */ + def outgoing: Either[PublicKey, ShortChannelId] def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry // @formatter:on @@ -238,7 +239,7 @@ object PaymentOnion { // @formatter:off val amountOut = records.get[AmountToForward].get.amount val cltvOut = records.get[OutgoingCltv].get.cltv - override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId + override val outgoing = Right(records.get[OutgoingChannelId].get.shortChannelId) override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = amountOut override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = cltvOut // @formatter:on @@ -258,12 +259,12 @@ object PaymentOnion { } /** - * @param blindedRecords decrypted tlv stream from the encrypted_recipient_data tlv. - * @param nextBlinding blinding point that must be forwarded to the next hop. + * @param paymentRelayData decrypted relaying data from the encrypted_recipient_data tlv. + * @param nextBlinding blinding point that must be forwarded to the next hop. */ case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], paymentRelayData: PaymentRelayData, nextBlinding: PublicKey) extends ChannelRelay { // @formatter:off - override val outgoingChannelId = paymentRelayData.outgoingChannelId + override val outgoing = paymentRelayData.outgoing override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = paymentRelayData.amountToForward(incomingAmount) override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = paymentRelayData.outgoingCltv(incomingCltv) // @formatter:on diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 096c1f8edf..be53f4aab0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -98,7 +98,11 @@ object BlindedRouteData { } case class PaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]) { - val outgoingChannelId: ShortChannelId = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId + // This is usually a channel, unless the next node is a mobile wallet connected to our node. + val outgoing: Either[PublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match { + case Some(r) => Right(r.shortChannelId) + case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey.Wallet].publicKey) + } val paymentRelay: PaymentRelay = records.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get val allowedFeatures: Features[Feature] = records.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty) @@ -110,7 +114,9 @@ object BlindedRouteData { } def validatePaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, PaymentRelayData] = { - if (records.get[OutgoingChannelId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) + // Note that the BOLTs require using an OutgoingChannelId, but we optionally support a wallet node_id. + if (records.get[OutgoingChannelId].isEmpty && records.get[OutgoingNodeId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) + if (records.get[OutgoingNodeId].nonEmpty && !records.get[OutgoingNodeId].get.nodeId.isInstanceOf[EncodedNodeId.WithPublicKey.Wallet]) return Left(ForbiddenTlv(UInt64(4))) if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) if (records.get[PaymentConstraints].isEmpty) return Left(MissingRequiredTlv(UInt64(12))) if (records.get[PathId].nonEmpty) return Left(ForbiddenTlv(UInt64(6))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 0b0483b437..963a984609 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -231,7 +231,8 @@ object TestConstants { maxAttempts = 2, ), purgeInvoicesInterval = None, - revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis) + revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), + wakeUpTimeout = 30 seconds, ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( @@ -401,7 +402,8 @@ object TestConstants { maxAttempts = 2, ), purgeInvoicesInterval = None, - revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis) + revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), + wakeUpTimeout = 30 seconds, ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index e4f064ee3e..57a93308a1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -505,7 +505,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadBob) = RouteBlindingEncryptedDataCodecs.decode(bob, blinding, tlvsBob.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForCarol = decryptedPayloadBob.nextBlinding val Right(payloadBob) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsBob, decryptedPayloadBob.tlvs, blindingEphemeralKeyForCarol) - assert(payloadBob.outgoingChannelId == ShortChannelId(1)) + assert(payloadBob.outgoing.contains(ShortChannelId(1))) assert(payloadBob.amountToForward(110_125 msat) == 100_125.msat) assert(payloadBob.outgoingCltv(CltvExpiry(749150)) == CltvExpiry(749100)) assert(payloadBob.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(50), 0, 10_000 msat)) @@ -523,7 +523,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadCarol) = RouteBlindingEncryptedDataCodecs.decode(carol, blindingEphemeralKeyForCarol, tlvsCarol.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForDave = decryptedPayloadCarol.nextBlinding val Right(payloadCarol) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsCarol, decryptedPayloadCarol.tlvs, blindingEphemeralKeyForDave) - assert(payloadCarol.outgoingChannelId == ShortChannelId(2)) + assert(payloadCarol.outgoing.contains(ShortChannelId(2))) assert(payloadCarol.amountToForward(100_125 msat) == 100_010.msat) assert(payloadCarol.outgoingCltv(CltvExpiry(749100)) == CltvExpiry(749025)) assert(payloadCarol.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(75), 150, 100 msat)) @@ -544,7 +544,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadDave) = RouteBlindingEncryptedDataCodecs.decode(dave, blindingOverride, tlvsDave.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForEve = decryptedPayloadDave.nextBlinding val Right(payloadDave) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsDave, decryptedPayloadDave.tlvs, blindingEphemeralKeyForEve) - assert(payloadDave.outgoingChannelId == ShortChannelId(3)) + assert(payloadDave.outgoing.contains(ShortChannelId(3))) assert(payloadDave.amountToForward(100_010 msat) == 100_000.msat) assert(payloadDave.outgoingCltv(CltvExpiry(749025)) == CltvExpiry(749000)) assert(payloadDave.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(25), 100, 0 msat)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index 50da4b3c03..63ee150f2f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -19,7 +19,8 @@ package fr.acinq.eclair.io import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe => TypedProbe} import akka.actor.typed.ActorRef import akka.actor.typed.eventstream.EventStream -import akka.actor.typed.scaladsl.adapter.TypedActorRefOps +import akka.actor.typed.receptionist.Receptionist +import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps} import akka.testkit.TestProbe import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -33,8 +34,8 @@ import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient} import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessagePayloadTlv, TlvStream} import fr.acinq.eclair.{EncodedNodeId, RealShortChannelId, ShortChannelId, UInt64, randomBytes32, randomKey} -import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import org.scalatest.{Outcome, Tag} import scodec.bits.HexStringSyntax import scala.concurrent.duration.DurationInt @@ -43,19 +44,24 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val aliceId: PublicKey = Alice.nodeParams.nodeId val bobId: PublicKey = Bob.nodeParams.nodeId + val wakeUpTimeout = "wake_up_timeout" + case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status]) override def withFixture(test: OneArgTest): Outcome = { val switchboard = TestProbe("switchboard")(system.classicSystem) + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped) val register = TestProbe("register")(system.classicSystem) val router = TypedProbe[Router.GetNodeId]("router") val peerConnection = TypedProbe[Nothing]("peerConnection") val peer = TypedProbe[Peer.RelayOnionMessage]("peer") val probe = TypedProbe[Status]("probe") - val relay = testKit.spawn(MessageRelay(Alice.nodeParams, switchboard.ref, register.ref, router.ref)) + val nodeParams = if (test.tags.contains(wakeUpTimeout)) Alice.nodeParams.copy(wakeUpTimeout = 100 millis) else Alice.nodeParams + val relay = testKit.spawn(MessageRelay(nodeParams, switchboard.ref, register.ref, router.ref)) try { withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, probe))) } finally { + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped) testKit.stop(relay) } } @@ -86,6 +92,19 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) } + test("relay after waking up next node") { f => + import f._ + + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + val messageId = randomBytes32() + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, None) + + val request = switchboard.expectMsgType[GetPeerInfo] + assert(request.remoteNodeId == bobId) + request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, bobId, Peer.CONNECTED, None, Set.empty) + assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) + } + test("can't open new connection") { f => import f._ @@ -99,6 +118,15 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app probe.expectMessage(ConnectionFailure(messageId, PeerConnection.ConnectionResult.NoAddressFound)) } + test("can't wake up next node", Tag(wakeUpTimeout)) { f => + import f._ + + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + val messageId = randomBytes32() + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, Some(probe.ref)) + probe.expectMessage(Disconnected(messageId)) + } + test("no channel with previous node") { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index 1aab2b2c9e..6cca2f3c86 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -85,7 +85,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward == amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) @@ -95,7 +95,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_c.amountToForward == amount_cd) assert(relay_c.outgoingCltv == expiry_cd) - assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId)) assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) @@ -105,7 +105,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_d.amountToForward == amount_de) assert(relay_d.outgoingCltv == expiry_de) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) @@ -175,7 +175,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward >= amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) @@ -185,7 +185,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_c.amountToForward >= amount_cd) assert(relay_c.outgoingCltv == expiry_cd) - assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId)) assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) assert(payload_c.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) @@ -196,7 +196,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_d.amountToForward >= amount_de) assert(relay_d.outgoingCltv == expiry_de) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) @@ -238,7 +238,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward >= amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) @@ -547,7 +547,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // A smaller amount is sent to d, who doesn't know that it's invalid. val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.amountToForward < amount_de) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding @@ -569,7 +569,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invalidExpiry = payment.cmd.cltvExpiry - Channel.MIN_CLTV_EXPIRY_DELTA - CltvExpiryDelta(1) val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, invalidExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt, 1.0) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.outgoingCltv < CltvExpiry(currentBlockCount)) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala index a25391031c..20a86e2f0e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala @@ -19,6 +19,7 @@ package fr.acinq.eclair.payment.relay import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.typed import akka.actor.typed.eventstream.EventStream +import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory @@ -29,6 +30,7 @@ import fr.acinq.eclair.TestConstants.emptyOnionPacket import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.io.{Peer, Switchboard} import fr.acinq.eclair.payment.IncomingPaymentPacket.ChannelRelayPacket import fr.acinq.eclair.payment.relay.ChannelRelayer._ import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket, PaymentPacketSpec} @@ -39,19 +41,23 @@ import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload.ChannelRel import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, NodeParams, RealShortChannelId, TestConstants, randomBytes32, _} import org.scalatest.Inside.inside -import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import org.scalatest.{Outcome, Tag} import scodec.bits.HexStringSyntax +import scala.concurrent.duration.DurationInt + class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { import ChannelRelayerSpec._ + val wakeUpTimeout = "wake_up_timeout" + case class FixtureParam(nodeParams: NodeParams, channelRelayer: typed.ActorRef[ChannelRelayer.Command], register: TestProbe[Any]) override def withFixture(test: OneArgTest): Outcome = { // we are node B in the route A -> B -> C -> .... - val nodeParams = TestConstants.Bob.nodeParams + val nodeParams = if (test.tags.contains(wakeUpTimeout)) TestConstants.Bob.nodeParams.copy(wakeUpTimeout = 100 millis) else TestConstants.Bob.nodeParams val register = TestProbe[Any]("register") val channelRelayer = testKit.spawn(ChannelRelayer.apply(nodeParams, register.ref.toClassic)) try { @@ -157,7 +163,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a import f._ val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) - val payload = createBlindedPayload(u.channelUpdate, isIntroduction = false) + val payload = createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction = false) val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) @@ -166,6 +172,30 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7) } + test("relay blinded payment (wake up wallet node)") { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) + Seq(true, false).foreach(isIntroduction => { + val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction) + val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + + channelRelayer ! WrappedLocalChannelUpdate(u) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) + + // We try to wake-up the next node. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7) + }) + + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + } + test("relay with retries") { f => import f._ @@ -270,7 +300,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a Seq(true, false).foreach { isIntroduction => // The outgoing channel is disabled, so we won't be able to relay the payment. val u = createLocalUpdate(channelId1, feeBaseMsat = 5000 msat, feeProportionalMillionths = 0, enabled = false) - val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) @@ -293,6 +323,27 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a } } + test("fail to relay blinded payment (cannot wake up remote node)", Tag(wakeUpTimeout)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) + val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction = true) + val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + + channelRelayer ! WrappedLocalChannelUpdate(u) + channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) + + // We try to wake-up the next node, but we timeout before they connect. + assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId) + val fail = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] + assert(fail.message.reason.contains(InvalidOnionBlinding(Sphinx.hash(r.add.onionRoutingPacket)))) + + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + } + test("relay when expiry larger than our requirements") { f => import f._ @@ -519,7 +570,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a Seq(true, false).foreach { isIntroduction => testCases.foreach { htlcResult => - val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta, endorsementIn = 0) + val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta, endorsementIn = 0) channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) val fwd = expectFwdAdd(register, channelId1, outgoingAmount, outgoingExpiry, 0) @@ -653,13 +704,16 @@ object ChannelRelayerSpec { localAlias2 -> channelId2, ) - def createBlindedPayload(update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = { + def createBlindedPayload(outgoing: Either[PublicKey, ShortChannelId], update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = { val tlvs = TlvStream[OnionPaymentPayloadTlv](Set( Some(OnionPaymentPayloadTlv.EncryptedRecipientData(hex"2a")), if (isIntroduction) Some(OnionPaymentPayloadTlv.BlindingPoint(randomKey().publicKey)) else None, ).flatten[OnionPaymentPayloadTlv]) val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv]( - RouteBlindingEncryptedDataTlv.OutgoingChannelId(update.shortChannelId), + outgoing match { + case Left(nodeId) => RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(nodeId)) + case Right(scid) => RouteBlindingEncryptedDataTlv.OutgoingChannelId(scid) + }, RouteBlindingEncryptedDataTlv.PaymentRelay(update.cltvExpiryDelta, update.feeProportionalMillionths, update.feeBaseMsat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(500_000), 0 msat), ) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index d8ce8d9cef..b20df1f7d1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -20,36 +20,39 @@ import akka.actor.Status import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.typed.ActorRef import akka.actor.typed.eventstream.EventStream +import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.scaladsl.ActorContext import akka.actor.typed.scaladsl.adapter._ import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} -import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32, Crypto} +import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto} import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion} import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register, Upstream} import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.io.{Peer, Switchboard} import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.IncomingPaymentPacket.{RelayToBlindedPathsPacket, RelayToTrampolinePacket} import fr.acinq.eclair.payment.Invoice.ExtraEdge +import fr.acinq.eclair.payment.OutgoingPaymentPacket.NodePayload import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.relay.NodeRelayer.PaymentKey import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment} import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient} -import fr.acinq.eclair.router.Router.RouteRequest -import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound, Router} +import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, RouteRequest} +import fr.acinq.eclair.router.{BalanceTooLow, BlindedRouteCreation, RouteNotFound, Router} import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PathId, PaymentConstraints} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{BlockHeight, Bolt11Feature, CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes, randomBytes32, randomKey} -import org.scalatest.Outcome +import fr.acinq.eclair.{Alias, BlockHeight, Bolt11Feature, Bolt12Feature, CltvExpiry, CltvExpiryDelta, EncodedNodeId, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} import java.util.UUID @@ -64,6 +67,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import NodeRelayerSpec._ + val wakeUpTimeout = "wake_up_timeout" + case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent]) { def createNodeRelay(packetIn: IncomingPaymentPacket.NodeRelayPacket, useRealPaymentFactory: Boolean = false): (ActorRef[NodeRelay.Command], TestProbe[NodeRelayer.Command]) = { val parent = TestProbe[NodeRelayer.Command]("parent-relayer") @@ -92,6 +97,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeParams = TestConstants.Bob.nodeParams .modify(_.multiPartPaymentExpiry).setTo(5 seconds) .modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires + .modify(_.wakeUpTimeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") val eventListener = TestProbe[PaymentEvent]("event-listener") @@ -225,7 +231,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0), FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(outgoingAmount, outgoingExpiry)) nodeRelayer ! NodeRelay.Relay(extra, randomKey().publicKey) // the extra payment will be rejected @@ -254,7 +260,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1000 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0), FinalPayload.Standard.createPayload(1000 msat, incomingAmount, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, outgoingExpiry, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(outgoingAmount, outgoingExpiry)) nodeRelayer ! NodeRelay.Relay(i1, randomKey().publicKey) val fwd1 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -267,7 +273,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl UpdateAddHtlc(randomBytes32(), Random.nextInt(100), 1500 msat, paymentHash, CltvExpiry(499990), TestConstants.emptyOnionPacket, None, 1.0), PaymentOnion.FinalPayload.Standard.createPayload(1500 msat, 1500 msat, CltvExpiry(499990), incomingSecret, None), IntermediatePayload.NodeRelay.Standard(1250 msat, outgoingExpiry, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(outgoingAmount, outgoingExpiry)) nodeRelayer ! NodeRelay.Relay(i2, randomKey().publicKey) val fwd2 = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] @@ -715,26 +721,15 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl } } - def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { - val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes - PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) - } - test("relay to blinded paths without multi-part") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(createPaymentBlindedRoute(outgoingNodeId))) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, None) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -744,7 +739,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeRelayerAdapters = outgoingPayment.replyTo nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) @@ -753,7 +748,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayerAdapters ! createSuccessEvent() val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.outgoing.length == 1) parent.expectMessageType[NodeRelayer.RelayComplete] register.expectNoMessage(100 millis) @@ -762,18 +757,12 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("relay to blinded paths with multi-part") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), payerKey, chain) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), Seq(createPaymentBlindedRoute(outgoingNodeId))) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), None) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] assert(outgoingPayment.recipient.totalAmount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -783,7 +772,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeRelayerAdapters = outgoingPayment.replyTo nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) @@ -792,25 +781,81 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayerAdapters ! createSuccessEvent() val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.outgoing.length == 1) + parent.expectMessageType[NodeRelayer.RelayComplete] + register.expectNoMessage(100 millis) + } + + test("relay to blinded path with wake-up") { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams) + val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) + + // The remote node is a wallet node: we try to wake them up before relaying the payment. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) + val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] + assert(outgoingPayment.recipient.totalAmount == outgoingAmount) + assert(outgoingPayment.recipient.expiry == outgoingExpiry) + assert(outgoingPayment.recipient.isInstanceOf[BlindedRecipient]) + + // those are adapters for pay-fsm messages + val nodeRelayerAdapters = outgoingPayment.replyTo + + nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) + incomingPayments.foreach { p => + val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fwd.channelId == p.add.channelId) + assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) + } + + nodeRelayerAdapters ! createSuccessEvent() + val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] + validateRelayEvent(relayEvent) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.outgoing.length == 1) parent.expectMessageType[NodeRelayer.RelayComplete] register.expectNoMessage(100 millis) } + test("fail to relay to blinded path when wake-up fails", Tag(wakeUpTimeout)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams) + val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) + + // The remote node is a wallet node: we try to wake them up before relaying the payment, but it times out. + assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + mockPayFSM.expectNoMessage(100 millis) + + incomingPayments.foreach { p => + val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == p.add.channelId) + assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true)) + } + } + test("relay to compact blinded paths") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) - val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId) val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) - val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute)) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir)) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) @@ -820,7 +865,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl getNodeId.replyTo ! Some(outgoingNodeId) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingMultiPart.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Hot.Trampoline(incomingPayments.map(p => Upstream.Hot.Channel(p.add, TimestampMilli.now(), randomKey().publicKey))), 5, ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -830,7 +875,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeRelayerAdapters = outgoingPayment.replyTo nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) @@ -839,7 +884,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayerAdapters ! createSuccessEvent() val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.outgoing.length == 1) parent.expectMessageType[NodeRelayer.RelayComplete] register.expectNoMessage(100 millis) @@ -848,16 +893,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("fail to relay to compact blinded paths with unknown scid") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) - val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId) val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) - val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute)) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir)) val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) @@ -868,7 +905,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl mockPayFSM.expectNoMessage(100 millis) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true)) @@ -896,7 +933,9 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl assert(outgoingPayment.recipient.isInstanceOf[ClearRecipient]) val recipient = outgoingPayment.recipient.asInstanceOf[ClearRecipient] assert(recipient.paymentSecret !== incomingSecret) // we should generate a new outgoing secret - assert(recipient.nextTrampolineOnion_opt.contains(nextTrampolinePacket)) + assert(recipient.nextTrampolineOnion_opt.nonEmpty) + // The recipient is able to decrypt the trampoline onion. + recipient.nextTrampolineOnion_opt.foreach(onion => assert(IncomingPaymentPacket.decryptOnion(paymentHash, outgoingNodeKey, onion).isRight)) } def validateRelayEvent(e: TrampolinePaymentRelayed): Unit = { @@ -913,10 +952,7 @@ object NodeRelayerSpec { val paymentPreimage = randomBytes32() val paymentHash = Crypto.sha256(paymentPreimage) - - // This is the result of decrypting the incoming trampoline onion packet. - // It should be forwarded to the next trampoline node. - val nextTrampolinePacket = OnionRoutingPacket(0, hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619", randomBytes(400), randomBytes32()) + val paymentSecret = randomBytes32() val outgoingAmount = 40_000_000 msat val outgoingExpiry = CltvExpiry(490000) @@ -942,6 +978,12 @@ object NodeRelayerSpec { def createSuccessEvent(): PaymentSent = PaymentSent(relayId, paymentHash, paymentPreimage, outgoingAmount, outgoingNodeId, Seq(PaymentSent.PartialPayment(UUID.randomUUID(), outgoingAmount, 10 msat, randomBytes32(), None))) + def createTrampolinePacket(amount: MilliSatoshi, expiry: CltvExpiry): OnionRoutingPacket = { + val payload = NodePayload(outgoingNodeId, FinalPayload.Standard.createPayload(amount, amount, expiry, paymentSecret)) + val Right(onion) = OutgoingPaymentPacket.buildOnion(Seq(payload), paymentHash, None) + onion.packet + } + def createValidIncomingPacket(amountIn: MilliSatoshi, totalAmountIn: MilliSatoshi, expiryIn: CltvExpiry, amountOut: MilliSatoshi, expiryOut: CltvExpiry, endorsementIn: Int = 7): RelayToTrampolinePacket = { val outerPayload = FinalPayload.Standard.createPayload(amountIn, totalAmountIn, expiryIn, incomingSecret, None) val tlvs = TlvStream[UpdateAddHtlcTlv](UpdateAddHtlcTlv.Endorsement(endorsementIn)) @@ -949,7 +991,7 @@ object NodeRelayerSpec { UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, tlvs), outerPayload, IntermediatePayload.NodeRelay.Standard(amountOut, expiryOut, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(amountOut, expiryOut)) } def createPartialIncomingPacket(paymentHash: ByteVector32, paymentSecret: ByteVector32): RelayToTrampolinePacket = { @@ -959,7 +1001,46 @@ object NodeRelayerSpec { UpdateAddHtlc(randomBytes32(), Random.nextInt(100), amountIn, paymentHash, expiryIn, TestConstants.emptyOnionPacket, None, 1.0), FinalPayload.Standard.createPayload(amountIn, incomingAmount, expiryIn, paymentSecret, None), IntermediatePayload.NodeRelay.Standard(outgoingAmount, expiryOut, outgoingNodeId), - nextTrampolinePacket) + createTrampolinePacket(outgoingAmount, expiryOut)) + } + + def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { + val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes + PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) + } + + /** Create payments to a blinded path that starts at a remote node. */ + def createIncomingPaymentsToRemoteBlindedPath(features: Features[Bolt12Feature], scidDir_opt: Option[EncodedNodeId.ShortChannelIdDir]): Seq[RelayToBlindedPathsPacket] = { + val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash) + val request = InvoiceRequest(offer, outgoingAmount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + val paymentBlindedRoute = scidDir_opt match { + case Some(scidDir) => + val nonCompact = createPaymentBlindedRoute(outgoingNodeId) + nonCompact.copy(route = nonCompact.route.copy(introductionNodeId = scidDir)) + case None => + createPaymentBlindedRoute(outgoingNodeId) + } + val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(paymentBlindedRoute)) + incomingMultiPart.map(incoming => { + val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice) + RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload) + }) + } + + /** Create payments to a blinded path that starts at our node and relays to a wallet node. */ + def createIncomingPaymentsToWalletBlindedPath(nodeParams: NodeParams): Seq[RelayToBlindedPathsPacket] = { + val features: Features[Bolt12Feature] = Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional) + val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash) + val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), randomKey(), Block.RegtestGenesisBlock.hash) + val edge = ExtraEdge(nodeParams.nodeId, outgoingNodeId, Alias(561), 2_000_000 msat, 250, CltvExpiryDelta(144), 1 msat, None) + val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, outgoingNodeId, HopRelayParams.FromHint(edge)) + val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, outgoingExpiry).route + val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(outgoingAmount, Seq(hop), CltvExpiryDelta(12)) + val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(PaymentBlindedRoute(route, paymentInfo))) + incomingMultiPart.map(incoming => { + val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice) + RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload) + }) } } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala index 038f15aa88..8a9ae7d641 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.payment.send.BlindedPathsResolver.{FullBlindedRoute, Part import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams} import fr.acinq.eclair.router.{BlindedRouteCreation, Router} import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo -import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey} +import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.HexStringSyntax @@ -151,6 +151,31 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l } } + test("resolve route starting at our node (wallet node)") { f => + import f._ + + val probe = TestProbe() + val walletNodeId = randomKey().publicKey + val edge = ExtraEdge(nodeParams.nodeId, walletNodeId, Alias(561), 5_000_000 msat, 200, CltvExpiryDelta(144), 1 msat, None) + val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, walletNodeId, HopRelayParams.FromHint(edge)) + val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route + val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(100_000_000 msat, Seq(hop), CltvExpiryDelta(12)) + val resolver = testKit.spawn(BlindedPathsResolver(nodeParams, randomBytes32(), router.ref, register.ref)) + resolver ! Resolve(probe.ref, Seq(PaymentBlindedRoute(route, paymentInfo))) + // We are the introduction node: we decrypt the payload and discover that the next node is a wallet node. + val resolved = probe.expectMsgType[Seq[ResolvedPath]] + assert(resolved.size == 1) + assert(resolved.head.route.isInstanceOf[PartialBlindedRoute]) + val partialRoute = resolved.head.route.asInstanceOf[PartialBlindedRoute] + assert(partialRoute.firstNodeId == walletNodeId) + assert(partialRoute.nextNodeId == EncodedNodeId.WithPublicKey.Wallet(walletNodeId)) + assert(partialRoute.blindedNodes == route.subsequentNodes) + assert(partialRoute.nextBlinding != route.blindingKey) + // We don't need to resolve the nodeId. + register.expectNoMessage(100 millis) + router.expectNoMessage(100 millis) + } + test("ignore blinded paths that cannot be resolved") { f => import f._ @@ -181,8 +206,9 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l val probe = TestProbe() val scid = RealShortChannelId(BlockHeight(750_000), 3, 7) - val edgeLowFees = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None) - val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None) + val nextNodeId = randomKey().publicKey + val edgeLowFees = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None) + val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None) val toResolve = Seq( // We don't allow paying blinded routes to ourselves. BlindedRouteCreation.createBlindedRouteWithoutHops(nodeParams.nodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, @@ -190,6 +216,8 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes with low cltv_expiry_delta. BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowExpiryDelta.targetNodeId, HopRelayParams.FromHint(edgeLowExpiryDelta))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, + // We reject blinded routes with low fees, even when the next node seems to be a wallet node. + BlindedRouteCreation.createBlindedRouteToWallet(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees)), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes that cannot be decrypted. BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route.copy(blindingKey = randomKey().publicKey) ).map(r => PaymentBlindedRoute(r, PaymentInfo(1_000_000 msat, 2500, CltvExpiryDelta(300), 1 msat, 500_000_000 msat, Features.empty))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index 13e22eca5e..c0b691b588 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -89,7 +89,7 @@ class PaymentOnionSpec extends AnyFunSuite { val Right(payload) = IntermediatePayload.ChannelRelay.Standard.validate(decoded) assert(payload.amountOut == 561.msat) assert(payload.cltvOut == CltvExpiry(42)) - assert(payload.outgoingChannelId == ShortChannelId(1105)) + assert(payload.outgoing.contains(ShortChannelId(1105))) val encoded = perHopPayloadCodec.encode(expected).require.bytes assert(encoded == bin) } @@ -110,7 +110,7 @@ class PaymentOnionSpec extends AnyFunSuite { val decoded = perHopPayloadCodec.decode(bin.bits).require.value assert(decoded == expected) val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(decoded, blindedTlvs, randomKey().publicKey) - assert(payload.outgoingChannelId == ShortChannelId(42)) + assert(payload.outgoing.contains(ShortChannelId(42))) assert(payload.amountToForward(10_000 msat) == 9990.msat) assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856)) assert(payload.paymentRelayData.allowedFeatures.isEmpty) @@ -119,6 +119,20 @@ class PaymentOnionSpec extends AnyFunSuite { } } + test("encode/decode channel relay blinded per-hop-payload (with wallet node_id)") { + val walletNodeId = PublicKey(hex"0221cd519eba9c8b840a5e40b65dc2c040e159a766979723ed770efceb97260ec8") + val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv]( + RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(walletNodeId)), + RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), + RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat), + ) + val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(TlvStream(EncryptedRecipientData(hex"deadbeef")), blindedTlvs, randomKey().publicKey) + assert(payload.outgoing == Left(walletNodeId)) + assert(payload.amountToForward(10_000 msat) == 9990.msat) + assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856)) + assert(payload.paymentRelayData.allowedFeatures.isEmpty) + } + test("encode/decode node relay per-hop payload") { val nodeId = PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619") val expected = TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingNodeId(nodeId)) @@ -292,6 +306,8 @@ class PaymentOnionSpec extends AnyFunSuite { TestCase(MissingRequiredTlv(UInt64(10)), hex"23 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", validBlindedTlvs), // Missing encrypted outgoing channel. TestCase(MissingRequiredTlv(UInt64(2)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), + // Forbidden encrypted outgoing plain node_id. + TestCase(ForbiddenTlv(UInt64(4)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Plain(randomKey().publicKey)), RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), // Missing encrypted payment relay data. TestCase(MissingRequiredTlv(UInt64(10)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId(42)), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), // Missing encrypted payment constraint. From a395f9fa0270cb1f44155ad35100264686d14d65 Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 26 Aug 2024 11:42:15 +0200 Subject: [PATCH 04/13] Create `peer-wake-up` config section Create a dedicated `peer-wake-up` configuration section. This can be enriched with mobile notification sub-sections. --- eclair-core/src/main/resources/reference.conf | 6 ++++++ .../src/main/scala/fr/acinq/eclair/NodeParams.scala | 8 +++++--- .../src/main/scala/fr/acinq/eclair/io/MessageRelay.scala | 4 +--- .../main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala | 2 ++ .../fr/acinq/eclair/payment/relay/ChannelRelay.scala | 2 +- .../scala/fr/acinq/eclair/payment/relay/NodeRelay.scala | 2 +- .../src/test/scala/fr/acinq/eclair/TestConstants.scala | 6 +++--- .../test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala | 3 ++- .../acinq/eclair/payment/relay/ChannelRelayerSpec.scala | 3 ++- .../fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala | 2 +- 10 files changed, 24 insertions(+), 14 deletions(-) diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index fd58844a1d..79d0438760 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -318,6 +318,12 @@ eclair { max-no-channels = 250 // maximum number of incoming connections from peers that do not have any channels with us } + // When relaying payments or messages to mobile peers who are disconnected, we may try to wake them up using a mobile + // notification system, or we attempt connecting to the last known address. + peer-wake-up { + timeout = 60 seconds + } + auto-reconnect = true initial-random-reconnect-delay = 5 seconds // we add a random delay before the first reconnection attempt, capped by this value max-reconnect-interval = 1 hour // max interval between two reconnection attempts, after the exponential backoff period diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 97c7c6ca5d..9c7858f1f1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -28,7 +28,7 @@ import fr.acinq.eclair.crypto.Noise.KeyPair import fr.acinq.eclair.crypto.keymanager.{ChannelKeyManager, NodeKeyManager, OnChainKeyManager} import fr.acinq.eclair.db._ import fr.acinq.eclair.io.MessageRelay.{RelayAll, RelayChannelsOnly, RelayPolicy} -import fr.acinq.eclair.io.PeerConnection +import fr.acinq.eclair.io.{PeerConnection, PeerReadyNotifier} import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} import fr.acinq.eclair.router.Announcements.AddressException @@ -88,7 +88,7 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, onionMessageConfig: OnionMessageConfig, purgeInvoicesInterval: Option[FiniteDuration], revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config, - wakeUpTimeout: FiniteDuration) { + peerWakeUpConfig: PeerReadyNotifier.WakeUpConfig) { val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey val nodeId: PublicKey = nodeKeyManager.nodeId @@ -613,7 +613,9 @@ object NodeParams extends Logging { batchSize = config.getInt("db.revoked-htlc-info-cleaner.batch-size"), interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS) ), - wakeUpTimeout = 30 seconds, + peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig( + timeout = FiniteDuration(config.getDuration("peer-wake-up.timeout").getSeconds, TimeUnit.SECONDS) + ), ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index 7571362736..5c607083c0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -34,8 +34,6 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionMessage import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams, ShortChannelId} -import scala.concurrent.duration.DurationInt - object MessageRelay { // @formatter:off sealed trait Command @@ -148,7 +146,7 @@ private class MessageRelay(nodeParams: NodeParams, waitForConnection(msg, nodeId) } case EncodedNodeId.WithPublicKey.Wallet(nodeId) => - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))).onFailure(SupervisorStrategy.stop)) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) waitForWalletNodeUp(msg, nodeId) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index fbcd09c645..ac394f2d6f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -33,6 +33,8 @@ import scala.concurrent.duration.{DurationInt, FiniteDuration} */ object PeerReadyNotifier { + case class WakeUpConfig(timeout: FiniteDuration) + // @formatter:off sealed trait Command case class NotifyWhenPeerReady(replyTo: ActorRef[Result]) extends Command diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index f025da43c4..7b8c0a212a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -149,7 +149,7 @@ class ChannelRelay private(nodeParams: NodeParams, private def wakeUp(walletNodeId: PublicKey): Behavior[Command] = { context.log.info("trying to wake up channel peer (nodeId={})", walletNodeId) - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))).onFailure(SupervisorStrategy.stop)) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) Behaviors.receiveMessagePartial { case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 00bd63e004..34727b1171 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -298,7 +298,7 @@ class NodeRelay private(nodeParams: NodeParams, */ private def waitForPeerReady(upstream: Upstream.Hot.Trampoline, walletNodeId: PublicKey, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { context.log.info("trying to wake up next peer (nodeId={})", walletNodeId) - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))).onFailure(SupervisorStrategy.stop)) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) Behaviors.receiveMessagePartial { rejectExtraHtlcPartialFunction orElse { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 963a984609..72005dfd75 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -26,7 +26,7 @@ import fr.acinq.eclair.channel.{ChannelFlags, LocalParams, Origin, Upstream} import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager} import fr.acinq.eclair.db.RevokedHtlcInfoCleaner import fr.acinq.eclair.io.MessageRelay.RelayAll -import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection} +import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection, PeerReadyNotifier} import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} import fr.acinq.eclair.router.Graph.{MessagePath, WeightRatios} @@ -232,7 +232,7 @@ object TestConstants { ), purgeInvoicesInterval = None, revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), - wakeUpTimeout = 30 seconds, + peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(30 seconds), ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( @@ -403,7 +403,7 @@ object TestConstants { ), purgeInvoicesInterval = None, revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), - wakeUpTimeout = 30 seconds, + peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(30 seconds), ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index 63ee150f2f..c296602bc1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -28,6 +28,7 @@ import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair.channel.Register import fr.acinq.eclair.io.MessageRelay._ import fr.acinq.eclair.io.Peer.{PeerInfo, PeerNotFound} +import fr.acinq.eclair.io.PeerReadyNotifier.WakeUpConfig import fr.acinq.eclair.io.Switchboard.GetPeerInfo import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient} @@ -56,7 +57,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val peerConnection = TypedProbe[Nothing]("peerConnection") val peer = TypedProbe[Peer.RelayOnionMessage]("peer") val probe = TypedProbe[Status]("probe") - val nodeParams = if (test.tags.contains(wakeUpTimeout)) Alice.nodeParams.copy(wakeUpTimeout = 100 millis) else Alice.nodeParams + val nodeParams = if (test.tags.contains(wakeUpTimeout)) Alice.nodeParams.copy(peerWakeUpConfig = WakeUpConfig(100 millis)) else Alice.nodeParams val relay = testKit.spawn(MessageRelay(nodeParams, switchboard.ref, register.ref, router.ref)) try { withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, probe))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala index 20a86e2f0e..c705d21621 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala @@ -30,6 +30,7 @@ import fr.acinq.eclair.TestConstants.emptyOnionPacket import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.io.PeerReadyNotifier.WakeUpConfig import fr.acinq.eclair.io.{Peer, Switchboard} import fr.acinq.eclair.payment.IncomingPaymentPacket.ChannelRelayPacket import fr.acinq.eclair.payment.relay.ChannelRelayer._ @@ -57,7 +58,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a override def withFixture(test: OneArgTest): Outcome = { // we are node B in the route A -> B -> C -> .... - val nodeParams = if (test.tags.contains(wakeUpTimeout)) TestConstants.Bob.nodeParams.copy(wakeUpTimeout = 100 millis) else TestConstants.Bob.nodeParams + val nodeParams = if (test.tags.contains(wakeUpTimeout)) TestConstants.Bob.nodeParams.copy(peerWakeUpConfig = WakeUpConfig(100 millis)) else TestConstants.Bob.nodeParams val register = TestProbe[Any]("register") val channelRelayer = testKit.spawn(ChannelRelayer.apply(nodeParams, register.ref.toClassic)) try { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index b20df1f7d1..71e120befa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -97,7 +97,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeParams = TestConstants.Bob.nodeParams .modify(_.multiPartPaymentExpiry).setTo(5 seconds) .modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires - .modify(_.wakeUpTimeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) + .modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") val eventListener = TestProbe[PaymentEvent]("event-listener") From 3a3d39e4cbe6bfc89b6f2a7379fa5dadbe95d27d Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 26 Aug 2024 12:25:08 +0200 Subject: [PATCH 05/13] Make `walletNodeId_opt` a field in `ChannelRelay` It turns out that we can keep `requestedShortChannelId_opt` and `walletNodeId_opt` as fields of this actor instead of resolving them once and forwarding the value in every function. --- .../eclair/payment/relay/ChannelRelay.scala | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index 7b8c0a212a..318e064d58 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -33,7 +33,7 @@ import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket} import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, TimestampMilli, TimestampSecond, channel, nodeFee} +import fr.acinq.eclair.{Logs, NodeParams, TimestampMilli, TimestampSecond, channel, nodeFee} import java.util.UUID import java.util.concurrent.TimeUnit @@ -78,7 +78,7 @@ object ChannelRelay { * This helper method translates relaying errors (returned by the downstream outgoing channel) to BOLT 4 standard * errors that we should return upstream. */ - private def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { + private def translateLocalError(error: ChannelException, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { (error, channelUpdate_opt) match { case (_: ExpiryTooSmall, Some(channelUpdate)) => ExpiryTooSoon(Some(channelUpdate)) case (_: ExpiryTooBig, _) => ExpiryTooFar() @@ -136,14 +136,19 @@ class ChannelRelay private(nodeParams: NodeParams, } } + private val (requestedShortChannelId_opt, walletNodeId_opt) = r.payload.outgoing match { + case Left(walletNodeId) => (None, Some(walletNodeId)) + case Right(shortChannelId) => (Some(shortChannelId), None) + } + private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException]) def start(): Behavior[Command] = { - r.payload.outgoing match { - case Left(walletNodeId) => wakeUp(walletNodeId) - case Right(requestedShortChannelId) => + walletNodeId_opt match { + case Some(walletNodeId) => wakeUp(walletNodeId) + case None => context.self ! DoRelay - relay(Some(requestedShortChannelId), Seq.empty) + relay(Seq.empty) } } @@ -158,11 +163,11 @@ class ChannelRelay private(nodeParams: NodeParams, safeSendAndStop(r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) => context.self ! DoRelay - relay(None, Seq.empty) + relay(Seq.empty) } } - def relay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { + def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { Behaviors.receiveMessagePartial { case DoRelay => if (previousFailures.isEmpty) { @@ -170,7 +175,7 @@ class ChannelRelay private(nodeParams: NodeParams, context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, requestedShortChannelId_opt, nextNodeId_opt.getOrElse("")) } context.log.debug("attempting relay previousAttempts={}", previousFailures.size) - handleRelay(requestedShortChannelId_opt, previousFailures) match { + handleRelay(previousFailures) match { case RelayFailure(cmdFail) => Metrics.recordPaymentRelayFailed(Tags.FailureType(cmdFail), Tags.RelayType.Channel) context.log.info("rejecting htlc reason={}", cmdFail.reason) @@ -178,12 +183,12 @@ class ChannelRelay private(nodeParams: NodeParams, case RelaySuccess(selectedChannelId, cmdAdd) => context.log.info("forwarding htlc #{} from channelId={} to channelId={}", r.add.id, r.add.channelId, selectedChannelId) register ! Register.Forward(forwardFailureAdapter, selectedChannelId, cmdAdd) - waitForAddResponse(selectedChannelId, requestedShortChannelId_opt, previousFailures) + waitForAddResponse(selectedChannelId, previousFailures) } } } - private def waitForAddResponse(selectedChannelId: ByteVector32, requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = + private def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, _))) => context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${upstream.add.id}") @@ -194,7 +199,7 @@ class ChannelRelay private(nodeParams: NodeParams, case WrappedAddResponse(addFailed: RES_ADD_FAILED[_]) => context.log.info("attempt failed with reason={}", addFailed.t.getClass.getSimpleName) context.self ! DoRelay - relay(requestedShortChannelId_opt, previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) + relay(previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) case WrappedAddResponse(_: RES_SUCCESS[_]) => context.log.debug("sent htlc to the downstream channel") @@ -251,9 +256,9 @@ class ChannelRelay private(nodeParams: NodeParams, * - a CMD_FAIL_HTLC to be sent back upstream * - a CMD_ADD_HTLC to propagate downstream */ - private def handleRelay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): RelayResult = { + private def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = { val alreadyTried = previousFailures.map(_.channelId) - selectPreferredChannel(requestedShortChannelId_opt, alreadyTried) match { + selectPreferredChannel(alreadyTried) match { case Some(outgoingChannel) => relayOrFail(outgoingChannel) case None => // No more channels to try. @@ -278,7 +283,7 @@ class ChannelRelay private(nodeParams: NodeParams, * * If no suitable channel is found we default to the originally requested channel. */ - private def selectPreferredChannel(requestedShortChannelId_opt: Option[ShortChannelId], alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { + private def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId_opt) // we filter out channels that we have already tried val candidateChannels: Map[ByteVector32, OutgoingChannel] = channels -- alreadyTried From fbc4577d681eaee5400bb94084a67cfc1ca992a8 Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 26 Aug 2024 12:32:56 +0200 Subject: [PATCH 06/13] Include `recipient` in `NodeRelay.sending` This field is used in feature branches, so we include it to minimize the diff. --- .../fr/acinq/eclair/payment/relay/NodeRelay.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 34727b1171..aa16f3c8eb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -332,7 +332,7 @@ class NodeRelay private(nodeParams: NodeParams, } val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, useMultiPart) payFSM ! payment - sending(upstream, payloadOut, TimestampMilli.now(), fulfilledUpstream = false) + sending(upstream, payloadOut, recipient, TimestampMilli.now(), fulfilledUpstream = false) } /** @@ -342,7 +342,11 @@ class NodeRelay private(nodeParams: NodeParams, * @param nextPayload relay instructions. * @param fulfilledUpstream true if we already fulfilled the payment upstream. */ - private def sending(upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay, startedAt: TimestampMilli, fulfilledUpstream: Boolean): Behavior[Command] = + private def sending(upstream: Upstream.Hot.Trampoline, + nextPayload: IntermediatePayload.NodeRelay, + recipient: Recipient, + startedAt: TimestampMilli, + fulfilledUpstream: Boolean): Behavior[Command] = Behaviors.receiveMessagePartial { rejectExtraHtlcPartialFunction orElse { // this is the fulfill that arrives from downstream channels @@ -351,7 +355,7 @@ class NodeRelay private(nodeParams: NodeParams, // We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream). context.log.debug("got preimage from downstream") fulfillPayment(upstream, paymentPreimage) - sending(upstream, nextPayload, startedAt, fulfilledUpstream = true) + sending(upstream, nextPayload, recipient, startedAt, fulfilledUpstream = true) } else { // we don't want to fulfill multiple times Behaviors.same From 893aac2c9914d7cabf544ca576baaa337e0d2d9a Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 26 Aug 2024 14:26:09 +0200 Subject: [PATCH 07/13] Introduce `PeerReadyNotifier` private class This commit doesn't contain any changes apart from using a private class to factor common fields of the `PeerReadyNotifier`. --- .../acinq/eclair/io/PeerReadyNotifier.scala | 113 ++++++++++-------- 1 file changed, 62 insertions(+), 51 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index ac394f2d6f..c6defa3533 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -72,33 +72,75 @@ object PeerReadyNotifier { // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => SomePeerConnected(e.nodeId))) context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => SomePeerDisconnected(e.nodeId))) - findSwitchboard(replyTo, remoteNodeId, context, timers) + new PeerReadyNotifier(replyTo, remoteNodeId, context, timers).findSwitchboard() } } } } } - private def findSwitchboard(replyTo: ActorRef[Result], remoteNodeId: PublicKey, context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = { + // We use an exhaustive pattern matching here to ensure we explicitly handle future new channel states. + // We only want to test that channels are not in an uninitialized state, we don't need them to be available to relay + // payments (channels closing or waiting to confirm are "ready" for our purposes). + private def isChannelReady(state: channel.ChannelState): Boolean = state match { + case channel.WAIT_FOR_INIT_INTERNAL => false + case channel.WAIT_FOR_INIT_SINGLE_FUNDED_CHANNEL => false + case channel.WAIT_FOR_INIT_DUAL_FUNDED_CHANNEL => false + case channel.OFFLINE => false + case channel.SYNCING => false + case channel.WAIT_FOR_OPEN_CHANNEL => true + case channel.WAIT_FOR_ACCEPT_CHANNEL => true + case channel.WAIT_FOR_FUNDING_INTERNAL => true + case channel.WAIT_FOR_FUNDING_CREATED => true + case channel.WAIT_FOR_FUNDING_SIGNED => true + case channel.WAIT_FOR_FUNDING_CONFIRMED => true + case channel.WAIT_FOR_CHANNEL_READY => true + case channel.WAIT_FOR_OPEN_DUAL_FUNDED_CHANNEL => true + case channel.WAIT_FOR_ACCEPT_DUAL_FUNDED_CHANNEL => true + case channel.WAIT_FOR_DUAL_FUNDING_CREATED => true + case channel.WAIT_FOR_DUAL_FUNDING_SIGNED => true + case channel.WAIT_FOR_DUAL_FUNDING_CONFIRMED => true + case channel.WAIT_FOR_DUAL_FUNDING_READY => true + case channel.NORMAL => true + case channel.SHUTDOWN => true + case channel.NEGOTIATING => true + case channel.CLOSING => true + case channel.CLOSED => true + case channel.WAIT_FOR_REMOTE_PUBLISH_FUTURE_COMMITMENT => true + case channel.ERR_INFORMATION_LEAK => true + } + +} + +private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], + remoteNodeId: PublicKey, + context: ActorContext[PeerReadyNotifier.Command], + timers: TimerScheduler[PeerReadyNotifier.Command]) { + + import PeerReadyNotifier._ + + private val log = context.log + + private def findSwitchboard(): Behavior[Command] = { context.system.receptionist ! Receptionist.Find(Switchboard.SwitchboardServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) Behaviors.receiveMessagePartial { case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) => listings.headOption match { case Some(switchboard) => - waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers) + waitForPeerConnected(switchboard) case None => - context.log.error("no switchboard found") + log.error("no switchboard found") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped } case Timeout => - context.log.info("timed out finding switchboard actor") + log.info("timed out finding switchboard actor") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped } } - private def waitForPeerConnected(replyTo: ActorRef[Result], remoteNodeId: PublicKey, switchboard: ActorRef[Switchboard.GetPeerInfo], context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = { + private def waitForPeerConnected(switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = { val peerInfoAdapter = context.messageAdapter[Peer.PeerInfoResponse] { // We receive this when we don't have any channel to the given peer and are not currently connected to them. // In that case we still want to wait for a connection, because we may want to open a channel to them. @@ -110,7 +152,7 @@ object PeerReadyNotifier { switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId) Behaviors.receiveMessagePartial { case PeerNotConnected => - context.log.debug("peer is not connected yet") + log.debug("peer is not connected yet") Behaviors.same case SomePeerConnected(nodeId) => if (nodeId == remoteNodeId) { @@ -121,28 +163,28 @@ object PeerReadyNotifier { Behaviors.same case WrappedPeerInfo(peer, channelCount) => if (channelCount == 0) { - context.log.info("peer is ready with no channels") + log.info("peer is ready with no channels") replyTo ! PeerReady(remoteNodeId, peer.toClassic, Seq.empty) Behaviors.stopped } else { - context.log.debug("peer is connected with {} channels", channelCount) - waitForChannelsReady(replyTo, remoteNodeId, peer, switchboard, context, timers) + log.debug("peer is connected with {} channels", channelCount) + waitForChannelsReady(peer, switchboard) } case NewBlockNotTimedOut(currentBlockHeight) => - context.log.debug("waiting for peer to connect at block {}", currentBlockHeight) + log.debug("waiting for peer to connect at block {}", currentBlockHeight) Behaviors.same case Timeout => - context.log.info("timed out waiting for peer to connect") + log.info("timed out waiting for peer to connect") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped } } - private def waitForChannelsReady(replyTo: ActorRef[Result], remoteNodeId: PublicKey, peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo], context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = { + private def waitForChannelsReady(peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = { timers.startTimerWithFixedDelay(ChannelsReadyTimerKey, CheckChannelsReady, initialDelay = 50 millis, delay = 1 second) Behaviors.receiveMessagePartial { case CheckChannelsReady => - context.log.debug("checking channel states") + log.debug("checking channel states") peer ! Peer.GetPeerChannels(context.messageAdapter[Peer.PeerChannels](WrappedPeerChannels)) Behaviors.same case WrappedPeerChannels(peerChannels) => @@ -150,58 +192,27 @@ object PeerReadyNotifier { replyTo ! PeerReady(remoteNodeId, peer.toClassic, peerChannels.channels) Behaviors.stopped } else { - context.log.debug("peer has {} channels that are not ready", peerChannels.channels.count(s => !isChannelReady(s.state))) + log.debug("peer has {} channels that are not ready", peerChannels.channels.count(s => !isChannelReady(s.state))) Behaviors.same } case NewBlockNotTimedOut(currentBlockHeight) => - context.log.debug("waiting for channels to be ready at block {}", currentBlockHeight) + log.debug("waiting for channels to be ready at block {}", currentBlockHeight) Behaviors.same case SomePeerConnected(_) => Behaviors.same case SomePeerDisconnected(nodeId) => if (nodeId == remoteNodeId) { - context.log.debug("peer disconnected, waiting for them to reconnect") + log.debug("peer disconnected, waiting for them to reconnect") timers.cancel(ChannelsReadyTimerKey) - waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers) + waitForPeerConnected(switchboard) } else { Behaviors.same } case Timeout => - context.log.info("timed out waiting for channels to be ready") + log.info("timed out waiting for channels to be ready") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped } } - // We use an exhaustive pattern matching here to ensure we explicitly handle future new channel states. - // We only want to test that channels are not in an uninitialized state, we don't need them to be available to relay - // payments (channels closing or waiting to confirm are "ready" for our purposes). - private def isChannelReady(state: channel.ChannelState): Boolean = state match { - case channel.WAIT_FOR_INIT_INTERNAL => false - case channel.WAIT_FOR_INIT_SINGLE_FUNDED_CHANNEL => false - case channel.WAIT_FOR_INIT_DUAL_FUNDED_CHANNEL => false - case channel.OFFLINE => false - case channel.SYNCING => false - case channel.WAIT_FOR_OPEN_CHANNEL => true - case channel.WAIT_FOR_ACCEPT_CHANNEL => true - case channel.WAIT_FOR_FUNDING_INTERNAL => true - case channel.WAIT_FOR_FUNDING_CREATED => true - case channel.WAIT_FOR_FUNDING_SIGNED => true - case channel.WAIT_FOR_FUNDING_CONFIRMED => true - case channel.WAIT_FOR_CHANNEL_READY => true - case channel.WAIT_FOR_OPEN_DUAL_FUNDED_CHANNEL => true - case channel.WAIT_FOR_ACCEPT_DUAL_FUNDED_CHANNEL => true - case channel.WAIT_FOR_DUAL_FUNDING_CREATED => true - case channel.WAIT_FOR_DUAL_FUNDING_SIGNED => true - case channel.WAIT_FOR_DUAL_FUNDING_CONFIRMED => true - case channel.WAIT_FOR_DUAL_FUNDING_READY => true - case channel.NORMAL => true - case channel.SHUTDOWN => true - case channel.NEGOTIATING => true - case channel.CLOSING => true - case channel.CLOSED => true - case channel.WAIT_FOR_REMOTE_PUBLISH_FUTURE_COMMITMENT => true - case channel.ERR_INFORMATION_LEAK => true - } - -} +} \ No newline at end of file From aad3226716dd99a472384ebfc0f37fee83768769 Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 26 Aug 2024 14:42:35 +0200 Subject: [PATCH 08/13] Refactor `PeerReadyNotifier` - clean-up `SomePeerConnected` / `SomePeerDisconnected` - include supervision directly inside the actor itself --- .../fr/acinq/eclair/io/MessageRelay.scala | 4 +- .../acinq/eclair/io/PeerReadyNotifier.scala | 70 ++++++++++--------- .../payment/relay/AsyncPaymentTriggerer.scala | 4 +- .../eclair/payment/relay/ChannelRelay.scala | 4 +- .../eclair/payment/relay/NodeRelay.scala | 4 +- 5 files changed, 45 insertions(+), 41 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index 5c607083c0..368d883c14 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -16,10 +16,10 @@ package fr.acinq.eclair.io +import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} -import akka.actor.typed.{Behavior, SupervisorStrategy} import akka.actor.{ActorRef, typed} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -146,7 +146,7 @@ private class MessageRelay(nodeParams: NodeParams, waitForConnection(msg, nodeId) } case EncodedNodeId.WithPublicKey.Wallet(nodeId) => - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) waitForWalletNodeUp(msg, nodeId) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index c6defa3533..70ab7c46c6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -20,7 +20,7 @@ import akka.actor.typed.eventstream.EventStream import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps} import akka.actor.typed.scaladsl.{ActorContext, Behaviors, TimerScheduler} -import akka.actor.typed.{ActorRef, Behavior} +import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy} import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.blockchain.CurrentBlockHeight import fr.acinq.eclair.{BlockHeight, Logs, channel} @@ -40,15 +40,16 @@ object PeerReadyNotifier { case class NotifyWhenPeerReady(replyTo: ActorRef[Result]) extends Command private final case class WrappedListing(wrapped: Receptionist.Listing) extends Command private case object PeerNotConnected extends Command - private case class SomePeerConnected(nodeId: PublicKey) extends Command - private case class SomePeerDisconnected(nodeId: PublicKey) extends Command + private case object PeerConnected extends Command + private case object PeerDisconnected extends Command private case class WrappedPeerInfo(peer: ActorRef[Peer.GetPeerChannels], channelCount: Int) extends Command private case class NewBlockNotTimedOut(currentBlockHeight: BlockHeight) extends Command private case object CheckChannelsReady extends Command private case class WrappedPeerChannels(wrapped: Peer.PeerChannels) extends Command private case object Timeout extends Command + private case object ToBeIgnored extends Command - sealed trait Result + sealed trait Result { def remoteNodeId: PublicKey } case class PeerReady(remoteNodeId: PublicKey, peer: akka.actor.ActorRef, channelInfos: Seq[Peer.ChannelInfo]) extends Result { val channelsCount: Int = channelInfos.size } case class PeerUnavailable(remoteNodeId: PublicKey) extends Result @@ -59,21 +60,24 @@ object PeerReadyNotifier { Behaviors.setup { context => Behaviors.withTimers { timers => Behaviors.withMdc(Logs.mdc(remoteNodeId_opt = Some(remoteNodeId))) { - Behaviors.receiveMessagePartial { - case NotifyWhenPeerReady(replyTo) => - timeout_opt.foreach { - case Left(d) => timers.startSingleTimer(Timeout, d) - case Right(h) => context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[CurrentBlockHeight] { - case cbc if h <= cbc.blockHeight => Timeout - case cbc => NewBlockNotTimedOut(cbc.blockHeight) - }) - } - // In case the peer is not currently connected, we will wait for them to connect instead of regularly - // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. - context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => SomePeerConnected(e.nodeId))) - context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => SomePeerDisconnected(e.nodeId))) - new PeerReadyNotifier(replyTo, remoteNodeId, context, timers).findSwitchboard() - } + Behaviors.receiveMessagePartial { + case NotifyWhenPeerReady(replyTo) => + timeout_opt.foreach { + case Left(d) => timers.startSingleTimer(Timeout, d) + case Right(h) => context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[CurrentBlockHeight] { + case cbc if h <= cbc.blockHeight => Timeout + case cbc => NewBlockNotTimedOut(cbc.blockHeight) + }) + } + // In case the peer is not currently connected, we will wait for them to connect instead of regularly + // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. + context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => if (e.nodeId == remoteNodeId) PeerConnected else ToBeIgnored)) + context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => if (e.nodeId == remoteNodeId) PeerDisconnected else ToBeIgnored)) + // The actor should never throw, but for extra safety we wrap it with a supervisor. + Behaviors.supervise { + new PeerReadyNotifier(replyTo, remoteNodeId, context, timers).findSwitchboard() + }.onFailure(SupervisorStrategy.stop) + } } } } @@ -137,6 +141,8 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], log.info("timed out finding switchboard actor") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped + case ToBeIgnored => + Behaviors.same } } @@ -154,12 +160,10 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], case PeerNotConnected => log.debug("peer is not connected yet") Behaviors.same - case SomePeerConnected(nodeId) => - if (nodeId == remoteNodeId) { - switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId) - } + case PeerConnected => + switchboard ! Switchboard.GetPeerInfo(peerInfoAdapter, remoteNodeId) Behaviors.same - case SomePeerDisconnected(_) => + case PeerDisconnected => Behaviors.same case WrappedPeerInfo(peer, channelCount) => if (channelCount == 0) { @@ -177,6 +181,8 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], log.info("timed out waiting for peer to connect") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped + case ToBeIgnored => + Behaviors.same } } @@ -198,20 +204,18 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], case NewBlockNotTimedOut(currentBlockHeight) => log.debug("waiting for channels to be ready at block {}", currentBlockHeight) Behaviors.same - case SomePeerConnected(_) => + case PeerConnected => Behaviors.same - case SomePeerDisconnected(nodeId) => - if (nodeId == remoteNodeId) { - log.debug("peer disconnected, waiting for them to reconnect") - timers.cancel(ChannelsReadyTimerKey) - waitForPeerConnected(switchboard) - } else { - Behaviors.same - } + case PeerDisconnected => + log.debug("peer disconnected, waiting for them to reconnect") + timers.cancel(ChannelsReadyTimerKey) + waitForPeerConnected(switchboard) case Timeout => log.info("timed out waiting for channels to be ready") replyTo ! PeerUnavailable(remoteNodeId) Behaviors.stopped + case ToBeIgnored => + Behaviors.same } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggerer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggerer.scala index 5adccac806..965e4b5203 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggerer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggerer.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.payment.relay import akka.actor.typed.ActorRef.ActorRefOps import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.{ActorContext, Behaviors} -import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy} +import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory @@ -99,7 +99,7 @@ private class AsyncPaymentTriggerer(context: ActorContext[Command]) { case Watch(replyTo, remoteNodeId, paymentHash, timeout) => peers.get(remoteNodeId) match { case None => - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(remoteNodeId, timeout_opt = None)).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(PeerReadyNotifier(remoteNodeId, timeout_opt = None)) context.watchWith(notifier, NotifierStopped(remoteNodeId)) notifier ! NotifyWhenPeerReady(context.messageAdapter[PeerReadyNotifier.Result](WrappedPeerReadyResult)) val peer = PeerPayments(notifier, Set(Payment(replyTo, timeout, paymentHash))) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index 318e064d58..bd1a6d68fb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -17,10 +17,10 @@ package fr.acinq.eclair.payment.relay import akka.actor.ActorRef +import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} -import akka.actor.typed.{Behavior, SupervisorStrategy} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ @@ -154,7 +154,7 @@ class ChannelRelay private(nodeParams: NodeParams, private def wakeUp(walletNodeId: PublicKey): Behavior[Command] = { context.log.info("trying to wake up channel peer (nodeId={})", walletNodeId) - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) Behaviors.receiveMessagePartial { case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index aa16f3c8eb..5bf729e636 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -16,10 +16,10 @@ package fr.acinq.eclair.payment.relay +import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.{TypedActorContextOps, TypedActorRefOps} import akka.actor.typed.scaladsl.{ActorContext, Behaviors} -import akka.actor.typed.{Behavior, SupervisorStrategy} import akka.actor.{ActorRef, typed} import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.ByteVector32 @@ -298,7 +298,7 @@ class NodeRelay private(nodeParams: NodeParams, */ private def waitForPeerReady(upstream: Upstream.Hot.Trampoline, walletNodeId: PublicKey, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { context.log.info("trying to wake up next peer (nodeId={})", walletNodeId) - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.peerWakeUpConfig.timeout)))) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) Behaviors.receiveMessagePartial { rejectExtraHtlcPartialFunction orElse { From 561320426432971d32e5f19ad9b3609946e50135 Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 26 Aug 2024 15:50:16 +0200 Subject: [PATCH 09/13] Track pending `PeerReadyNotifier` instances It can be useful to track pending `PeerReadyNotifier` instances to avoid performing duplicate actions when multiple `PeerReadyNotifier` spawn for the same peer (e.g. sending a mobile notification). When a `PeerReadyNotifier` actor is started, it registers itself into a singleton `PeerReadyManager`, which tells it whether there are other pending attempts for the same peer. --- .../main/scala/fr/acinq/eclair/Setup.scala | 1 + .../acinq/eclair/io/PeerReadyNotifier.scala | 142 ++++++++++++++---- .../eclair/io/PeerReadyManagerSpec.scala | 55 +++++++ .../eclair/io/PeerReadyNotifierSpec.scala | 16 +- 4 files changed, 186 insertions(+), 28 deletions(-) create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyManagerSpec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index 63587b7b33..2b28bb0415 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -360,6 +360,7 @@ class Setup(val datadir: File, offerManager = system.spawn(Behaviors.supervise(OfferManager(nodeParams, router, paymentTimeout = 1 minute)).onFailure(typed.SupervisorStrategy.resume), name = "offer-manager") paymentHandler = system.actorOf(SimpleSupervisor.props(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler", SupervisorStrategy.Resume)) triggerer = system.spawn(Behaviors.supervise(AsyncPaymentTriggerer()).onFailure(typed.SupervisorStrategy.resume), name = "async-payment-triggerer") + peerReadyManager = system.spawn(Behaviors.supervise(PeerReadyManager()).onFailure(typed.SupervisorStrategy.restart), name = "peer-ready-manager") relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, router, register, paymentHandler, Some(postRestartCleanUpInitialized)), "relayer", SupervisorStrategy.Resume)) _ = relayer ! PostRestartHtlcCleaner.Init(channels) // Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index 70ab7c46c6..684cbe9bcc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.io import akka.actor.typed.eventstream.EventStream -import akka.actor.typed.receptionist.Receptionist +import akka.actor.typed.receptionist.{Receptionist, ServiceKey} import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps} import akka.actor.typed.scaladsl.{ActorContext, Behaviors, TimerScheduler} import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy} @@ -27,9 +27,73 @@ import fr.acinq.eclair.{BlockHeight, Logs, channel} import scala.concurrent.duration.{DurationInt, FiniteDuration} +/** + * This actor tracks the set of pending [[PeerReadyNotifier]]. + * It can be used to ensure that notifications are only sent once, even if there are multiple parallel operations + * waiting for that peer to come online. + */ +object PeerReadyManager { + + val PeerReadyManagerServiceKey: ServiceKey[Register] = ServiceKey[Register]("peer-ready-manager") + + // @formatter:off + sealed trait Command + case class Register(replyTo: ActorRef[Registered], remoteNodeId: PublicKey) extends Command + case class List(replyTo: ActorRef[Set[PublicKey]]) extends Command + private case class Completed(remoteNodeId: PublicKey, actor: ActorRef[Registered]) extends Command + // @formatter:on + + /** + * @param otherAttempts number of already pending [[PeerReadyNotifier]] instances for that peer. + */ + case class Registered(remoteNodeId: PublicKey, otherAttempts: Int) + + def apply(): Behavior[Command] = { + Behaviors.setup { context => + context.system.receptionist ! Receptionist.Register(PeerReadyManagerServiceKey, context.self) + watch(Map.empty, context) + } + } + + private def watch(pending: Map[PublicKey, Set[ActorRef[Registered]]], context: ActorContext[Command]): Behavior[Command] = { + Behaviors.receiveMessage { + case Register(replyTo, remoteNodeId) => + context.watchWith(replyTo, Completed(remoteNodeId, replyTo)) + pending.get(remoteNodeId) match { + case Some(attempts) => + replyTo ! Registered(remoteNodeId, otherAttempts = attempts.size) + val attempts1 = attempts + replyTo + watch(pending + (remoteNodeId -> attempts1), context) + case None => + replyTo ! Registered(remoteNodeId, otherAttempts = 0) + watch(pending + (remoteNodeId -> Set(replyTo)), context) + } + case Completed(remoteNodeId, actor) => + pending.get(remoteNodeId) match { + case Some(attempts) => + val attempts1 = attempts - actor + if (attempts1.isEmpty) { + watch(pending - remoteNodeId, context) + } else { + watch(pending + (remoteNodeId -> attempts1), context) + } + case None => + Behaviors.same + } + case List(replyTo) => + replyTo ! pending.keySet + Behaviors.same + } + } + +} + /** * This actor waits for a given peer to be online and ready to process payments. - * It automatically stops after the timeout provided. + * It automatically stops after the timeout provided if the peer doesn't connect. + * There may be multiple instances of this actor running in parallel for the same peer, which is fine because they + * may use different timeouts. + * Having separate actor instances for each caller guarantees that the caller will always receive a response. */ object PeerReadyNotifier { @@ -39,6 +103,7 @@ object PeerReadyNotifier { sealed trait Command case class NotifyWhenPeerReady(replyTo: ActorRef[Result]) extends Command private final case class WrappedListing(wrapped: Receptionist.Listing) extends Command + private final case class WrappedRegistered(registered: PeerReadyManager.Registered) extends Command private case object PeerNotConnected extends Command private case object PeerConnected extends Command private case object PeerDisconnected extends Command @@ -60,24 +125,24 @@ object PeerReadyNotifier { Behaviors.setup { context => Behaviors.withTimers { timers => Behaviors.withMdc(Logs.mdc(remoteNodeId_opt = Some(remoteNodeId))) { - Behaviors.receiveMessagePartial { - case NotifyWhenPeerReady(replyTo) => - timeout_opt.foreach { - case Left(d) => timers.startSingleTimer(Timeout, d) - case Right(h) => context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[CurrentBlockHeight] { - case cbc if h <= cbc.blockHeight => Timeout - case cbc => NewBlockNotTimedOut(cbc.blockHeight) - }) - } - // In case the peer is not currently connected, we will wait for them to connect instead of regularly - // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. - context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => if (e.nodeId == remoteNodeId) PeerConnected else ToBeIgnored)) - context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => if (e.nodeId == remoteNodeId) PeerDisconnected else ToBeIgnored)) - // The actor should never throw, but for extra safety we wrap it with a supervisor. - Behaviors.supervise { - new PeerReadyNotifier(replyTo, remoteNodeId, context, timers).findSwitchboard() - }.onFailure(SupervisorStrategy.stop) - } + Behaviors.receiveMessagePartial { + case NotifyWhenPeerReady(replyTo) => + timeout_opt.foreach { + case Left(d) => timers.startSingleTimer(Timeout, d) + case Right(h) => context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[CurrentBlockHeight] { + case cbc if h <= cbc.blockHeight => Timeout + case cbc => NewBlockNotTimedOut(cbc.blockHeight) + }) + } + // In case the peer is not currently connected, we will wait for them to connect instead of regularly + // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. + context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => if (e.nodeId == remoteNodeId) PeerConnected else ToBeIgnored)) + context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => if (e.nodeId == remoteNodeId) PeerDisconnected else ToBeIgnored)) + // The actor should never throw, but for extra safety we wrap it with a supervisor. + Behaviors.supervise { + new PeerReadyNotifier(replyTo, remoteNodeId, context, timers).register() + }.onFailure(SupervisorStrategy.stop) + } } } } @@ -125,13 +190,38 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], private val log = context.log - private def findSwitchboard(): Behavior[Command] = { + private def register(): Behavior[Command] = { + context.system.receptionist ! Receptionist.Find(PeerReadyManager.PeerReadyManagerServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) + Behaviors.receiveMessagePartial { + case WrappedListing(PeerReadyManager.PeerReadyManagerServiceKey.Listing(listings)) => + listings.headOption match { + case Some(peerReadyManager) => + peerReadyManager ! PeerReadyManager.Register(context.messageAdapter[PeerReadyManager.Registered](WrappedRegistered), remoteNodeId) + Behaviors.same + case None => + log.error("no peer-ready-manager found") + replyTo ! PeerUnavailable(remoteNodeId) + Behaviors.stopped + } + case WrappedRegistered(registered) => + log.info("checking if peer is available ({} other attempts)", registered.otherAttempts) + findSwitchboard(isFirstAttempt = registered.otherAttempts == 0) + case Timeout => + log.info("timed out finding peer-ready-manager actor") + replyTo ! PeerUnavailable(remoteNodeId) + Behaviors.stopped + case ToBeIgnored => + Behaviors.same + } + } + + private def findSwitchboard(isFirstAttempt: Boolean): Behavior[Command] = { context.system.receptionist ! Receptionist.Find(Switchboard.SwitchboardServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) Behaviors.receiveMessagePartial { case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) => listings.headOption match { case Some(switchboard) => - waitForPeerConnected(switchboard) + waitForPeerConnected(switchboard, isFirstAttempt) case None => log.error("no switchboard found") replyTo ! PeerUnavailable(remoteNodeId) @@ -146,7 +236,7 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], } } - private def waitForPeerConnected(switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = { + private def waitForPeerConnected(switchboard: ActorRef[Switchboard.GetPeerInfo], isFirstAttempt: Boolean): Behavior[Command] = { val peerInfoAdapter = context.messageAdapter[Peer.PeerInfoResponse] { // We receive this when we don't have any channel to the given peer and are not currently connected to them. // In that case we still want to wait for a connection, because we may want to open a channel to them. @@ -172,7 +262,7 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], Behaviors.stopped } else { log.debug("peer is connected with {} channels", channelCount) - waitForChannelsReady(peer, switchboard) + waitForChannelsReady(peer, switchboard, isFirstAttempt) } case NewBlockNotTimedOut(currentBlockHeight) => log.debug("waiting for peer to connect at block {}", currentBlockHeight) @@ -186,7 +276,7 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], } } - private def waitForChannelsReady(peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = { + private def waitForChannelsReady(peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo], isFirstAttempt: Boolean): Behavior[Command] = { timers.startTimerWithFixedDelay(ChannelsReadyTimerKey, CheckChannelsReady, initialDelay = 50 millis, delay = 1 second) Behaviors.receiveMessagePartial { case CheckChannelsReady => @@ -209,7 +299,7 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], case PeerDisconnected => log.debug("peer disconnected, waiting for them to reconnect") timers.cancel(ChannelsReadyTimerKey) - waitForPeerConnected(switchboard) + waitForPeerConnected(switchboard, isFirstAttempt) case Timeout => log.info("timed out waiting for channels to be ready") replyTo ! PeerUnavailable(remoteNodeId) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyManagerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyManagerSpec.scala new file mode 100644 index 0000000000..7ef663ec5a --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyManagerSpec.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2024 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.io + +import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} +import com.typesafe.config.ConfigFactory +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.eclair.randomKey +import org.scalatest.funsuite.AnyFunSuiteLike + +class PeerReadyManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with AnyFunSuiteLike { + + test("watch pending notifiers") { + val manager = testKit.spawn(PeerReadyManager()) + val remoteNodeId1 = randomKey().publicKey + val notifier1a = TestProbe[PeerReadyManager.Registered]() + val notifier1b = TestProbe[PeerReadyManager.Registered]() + + manager ! PeerReadyManager.Register(notifier1a.ref, remoteNodeId1) + assert(notifier1a.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0) + manager ! PeerReadyManager.Register(notifier1b.ref, remoteNodeId1) + assert(notifier1b.expectMessageType[PeerReadyManager.Registered].otherAttempts == 1) + + val remoteNodeId2 = randomKey().publicKey + val notifier2a = TestProbe[PeerReadyManager.Registered]() + val notifier2b = TestProbe[PeerReadyManager.Registered]() + + // Later attempts aren't affected by previously completed attempts. + manager ! PeerReadyManager.Register(notifier2a.ref, remoteNodeId2) + assert(notifier2a.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0) + notifier2a.stop() + val probe = TestProbe[Set[PublicKey]]() + probe.awaitAssert({ + manager ! PeerReadyManager.List(probe.ref) + assert(probe.expectMessageType[Set[PublicKey]] == Set(remoteNodeId1)) + }) + manager ! PeerReadyManager.Register(notifier2b.ref, remoteNodeId2) + assert(notifier2b.expectMessageType[PeerReadyManager.Registered].otherAttempts == 0) + } + +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala index 4d2dad7808..4247ae472f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala @@ -33,17 +33,20 @@ import scala.concurrent.duration.DurationInt class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { - case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[PeerReadyNotifier.Result]) + case class FixtureParam(remoteNodeId: PublicKey, peerReadyManager: TestProbe[PeerReadyManager.Register], switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[PeerReadyNotifier.Result]) override def withFixture(test: OneArgTest): Outcome = { val remoteNodeId = randomKey().publicKey + val peerReadyManager = TestProbe[PeerReadyManager.Register]("peer-ready-manager") + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard") system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) val peer = TestProbe[Peer.GetPeerChannels]("peer") val probe = TestProbe[PeerReadyNotifier.Result]() try { - withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe))) + withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, peerReadyManager, switchboard, peer, probe))) } finally { + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) } } @@ -53,6 +56,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(10 millis)))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId) probe.expectMessage(PeerUnavailable(remoteNodeId)) } @@ -62,6 +66,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId) // We haven't reached the timeout yet. @@ -78,6 +83,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request = switchboard.expectMessageType[Switchboard.GetPeerInfo] request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set.empty) probe.expectMessage(PeerReadyNotifier.PeerReady(remoteNodeId, peer.ref.toClassic, Seq.empty)) @@ -88,6 +94,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic)) @@ -115,6 +122,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 1) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic, TestProbe().ref.toClassic)) peer.expectNoMessage(100 millis) @@ -137,6 +145,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = None)) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerNotFound(remoteNodeId) peer.expectNoMessage(100 millis) @@ -161,6 +170,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(500))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 5) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set.empty) peer.expectNoMessage(100 millis) @@ -185,6 +195,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(1 second)))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request = switchboard.expectMessageType[Switchboard.GetPeerInfo] request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic)) peer.expectMessageType[Peer.GetPeerChannels] @@ -196,6 +207,7 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Right(BlockHeight(100))))) notifier ! NotifyWhenPeerReady(probe.ref) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 2) val request = switchboard.expectMessageType[Switchboard.GetPeerInfo] request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.CONNECTED, None, Set(TestProbe().ref.toClassic)) peer.expectMessageType[Peer.GetPeerChannels] From ea613a9a86c8119a4b9ac06c824c9d9a81a3b397 Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 26 Aug 2024 16:08:05 +0200 Subject: [PATCH 10/13] fixup! Track pending `PeerReadyNotifier` instances --- .../scala/fr/acinq/eclair/io/MessageRelaySpec.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index c296602bc1..87d93d645e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -47,9 +47,11 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val wakeUpTimeout = "wake_up_timeout" - case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status]) + case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], peerReadyManager: TestProbe, probe: TypedProbe[Status]) override def withFixture(test: OneArgTest): Outcome = { + val peerReadyManager = TestProbe("peer-ready-manager")(system.classicSystem) + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref.toTyped) val switchboard = TestProbe("switchboard")(system.classicSystem) system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped) val register = TestProbe("register")(system.classicSystem) @@ -60,8 +62,9 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val nodeParams = if (test.tags.contains(wakeUpTimeout)) Alice.nodeParams.copy(peerWakeUpConfig = WakeUpConfig(100 millis)) else Alice.nodeParams val relay = testKit.spawn(MessageRelay(nodeParams, switchboard.ref, register.ref, router.ref)) try { - withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, probe))) + withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, peerReadyManager, probe))) } finally { + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref.toTyped) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped) testKit.stop(relay) } @@ -100,6 +103,10 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val messageId = randomBytes32() relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, None) + val register = peerReadyManager.expectMsgType[PeerReadyManager.Register] + assert(register.remoteNodeId == bobId) + register.replyTo ! PeerReadyManager.Registered(bobId, otherAttempts = 0) + val request = switchboard.expectMsgType[GetPeerInfo] assert(request.remoteNodeId == bobId) request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, bobId, Peer.CONNECTED, None, Set.empty) From 28e92c2d022964dc38a10e3787c2b4cdd25c22aa Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 26 Aug 2024 16:38:52 +0200 Subject: [PATCH 11/13] fixup! Track pending `PeerReadyNotifier` instances --- .../relay/AsyncPaymentTriggererSpec.scala | 21 +++++++++++++++---- .../payment/relay/ChannelRelayerSpec.scala | 10 ++++++++- .../payment/relay/NodeRelayerSpec.scala | 10 ++++++++- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala index 223b7af750..eb9b7e3ea2 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala @@ -11,7 +11,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.blockchain.CurrentBlockHeight import fr.acinq.eclair.channel.NEGOTIATING import fr.acinq.eclair.io.Switchboard.GetPeerInfo -import fr.acinq.eclair.io.{Peer, PeerConnected, Switchboard} +import fr.acinq.eclair.io.{Peer, PeerConnected, PeerReadyManager, Switchboard} import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer._ import fr.acinq.eclair.{BlockHeight, TestConstants, randomKey} import org.scalatest.Outcome @@ -21,18 +21,21 @@ import scala.concurrent.duration.DurationInt class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { - case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[Result], triggerer: ActorRef[Command]) + case class FixtureParam(remoteNodeId: PublicKey, peerReadyManager: TestProbe[PeerReadyManager.Register], switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[Result], triggerer: ActorRef[Command]) override def withFixture(test: OneArgTest): Outcome = { val remoteNodeId = TestConstants.Alice.nodeParams.nodeId + val peerReadyManager = TestProbe[PeerReadyManager.Register]("peer-ready-manager") + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard") system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) val peer = TestProbe[Peer.GetPeerChannels]("peer") val probe = TestProbe[Result]() val triggerer = testKit.spawn(AsyncPaymentTriggerer()) try { - withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe, triggerer))) + withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, peerReadyManager, switchboard, peer, probe, triggerer))) } finally { + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) } } @@ -41,6 +44,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. import f._ triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) // We haven't reached the timeout yet. @@ -60,6 +64,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. import f._ triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) // cancel of an unwatched payment does nothing @@ -75,6 +80,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. // create two identical watches triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) @@ -86,6 +92,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. // create two different watches val probe2 = TestProbe[Result]() triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 1) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) triggerer ! Watch(probe2.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) @@ -101,6 +108,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. // create watches for two payments with the same payment hash val probe2 = TestProbe[Result]() triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) triggerer ! Watch(probe2.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) @@ -114,6 +122,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. import f._ triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) @@ -137,6 +146,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. import f._ triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) @@ -162,6 +172,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. // watch remote node triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) @@ -169,8 +180,9 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. val remoteNodeId2 = TestConstants.Bob.nodeParams.nodeId val probe2 = TestProbe[Result]() triggerer ! Watch(probe2.ref, remoteNodeId2, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(101)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId2, otherAttempts = 0) val request2 = switchboard.expectMessageType[Switchboard.GetPeerInfo] - request2.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) + request2.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId2, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) // First remote node times out system.eventStream ! EventStream.Publish(CurrentBlockHeight(BlockHeight(100))) @@ -193,6 +205,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. test("triggerer treats an unexpected stop of the notifier as a cancel") { f => import f._ triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) triggerer ! NotifierStopped(remoteNodeId) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala index c705d21621..5cab0aa877 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.io.PeerReadyNotifier.WakeUpConfig -import fr.acinq.eclair.io.{Peer, Switchboard} +import fr.acinq.eclair.io.{Peer, PeerReadyManager, Switchboard} import fr.acinq.eclair.payment.IncomingPaymentPacket.ChannelRelayPacket import fr.acinq.eclair.payment.relay.ChannelRelayer._ import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket, PaymentPacketSpec} @@ -176,6 +176,8 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a test("relay blinded payment (wake up wallet node)") { f => import f._ + val peerReadyManager = TestProbe[PeerReadyManager.Register]() + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) val switchboard = TestProbe[Switchboard.GetPeerInfo]() system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) @@ -188,12 +190,14 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) // We try to wake-up the next node. + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0) val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] assert(wakeUp.remoteNodeId == outgoingNodeId) wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7) }) + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) } @@ -327,6 +331,8 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a test("fail to relay blinded payment (cannot wake up remote node)", Tag(wakeUpTimeout)) { f => import f._ + val peerReadyManager = TestProbe[PeerReadyManager.Register]() + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) val switchboard = TestProbe[Switchboard.GetPeerInfo]() system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) @@ -338,10 +344,12 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a channelRelayer ! Relay(r, TestConstants.Alice.nodeParams.nodeId) // We try to wake-up the next node, but we timeout before they connect. + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId) val fail = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fail.message.reason.contains(InvalidOnionBlinding(Sphinx.hash(r.add.onionRoutingPacket)))) + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 71e120befa..3d734a2fb3 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -32,7 +32,7 @@ import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion} import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register, Upstream} import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.io.{Peer, Switchboard} +import fr.acinq.eclair.io.{Peer, PeerReadyManager, Switchboard} import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.IncomingPaymentPacket.{RelayToBlindedPathsPacket, RelayToTrampolinePacket} import fr.acinq.eclair.payment.Invoice.ExtraEdge @@ -790,6 +790,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("relay to blinded path with wake-up") { f => import f._ + val peerReadyManager = TestProbe[PeerReadyManager.Register]() + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) val switchboard = TestProbe[Switchboard.GetPeerInfo]() system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) @@ -798,9 +800,11 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) // The remote node is a wallet node: we try to wake them up before relaying the payment. + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0) val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] assert(wakeUp.remoteNodeId == outgoingNodeId) wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] @@ -832,6 +836,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("fail to relay to blinded path when wake-up fails", Tag(wakeUpTimeout)) { f => import f._ + val peerReadyManager = TestProbe[PeerReadyManager.Register]() + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) val switchboard = TestProbe[Switchboard.GetPeerInfo]() system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) @@ -840,7 +846,9 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming, randomKey().publicKey)) // The remote node is a wallet node: we try to wake them up before relaying the payment, but it times out. + peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(outgoingNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId) + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) mockPayFSM.expectNoMessage(100 millis) From bfb9090212bb55abc75ec20d6d99f09f0870652e Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 26 Aug 2024 17:23:49 +0200 Subject: [PATCH 12/13] Move `register` outside of private class This way we can have `isFirstAttempt` as a class field, which makes sense since this value won't change. --- .../acinq/eclair/io/PeerReadyNotifier.scala | 73 ++++++++++--------- .../eclair/io/PeerReadyNotifierSpec.scala | 1 - 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index 684cbe9bcc..4533f8c9ac 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -134,13 +134,9 @@ object PeerReadyNotifier { case cbc => NewBlockNotTimedOut(cbc.blockHeight) }) } - // In case the peer is not currently connected, we will wait for them to connect instead of regularly - // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. - context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => if (e.nodeId == remoteNodeId) PeerConnected else ToBeIgnored)) - context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => if (e.nodeId == remoteNodeId) PeerDisconnected else ToBeIgnored)) // The actor should never throw, but for extra safety we wrap it with a supervisor. Behaviors.supervise { - new PeerReadyNotifier(replyTo, remoteNodeId, context, timers).register() + start(replyTo, remoteNodeId, context, timers) }.onFailure(SupervisorStrategy.stop) } } @@ -148,6 +144,35 @@ object PeerReadyNotifier { } } + private def start(replyTo: ActorRef[Result], remoteNodeId: PublicKey, context: ActorContext[Command], timers: TimerScheduler[Command]): Behavior[Command] = { + // We start by registering ourself to see if other instances are running. + context.system.receptionist ! Receptionist.Find(PeerReadyManager.PeerReadyManagerServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) + Behaviors.receiveMessagePartial { + case WrappedListing(PeerReadyManager.PeerReadyManagerServiceKey.Listing(listings)) => + listings.headOption match { + case Some(peerReadyManager) => + peerReadyManager ! PeerReadyManager.Register(context.messageAdapter[PeerReadyManager.Registered](WrappedRegistered), remoteNodeId) + Behaviors.same + case None => + context.log.error("no peer-ready-manager found") + replyTo ! PeerUnavailable(remoteNodeId) + Behaviors.stopped + } + case WrappedRegistered(registered) => + context.log.info("checking if peer is ready ({} other attempts)", registered.otherAttempts) + val isFirstAttempt = registered.otherAttempts == 0 + // In case the peer is not currently connected, we will wait for them to connect instead of regularly + // polling the switchboard. This makes more sense for long timeouts such as the ones used for async payments. + context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerConnected](e => if (e.nodeId == remoteNodeId) PeerConnected else ToBeIgnored)) + context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[PeerDisconnected](e => if (e.nodeId == remoteNodeId) PeerDisconnected else ToBeIgnored)) + new PeerReadyNotifier(replyTo, remoteNodeId, isFirstAttempt, context, timers).findSwitchboard() + case Timeout => + context.log.info("timed out finding peer-ready-manager actor") + replyTo ! PeerUnavailable(remoteNodeId) + Behaviors.stopped + } + } + // We use an exhaustive pattern matching here to ensure we explicitly handle future new channel states. // We only want to test that channels are not in an uninitialized state, we don't need them to be available to relay // payments (channels closing or waiting to confirm are "ready" for our purposes). @@ -183,6 +208,7 @@ object PeerReadyNotifier { private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], remoteNodeId: PublicKey, + isFirstAttempt: Boolean, context: ActorContext[PeerReadyNotifier.Command], timers: TimerScheduler[PeerReadyNotifier.Command]) { @@ -190,38 +216,13 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], private val log = context.log - private def register(): Behavior[Command] = { - context.system.receptionist ! Receptionist.Find(PeerReadyManager.PeerReadyManagerServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) - Behaviors.receiveMessagePartial { - case WrappedListing(PeerReadyManager.PeerReadyManagerServiceKey.Listing(listings)) => - listings.headOption match { - case Some(peerReadyManager) => - peerReadyManager ! PeerReadyManager.Register(context.messageAdapter[PeerReadyManager.Registered](WrappedRegistered), remoteNodeId) - Behaviors.same - case None => - log.error("no peer-ready-manager found") - replyTo ! PeerUnavailable(remoteNodeId) - Behaviors.stopped - } - case WrappedRegistered(registered) => - log.info("checking if peer is available ({} other attempts)", registered.otherAttempts) - findSwitchboard(isFirstAttempt = registered.otherAttempts == 0) - case Timeout => - log.info("timed out finding peer-ready-manager actor") - replyTo ! PeerUnavailable(remoteNodeId) - Behaviors.stopped - case ToBeIgnored => - Behaviors.same - } - } - - private def findSwitchboard(isFirstAttempt: Boolean): Behavior[Command] = { + private def findSwitchboard(): Behavior[Command] = { context.system.receptionist ! Receptionist.Find(Switchboard.SwitchboardServiceKey, context.messageAdapter[Receptionist.Listing](WrappedListing)) Behaviors.receiveMessagePartial { case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) => listings.headOption match { case Some(switchboard) => - waitForPeerConnected(switchboard, isFirstAttempt) + waitForPeerConnected(switchboard) case None => log.error("no switchboard found") replyTo ! PeerUnavailable(remoteNodeId) @@ -236,7 +237,7 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], } } - private def waitForPeerConnected(switchboard: ActorRef[Switchboard.GetPeerInfo], isFirstAttempt: Boolean): Behavior[Command] = { + private def waitForPeerConnected(switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = { val peerInfoAdapter = context.messageAdapter[Peer.PeerInfoResponse] { // We receive this when we don't have any channel to the given peer and are not currently connected to them. // In that case we still want to wait for a connection, because we may want to open a channel to them. @@ -262,7 +263,7 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], Behaviors.stopped } else { log.debug("peer is connected with {} channels", channelCount) - waitForChannelsReady(peer, switchboard, isFirstAttempt) + waitForChannelsReady(peer, switchboard) } case NewBlockNotTimedOut(currentBlockHeight) => log.debug("waiting for peer to connect at block {}", currentBlockHeight) @@ -276,7 +277,7 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], } } - private def waitForChannelsReady(peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo], isFirstAttempt: Boolean): Behavior[Command] = { + private def waitForChannelsReady(peer: ActorRef[Peer.GetPeerChannels], switchboard: ActorRef[Switchboard.GetPeerInfo]): Behavior[Command] = { timers.startTimerWithFixedDelay(ChannelsReadyTimerKey, CheckChannelsReady, initialDelay = 50 millis, delay = 1 second) Behaviors.receiveMessagePartial { case CheckChannelsReady => @@ -299,7 +300,7 @@ private class PeerReadyNotifier(replyTo: ActorRef[PeerReadyNotifier.Result], case PeerDisconnected => log.debug("peer disconnected, waiting for them to reconnect") timers.cancel(ChannelsReadyTimerKey) - waitForPeerConnected(switchboard, isFirstAttempt) + waitForPeerConnected(switchboard) case Timeout => log.info("timed out waiting for channels to be ready") replyTo ! PeerUnavailable(remoteNodeId) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala index 4247ae472f..bb5f174095 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerReadyNotifierSpec.scala @@ -57,7 +57,6 @@ class PeerReadyNotifierSpec extends ScalaTestWithActorTestKit(ConfigFactory.load val notifier = testKit.spawn(PeerReadyNotifier(remoteNodeId, timeout_opt = Some(Left(10 millis)))) notifier ! NotifyWhenPeerReady(probe.ref) peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) - assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == remoteNodeId) probe.expectMessage(PeerUnavailable(remoteNodeId)) } From 3e23b859ae13d692af4466ebdd07139f2cbaa0a5 Mon Sep 17 00:00:00 2001 From: t-bast Date: Tue, 27 Aug 2024 11:59:05 +0200 Subject: [PATCH 13/13] Allow disabling `wake-up` This makes testing easier. We also do small refactorings in the relay actors without any behavior changes. --- eclair-core/src/main/resources/reference.conf | 1 + .../scala/fr/acinq/eclair/NodeParams.scala | 1 + .../acinq/eclair/io/PeerReadyNotifier.scala | 2 +- .../eclair/payment/relay/ChannelRelay.scala | 4 +-- .../eclair/payment/relay/NodeRelay.scala | 4 +-- .../scala/fr/acinq/eclair/TestConstants.scala | 4 +-- .../fr/acinq/eclair/io/MessageRelaySpec.scala | 11 +++--- .../relay/AsyncPaymentTriggererSpec.scala | 34 ++++++++++--------- .../payment/relay/ChannelRelayerSpec.scala | 10 +++--- .../payment/relay/NodeRelayerSpec.scala | 6 ++-- 10 files changed, 44 insertions(+), 33 deletions(-) diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index 79d0438760..9bf4cb33f1 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -321,6 +321,7 @@ eclair { // When relaying payments or messages to mobile peers who are disconnected, we may try to wake them up using a mobile // notification system, or we attempt connecting to the last known address. peer-wake-up { + enabled = false timeout = 60 seconds } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 9c7858f1f1..25b3556049 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -614,6 +614,7 @@ object NodeParams extends Logging { interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS) ), peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig( + enabled = config.getBoolean("peer-wake-up.enabled"), timeout = FiniteDuration(config.getDuration("peer-wake-up.timeout").getSeconds, TimeUnit.SECONDS) ), ) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index 4533f8c9ac..f4a8b8d677 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -97,7 +97,7 @@ object PeerReadyManager { */ object PeerReadyNotifier { - case class WakeUpConfig(timeout: FiniteDuration) + case class WakeUpConfig(enabled: Boolean, timeout: FiniteDuration) // @formatter:off sealed trait Command diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index bd1a6d68fb..1181f75bca 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -145,8 +145,8 @@ class ChannelRelay private(nodeParams: NodeParams, def start(): Behavior[Command] = { walletNodeId_opt match { - case Some(walletNodeId) => wakeUp(walletNodeId) - case None => + case Some(walletNodeId) if nodeParams.peerWakeUpConfig.enabled => wakeUp(walletNodeId) + case _ => context.self ! DoRelay relay(Seq.empty) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 5bf729e636..de22feba6b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -287,8 +287,8 @@ class NodeRelay private(nodeParams: NodeParams, */ private def ensureRecipientReady(upstream: Upstream.Hot.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { nextWalletNodeId(nodeParams, recipient) match { - case Some(walletNodeId) => waitForPeerReady(upstream, walletNodeId, recipient, nextPayload, nextPacket_opt) - case None => relay(upstream, recipient, nextPayload, nextPacket_opt) + case Some(walletNodeId) if nodeParams.peerWakeUpConfig.enabled => waitForPeerReady(upstream, walletNodeId, recipient, nextPayload, nextPacket_opt) + case _ => relay(upstream, recipient, nextPayload, nextPacket_opt) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 72005dfd75..0e70083fa9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -232,7 +232,7 @@ object TestConstants { ), purgeInvoicesInterval = None, revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), - peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(30 seconds), + peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds), ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( @@ -403,7 +403,7 @@ object TestConstants { ), purgeInvoicesInterval = None, revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), - peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(30 seconds), + peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds), ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index 87d93d645e..ae45c4c95b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -22,13 +22,13 @@ import akka.actor.typed.eventstream.EventStream import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps} import akka.testkit.TestProbe +import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair.channel.Register import fr.acinq.eclair.io.MessageRelay._ import fr.acinq.eclair.io.Peer.{PeerInfo, PeerNotFound} -import fr.acinq.eclair.io.PeerReadyNotifier.WakeUpConfig import fr.acinq.eclair.io.Switchboard.GetPeerInfo import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient} @@ -45,6 +45,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val aliceId: PublicKey = Alice.nodeParams.nodeId val bobId: PublicKey = Bob.nodeParams.nodeId + val wakeUpEnabled = "wake_up_enabled" val wakeUpTimeout = "wake_up_timeout" case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], peerReadyManager: TestProbe, probe: TypedProbe[Status]) @@ -59,7 +60,9 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val peerConnection = TypedProbe[Nothing]("peerConnection") val peer = TypedProbe[Peer.RelayOnionMessage]("peer") val probe = TypedProbe[Status]("probe") - val nodeParams = if (test.tags.contains(wakeUpTimeout)) Alice.nodeParams.copy(peerWakeUpConfig = WakeUpConfig(100 millis)) else Alice.nodeParams + val nodeParams = Alice.nodeParams + .modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true) + .modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) val relay = testKit.spawn(MessageRelay(nodeParams, switchboard.ref, register.ref, router.ref)) try { withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, peerReadyManager, probe))) @@ -96,7 +99,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) } - test("relay after waking up next node") { f => + test("relay after waking up next node", Tag(wakeUpEnabled)) { f => import f._ val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) @@ -126,7 +129,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app probe.expectMessage(ConnectionFailure(messageId, PeerConnection.ConnectionResult.NoAddressFound)) } - test("can't wake up next node", Tag(wakeUpTimeout)) { f => + test("can't wake up next node", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f => import f._ val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala index eb9b7e3ea2..eb19cfd3ad 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/AsyncPaymentTriggererSpec.scala @@ -1,10 +1,11 @@ package fr.acinq.eclair.payment.relay import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} -import akka.actor.typed.ActorRef import akka.actor.typed.eventstream.EventStream import akka.actor.typed.receptionist.Receptionist +import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.adapter.TypedActorRefOps +import akka.actor.typed.{ActorRef, Behavior} import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -21,21 +22,31 @@ import scala.concurrent.duration.DurationInt class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { - case class FixtureParam(remoteNodeId: PublicKey, peerReadyManager: TestProbe[PeerReadyManager.Register], switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[Result], triggerer: ActorRef[Command]) + case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], peer: TestProbe[Peer.GetPeerChannels], probe: TestProbe[Result], triggerer: ActorRef[Command]) + + object DummyPeerReadyManager { + def apply(): Behavior[PeerReadyManager.Command] = { + Behaviors.receiveMessagePartial { + case PeerReadyManager.Register(replyTo, remoteNodeId) => + replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) + Behaviors.same + } + } + } override def withFixture(test: OneArgTest): Outcome = { val remoteNodeId = TestConstants.Alice.nodeParams.nodeId - val peerReadyManager = TestProbe[PeerReadyManager.Register]("peer-ready-manager") - system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + val peerReadyManager = testKit.spawn(DummyPeerReadyManager()) + system.receptionist ! Receptionist.Register(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager) val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard") system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) val peer = TestProbe[Peer.GetPeerChannels]("peer") val probe = TestProbe[Result]() val triggerer = testKit.spawn(AsyncPaymentTriggerer()) try { - withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, peerReadyManager, switchboard, peer, probe, triggerer))) + withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, peer, probe, triggerer))) } finally { - system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager.ref) + system.receptionist ! Receptionist.Deregister(PeerReadyManager.PeerReadyManagerServiceKey, peerReadyManager) system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) } } @@ -44,7 +55,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. import f._ triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) // We haven't reached the timeout yet. @@ -64,7 +74,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. import f._ triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) // cancel of an unwatched payment does nothing @@ -80,7 +89,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. // create two identical watches triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) @@ -92,7 +100,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. // create two different watches val probe2 = TestProbe[Result]() triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 1) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) triggerer ! Watch(probe2.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) @@ -108,7 +115,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. // create watches for two payments with the same payment hash val probe2 = TestProbe[Result]() triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) triggerer ! Watch(probe2.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) @@ -122,7 +128,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. import f._ triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) @@ -146,7 +151,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. import f._ triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) @@ -172,7 +176,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. // watch remote node triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request1.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) @@ -180,7 +183,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. val remoteNodeId2 = TestConstants.Bob.nodeParams.nodeId val probe2 = TestProbe[Result]() triggerer ! Watch(probe2.ref, remoteNodeId2, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(101)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId2, otherAttempts = 0) val request2 = switchboard.expectMessageType[Switchboard.GetPeerInfo] request2.replyTo ! Peer.PeerInfo(peer.ref.toClassic, remoteNodeId2, Peer.DISCONNECTED, None, Set(TestProbe().ref.toClassic)) @@ -204,8 +206,8 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory. test("triggerer treats an unexpected stop of the notifier as a cancel") { f => import f._ + triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100)) - peerReadyManager.expectMessageType[PeerReadyManager.Register].replyTo ! PeerReadyManager.Registered(remoteNodeId, otherAttempts = 0) assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId) triggerer ! NotifierStopped(remoteNodeId) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala index 5cab0aa877..433f49c5d9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala @@ -30,7 +30,6 @@ import fr.acinq.eclair.TestConstants.emptyOnionPacket import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx -import fr.acinq.eclair.io.PeerReadyNotifier.WakeUpConfig import fr.acinq.eclair.io.{Peer, PeerReadyManager, Switchboard} import fr.acinq.eclair.payment.IncomingPaymentPacket.ChannelRelayPacket import fr.acinq.eclair.payment.relay.ChannelRelayer._ @@ -52,13 +51,16 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a import ChannelRelayerSpec._ + val wakeUpEnabled = "wake_up_enabled" val wakeUpTimeout = "wake_up_timeout" case class FixtureParam(nodeParams: NodeParams, channelRelayer: typed.ActorRef[ChannelRelayer.Command], register: TestProbe[Any]) override def withFixture(test: OneArgTest): Outcome = { // we are node B in the route A -> B -> C -> .... - val nodeParams = if (test.tags.contains(wakeUpTimeout)) TestConstants.Bob.nodeParams.copy(peerWakeUpConfig = WakeUpConfig(100 millis)) else TestConstants.Bob.nodeParams + val nodeParams = TestConstants.Bob.nodeParams + .modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true) + .modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) val register = TestProbe[Any]("register") val channelRelayer = testKit.spawn(ChannelRelayer.apply(nodeParams, register.ref.toClassic)) try { @@ -173,7 +175,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry, 7) } - test("relay blinded payment (wake up wallet node)") { f => + test("relay blinded payment (wake up wallet node)", Tag(wakeUpEnabled)) { f => import f._ val peerReadyManager = TestProbe[PeerReadyManager.Register]() @@ -328,7 +330,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a } } - test("fail to relay blinded payment (cannot wake up remote node)", Tag(wakeUpTimeout)) { f => + test("fail to relay blinded payment (cannot wake up remote node)", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f => import f._ val peerReadyManager = TestProbe[PeerReadyManager.Register]() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index 3d734a2fb3..21c1891e44 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -67,6 +67,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl import NodeRelayerSpec._ + val wakeUpEnabled = "wake_up_enabled" val wakeUpTimeout = "wake_up_timeout" case class FixtureParam(nodeParams: NodeParams, router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent]) { @@ -97,6 +98,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeParams = TestConstants.Bob.nodeParams .modify(_.multiPartPaymentExpiry).setTo(5 seconds) .modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires + .modify(_.peerWakeUpConfig.enabled).setToIf(test.tags.contains(wakeUpEnabled))(true) .modify(_.peerWakeUpConfig.timeout).setToIf(test.tags.contains(wakeUpTimeout))(100 millis) val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") @@ -787,7 +789,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } - test("relay to blinded path with wake-up") { f => + test("relay to blinded path with wake-up", Tag(wakeUpEnabled)) { f => import f._ val peerReadyManager = TestProbe[PeerReadyManager.Register]() @@ -833,7 +835,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl register.expectNoMessage(100 millis) } - test("fail to relay to blinded path when wake-up fails", Tag(wakeUpTimeout)) { f => + test("fail to relay to blinded path when wake-up fails", Tag(wakeUpEnabled), Tag(wakeUpTimeout)) { f => import f._ val peerReadyManager = TestProbe[PeerReadyManager.Register]()