From 8da0fc6dac9a912d3ed6593ea7118c4e22257790 Mon Sep 17 00:00:00 2001 From: t-bast Date: Tue, 11 Jun 2024 18:02:49 +0200 Subject: [PATCH] Add wake-up step to channel and message relay We allow relaying data to contain a wallet `node_id` instead of an scid. When that's the case, we start by waking up that wallet node before we try relaying onion messages or payments. --- .../scala/fr/acinq/eclair/NodeParams.scala | 6 +- .../fr/acinq/eclair/io/MessageRelay.scala | 102 +++++----- .../acinq/eclair/io/PeerReadyNotifier.scala | 2 +- .../fr/acinq/eclair/payment/Monitoring.scala | 1 + .../acinq/eclair/payment/PaymentPacket.scala | 2 +- .../eclair/payment/relay/ChannelRelay.scala | 92 ++++++--- .../eclair/payment/relay/ChannelRelayer.scala | 11 +- .../eclair/payment/relay/NodeRelay.scala | 12 +- .../payment/send/BlindedPathsResolver.scala | 62 +++--- .../eclair/router/BlindedRouteCreation.scala | 19 +- .../eclair/wire/protocol/PaymentOnion.scala | 13 +- .../eclair/wire/protocol/RouteBlinding.scala | 10 +- .../scala/fr/acinq/eclair/TestConstants.scala | 6 +- .../fr/acinq/eclair/crypto/SphinxSpec.scala | 6 +- .../fr/acinq/eclair/io/MessageRelaySpec.scala | 34 +++- .../eclair/payment/PaymentPacketSpec.scala | 18 +- .../payment/relay/ChannelRelayerSpec.scala | 68 ++++++- .../payment/relay/NodeRelayerSpec.scala | 177 +++++++++++++----- .../send/BlindedPathsResolverSpec.scala | 34 +++- .../wire/protocol/PaymentOnionSpec.scala | 20 +- 20 files changed, 495 insertions(+), 200 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 86cc13e091..13c5715df4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -87,7 +87,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, blockchainWatchdogSources: Seq[String], onionMessageConfig: OnionMessageConfig, purgeInvoicesInterval: Option[FiniteDuration], - revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config) { + revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config, + wakeUpTimeout: FiniteDuration) { val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey val nodeId: PublicKey = nodeKeyManager.nodeId @@ -610,7 +611,8 @@ object NodeParams extends Logging { revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config( batchSize = config.getInt("db.revoked-htlc-info-cleaner.batch-size"), interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS) - ) + ), + wakeUpTimeout = 30 seconds, ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index 11b9fc9079..49568254b1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -16,10 +16,10 @@ package fr.acinq.eclair.io -import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} +import akka.actor.typed.{Behavior, SupervisorStrategy} import akka.actor.{ActorRef, typed} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -32,6 +32,8 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionMessage import fr.acinq.eclair.{EncodedNodeId, NodeParams, ShortChannelId} +import scala.concurrent.duration.DurationInt + object MessageRelay { // @formatter:off sealed trait Command @@ -42,29 +44,18 @@ object MessageRelay { policy: RelayPolicy, replyTo_opt: Option[typed.ActorRef[Status]]) extends Command case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command - case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command - case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command + private case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command + private case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command - sealed trait Status { - val messageId: ByteVector32 - } + sealed trait Status { val messageId: ByteVector32 } case class Sent(messageId: ByteVector32) extends Status sealed trait Failure extends Status - case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { - override def toString: String = s"Relay prevented by policy $policy" - } - case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { - override def toString: String = s"Can't connect to peer: ${failure.toString}" - } - case class Disconnected(messageId: ByteVector32) extends Failure { - override def toString: String = "Peer is not connected" - } - case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { - override def toString: String = s"Unknown channel: $channelId" - } - case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { - override def toString: String = s"Message dropped: $reason" - } + case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { override def toString: String = s"Relay prevented by policy $policy" } + case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { override def toString: String = s"Can't connect to peer: ${failure.toString}" } + case class Disconnected(messageId: ByteVector32) extends Failure { override def toString: String = "Peer is not connected" } + case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { override def toString: String = s"Unknown channel: $channelId" } + case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { override def toString: String = s"Message dropped: $reason" } sealed trait RelayPolicy case object RelayChannelsOnly extends RelayPolicy @@ -100,7 +91,7 @@ private class MessageRelay(nodeParams: NodeParams, def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = { nextNode match { case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf => - withNextNodeId(msg, nodeParams.nodeId) + withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nodeParams.nodeId)) case Left(outgoingChannelId) => register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId) waitForNextNodeId(msg, outgoingChannelId) @@ -108,7 +99,7 @@ private class MessageRelay(nodeParams: NodeParams, router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1) waitForNextNodeId(msg, scid) case Right(encodedNodeId: EncodedNodeId.WithPublicKey) => - withNextNodeId(msg, encodedNodeId.publicKey) + withNextNodeId(msg, encodedNodeId) } } @@ -118,33 +109,39 @@ private class MessageRelay(nodeParams: NodeParams, replyTo_opt.foreach(_ ! UnknownChannel(messageId, channelId)) Behaviors.stopped case WrappedOptionalNodeId(Some(nextNodeId)) => - withNextNodeId(msg, nextNodeId) + withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nextNodeId)) } } - private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { - if (nextNodeId == nodeParams.nodeId) { - OnionMessages.process(nodeParams.privateKey, msg) match { - case OnionMessages.DropMessage(reason) => - replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) - Behaviors.stopped - case OnionMessages.SendMessage(nextNode, nextMessage) => - // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. - queryNextNodeId(nextMessage, nextNode) - case received: OnionMessages.ReceiveMessage => - context.system.eventStream ! EventStream.Publish(received) - replyTo_opt.foreach(_ ! Sent(messageId)) - Behaviors.stopped - } - } else { - policy match { - case RelayChannelsOnly => - switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) - waitForPreviousPeerForPolicyCheck(msg, nextNodeId) - case RelayAll => - switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) - waitForConnection(msg) - } + private def withNextNodeId(msg: OnionMessage, nextNodeId: EncodedNodeId.WithPublicKey): Behavior[Command] = { + nextNodeId match { + case EncodedNodeId.WithPublicKey.Plain(nodeId) if nodeId == nodeParams.nodeId => + OnionMessages.process(nodeParams.privateKey, msg) match { + case OnionMessages.DropMessage(reason) => + replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) + Behaviors.stopped + case OnionMessages.SendMessage(nextNode, nextMessage) => + // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. + queryNextNodeId(nextMessage, nextNode) + case received: OnionMessages.ReceiveMessage => + context.system.eventStream ! EventStream.Publish(received) + replyTo_opt.foreach(_ ! Sent(messageId)) + Behaviors.stopped + } + case EncodedNodeId.WithPublicKey.Plain(nodeId) => + policy match { + case RelayChannelsOnly => + switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) + waitForPreviousPeerForPolicyCheck(msg, nodeId) + case RelayAll => + switchboard ! Peer.Connect(nodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) + waitForConnection(msg) + } + case EncodedNodeId.WithPublicKey.Wallet(nodeId) => + context.log.info("trying to wake up next peer to relay onion message (nodeId={})", nodeId) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) + waitForWalletNodeUp(msg) } } @@ -180,4 +177,15 @@ private class MessageRelay(nodeParams: NodeParams, Behaviors.stopped } } + + private def waitForWalletNodeUp(msg: OnionMessage): Behavior[Command] = { + Behaviors.receiveMessagePartial { + case WrappedPeerReadyResult(r: PeerReadyNotifier.PeerReady) => + r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt) + Behaviors.stopped + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + replyTo_opt.foreach(_ ! Disconnected(messageId)) + Behaviors.stopped + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala index 4fad93e55e..fbcd09c645 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerReadyNotifier.scala @@ -83,7 +83,7 @@ object PeerReadyNotifier { case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) => listings.headOption match { case Some(switchboard) => - waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers) + waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers) case None => context.log.error("no switchboard found") replyTo ! PeerUnavailable(remoteNodeId) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala index d59491c903..20ae3fe823 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala @@ -115,6 +115,7 @@ object Monitoring { val Failure = "failure" object FailureType { + val WakeUp = "WakeUp" val Remote = "Remote" val Malformed = "MalformedHtlc" diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index 2f72fe2846..723cc89073 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -128,7 +128,7 @@ object IncomingPaymentPacket { decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap { case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) => validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket).flatMap { - case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoingChannelId == ShortChannelId.toSelf => + case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoing == Right(ShortChannelId.toSelf) => decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Set(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features) case relayPacket => Right(relayPacket) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index a68ff75dd7..0457fdc01e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -17,21 +17,23 @@ package fr.acinq.eclair.payment.relay import akka.actor.ActorRef -import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} +import akka.actor.typed.{Behavior, SupervisorStrategy} import fr.acinq.bitcoin.scalacompat.ByteVector32 +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.PendingCommandsDb +import fr.acinq.eclair.io.PeerReadyNotifier import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.relay.Relayer.{OutgoingChannel, OutgoingChannelParams} import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket} import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Logs, NodeParams, TimestampMilli, TimestampSecond, channel, nodeFee} +import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, TimestampMilli, TimestampSecond, channel, nodeFee} import java.util.UUID import java.util.concurrent.TimeUnit @@ -43,6 +45,7 @@ object ChannelRelay { // @formatter:off sealed trait Command private case object DoRelay extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command private case class WrappedForwardFailure(failure: Register.ForwardFailure[CMD_ADD_HTLC]) extends Command private case class WrappedAddResponse(res: CommandResponse[CMD_ADD_HTLC]) extends Command // @formatter:on @@ -60,8 +63,7 @@ object ChannelRelay { parentPaymentId_opt = Some(relayId), // for a channel relay, parent payment id = relay id paymentHash_opt = Some(r.add.paymentHash), nodeAlias_opt = Some(nodeParams.alias))) { - context.self ! DoRelay - new ChannelRelay(nodeParams, register, channels, r, context).relay(Seq.empty) + new ChannelRelay(nodeParams, register, channels, r, context).start() } } @@ -69,7 +71,7 @@ object ChannelRelay { * This helper method translates relaying errors (returned by the downstream outgoing channel) to BOLT 4 standard * errors that we should return upstream. */ - def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { + private def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { (error, channelUpdate_opt) match { case (_: ExpiryTooSmall, Some(channelUpdate)) => ExpiryTooSoon(Some(channelUpdate)) case (_: ExpiryTooBig, _) => ExpiryTooFar() @@ -112,16 +114,55 @@ class ChannelRelay private(nodeParams: NodeParams, private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure) private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse) + private val nextNodeId_opt = r.payload.outgoing match { + case Left(walletNodeId) => Some(walletNodeId) + // All the channels point to the same next node, we take the first one. + case Right(_) => channels.headOption.map(_._2.nextNodeId) + } + + /** Channel id explicitly requested in the onion payload. */ + private val requestedChannelId_opt = r.payload.outgoing match { + case Left(_) => None + case Right(outgoingChannelId) => channels.collectFirst { + case (channelId, channel) if channel.shortIds.localAlias == outgoingChannelId => channelId + case (channelId, channel) if channel.shortIds.real.toOption.contains(outgoingChannelId) => channelId + } + } + private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException]) - def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { + def start(): Behavior[Command] = { + r.payload.outgoing match { + case Left(walletNodeId) => wakeUp(walletNodeId) + case Right(requestedShortChannelId) => + context.self ! DoRelay + relay(Some(requestedShortChannelId), Seq.empty) + } + } + + private def wakeUp(walletNodeId: PublicKey): Behavior[Command] = { + context.log.info("trying to wake up channel peer (nodeId={})", walletNodeId) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) + Behaviors.receiveMessagePartial { + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + Metrics.recordPaymentRelayFailed(Tags.FailureType.WakeUp, Tags.RelayType.Channel) + context.log.info("rejecting htlc: failed to wake-up remote peer") + safeSendAndStop(r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) => + context.self ! DoRelay + relay(None, Seq.empty) + } + } + + def relay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { Behaviors.receiveMessagePartial { case DoRelay => if (previousFailures.isEmpty) { - context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, r.payload.outgoingChannelId, nextNodeId_opt.getOrElse("")) + context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, requestedShortChannelId_opt, nextNodeId_opt.getOrElse("")) } context.log.debug("attempting relay previousAttempts={}", previousFailures.size) - handleRelay(previousFailures) match { + handleRelay(requestedShortChannelId_opt, previousFailures) match { case RelayFailure(cmdFail) => Metrics.recordPaymentRelayFailed(Tags.FailureType(cmdFail), Tags.RelayType.Channel) context.log.info("rejecting htlc reason={}", cmdFail.reason) @@ -129,12 +170,12 @@ class ChannelRelay private(nodeParams: NodeParams, case RelaySuccess(selectedChannelId, cmdAdd) => context.log.info("forwarding htlc #{} from channelId={} to channelId={}", r.add.id, r.add.channelId, selectedChannelId) register ! Register.Forward(forwardFailureAdapter, selectedChannelId, cmdAdd) - waitForAddResponse(selectedChannelId, previousFailures) + waitForAddResponse(selectedChannelId, requestedShortChannelId_opt, previousFailures) } } } - def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] = + private def waitForAddResponse(selectedChannelId: ByteVector32, requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, CMD_ADD_HTLC(_, _, _, _, _, _, o: Origin.ChannelRelayedHot, _)))) => context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${o.add.id}") @@ -145,14 +186,14 @@ class ChannelRelay private(nodeParams: NodeParams, case WrappedAddResponse(addFailed@RES_ADD_FAILED(CMD_ADD_HTLC(_, _, _, _, _, _, _: Origin.ChannelRelayedHot, _), _, _)) => context.log.info("attempt failed with reason={}", addFailed.t.getClass.getSimpleName) context.self ! DoRelay - relay(previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) + relay(requestedShortChannelId_opt, previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) case WrappedAddResponse(_: RES_SUCCESS[_]) => context.log.debug("sent htlc to the downstream channel") waitForAddSettled() } - def waitForAddSettled(): Behavior[Command] = + private def waitForAddSettled(): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedAddResponse(RES_ADD_SETTLED(o: Origin.ChannelRelayedHot, htlc, fulfill: HtlcResult.Fulfill)) => context.log.debug("relaying fulfill to upstream") @@ -169,7 +210,7 @@ class ChannelRelay private(nodeParams: NodeParams, safeSendAndStop(o.originChannelId, cmd) } - def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { + private def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { val toSend = cmd match { case _: CMD_FULFILL_HTLC => cmd case _: CMD_FAIL_HTLC | _: CMD_FAIL_MALFORMED_HTLC => r.payload match { @@ -200,9 +241,9 @@ class ChannelRelay private(nodeParams: NodeParams, * - a CMD_FAIL_HTLC to be sent back upstream * - a CMD_ADD_HTLC to propagate downstream */ - def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = { + private def handleRelay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): RelayResult = { val alreadyTried = previousFailures.map(_.channelId) - selectPreferredChannel(alreadyTried) match { + selectPreferredChannel(requestedShortChannelId_opt, alreadyTried) match { case None if previousFailures.nonEmpty => // no more channels to try val error = previousFailures @@ -217,24 +258,14 @@ class ChannelRelay private(nodeParams: NodeParams, } } - /** all the channels point to the same next node, we take the first one */ - private val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId) - - /** channel id explicitly requested in the onion payload */ - private val requestedChannelId_opt = channels.collectFirst { - case (channelId, channel) if channel.shortIds.localAlias == r.payload.outgoingChannelId => channelId - case (channelId, channel) if channel.shortIds.real.toOption.contains(r.payload.outgoingChannelId) => channelId - } - /** * Select a channel to the same node to relay the payment to, that has the lowest capacity and balance and is * compatible in terms of fees, expiry_delta, etc. * * If no suitable channel is found we default to the originally requested channel. */ - def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { - val requestedShortChannelId = r.payload.outgoingChannelId - context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId) + private def selectPreferredChannel(requestedShortChannelId_opt: Option[ShortChannelId], alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { + context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId_opt) // we filter out channels that we have already tried val candidateChannels: Map[ByteVector32, OutgoingChannel] = channels -- alreadyTried // and we filter again to keep the ones that are compatible with this payment (mainly fees, expiry delta) @@ -242,7 +273,8 @@ class ChannelRelay private(nodeParams: NodeParams, .values .map { channel => val relayResult = relayOrFail(Some(channel)) - context.log.debug(s"candidate channel: channelId=${channel.channelId} availableForSend={} capacity={} channelUpdate={} result={}", + context.log.debug("candidate channel: channelId={} availableForSend={} capacity={} channelUpdate={} result={}", + channel.channelId, channel.commitments.availableBalanceForSend, channel.commitments.latest.capacity, channel.channelUpdate, @@ -268,7 +300,7 @@ class ChannelRelay private(nodeParams: NodeParams, context.log.debug("requested short channel id is our preferred channel") Some(channel) } else { - context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend) + context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId_opt, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend) Some(channel) } case None => @@ -289,7 +321,7 @@ class ChannelRelay private(nodeParams: NodeParams, * channel, because some parameters don't match with our settings for that channel. In that case we directly fail the * htlc. */ - def relayOrFail(outgoingChannel_opt: Option[OutgoingChannelParams]): RelayResult = { + private def relayOrFail(outgoingChannel_opt: Option[OutgoingChannelParams]): RelayResult = { outgoingChannel_opt match { case None => RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala index 59eb7b58b0..355ea83c22 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala @@ -24,7 +24,7 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.payment.IncomingPaymentPacket -import fr.acinq.eclair.{SubscriptionsComplete, Logs, NodeParams, ShortChannelId} +import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, SubscriptionsComplete} import java.util.UUID import scala.collection.mutable @@ -71,9 +71,12 @@ object ChannelRelayer { Behaviors.receiveMessage { case Relay(channelRelayPacket) => val relayId = UUID.randomUUID() - val nextNodeId_opt: Option[PublicKey] = scid2channels.get(channelRelayPacket.payload.outgoingChannelId) match { - case Some(channelId) => channels.get(channelId).map(_.nextNodeId) - case None => None + val nextNodeId_opt: Option[PublicKey] = channelRelayPacket.payload.outgoing match { + case Left(walletNodeId) => Some(walletNodeId) + case Right(outgoingChannelId) => scid2channels.get(outgoingChannelId) match { + case Some(channelId) => channels.get(channelId).map(_.nextNodeId) + case None => None + } } val nextChannels: Map[ByteVector32, Relayer.OutgoingChannel] = nextNodeId_opt match { case Some(nextNodeId) => node2channels.get(nextNodeId).flatMap(channels.get).map(c => c.channelId -> c).toMap diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 9051cb37ef..c2679b758e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -42,7 +42,7 @@ import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, EncodedNodeId, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} import java.util.UUID import java.util.concurrent.TimeUnit @@ -137,7 +137,13 @@ object NodeRelay { } private def shouldWakeUpNextNode(nodeParams: NodeParams, recipient: Recipient): Boolean = { - false + recipient match { + case r: BlindedRecipient => r.blindedHops.head.resolved.route match { + case BlindedPathsResolver.PartialBlindedRoute(_: EncodedNodeId.WithPublicKey.Wallet, _, _) => true + case _ => false + } + case _ => false + } } /** When we have identified that the next node is one of our peers, return their (real) nodeId. */ @@ -299,7 +305,7 @@ class NodeRelay private(nodeParams: NodeParams, private def waitForPeerReady(upstream: Upstream.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { val nextNodeId = nodeIdToWakeUp(recipient) context.log.info("trying to wake up next peer (nodeId={})", nextNodeId) - val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nextNodeId, timeout_opt = Some(Left(30 seconds)))).onFailure(SupervisorStrategy.stop)) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nextNodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop)) notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult)) Behaviors.receiveMessagePartial { rejectExtraHtlcPartialFunction orElse { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala index a12799d8d2..12ee84f463 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala @@ -14,7 +14,7 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData import fr.acinq.eclair.wire.protocol.{BlindedRouteData, OfferTypes, RouteBlindingEncryptedDataCodecs} -import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams} +import fr.acinq.eclair.{EncodedNodeId, Logs, MilliSatoshiLong, NodeParams, ShortChannelId} import scodec.bits.ByteVector import scala.annotation.tailrec @@ -45,8 +45,8 @@ object BlindedPathsResolver { override val firstNodeId: PublicKey = introductionNodeId } /** A partially unwrapped blinded route that started at our node: it only contains the part of the route after our node. */ - case class PartialBlindedRoute(nextNodeId: PublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute { - override val firstNodeId: PublicKey = nextNodeId + case class PartialBlindedRoute(nextNodeId: EncodedNodeId.WithPublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute { + override val firstNodeId: PublicKey = nextNodeId.publicKey } // @formatter:on @@ -111,8 +111,14 @@ private class BlindedPathsResolver(nodeParams: NodeParams, feeProportionalMillionths = nextFeeProportionalMillionths, cltvExpiryDelta = nextCltvExpiryDelta ) - register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), paymentRelayData.outgoingChannelId) - waitForNextNodeId(nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + paymentRelayData.outgoing match { + case Left(outgoingNodeId) => + // The next node seems to be a wallet node directly connected to us. + validateRelay(EncodedNodeId.WithPublicKey.Wallet(outgoingNodeId), nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + case Right(outgoingChannelId) => + register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), outgoingChannelId) + waitForNextNodeId(outgoingChannelId, nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + } } } case encodedNodeId: EncodedNodeId.WithPublicKey => @@ -129,7 +135,8 @@ private class BlindedPathsResolver(nodeParams: NodeParams, } /** Resolve the next node in the blinded path when we are the introduction node. */ - private def waitForNextNodeId(nextPaymentInfo: OfferTypes.PaymentInfo, + private def waitForNextNodeId(outgoingChannelId: ShortChannelId, + nextPaymentInfo: OfferTypes.PaymentInfo, paymentRelayData: BlindedRouteData.PaymentRelayData, nextBlinding: PublicKey, nextBlindedNodes: Seq[RouteBlinding.BlindedNode], @@ -137,28 +144,41 @@ private class BlindedPathsResolver(nodeParams: NodeParams, resolved: Seq[ResolvedPath]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedNodeId(None) => - context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", paymentRelayData.outgoingChannelId) + context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", outgoingChannelId) resolveBlindedPaths(toResolve, resolved) case WrappedNodeId(Some(nodeId)) if nodeId == nodeParams.nodeId => // The next node in the route is also our node: this is fishy, there is not reason to include us in the route twice. context.log.warn("ignoring blinded path starting at our node relaying to ourselves") resolveBlindedPaths(toResolve, resolved) case WrappedNodeId(Some(nodeId)) => - // Note that we default to private fees if we don't have a channel yet with that node. - // The announceChannel parameter is ignored if we already have a channel. - val relayFees = getRelayFees(nodeParams, nodeId, announceChannel = false) - val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase && - paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths && - paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta - if (shouldRelay) { - context.log.debug("unwrapped blinded path starting at our node: next_node={}", nodeId) - val path = ResolvedPath(PartialBlindedRoute(nodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo) - resolveBlindedPaths(toResolve, resolved :+ path) - } else { - context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta) - resolveBlindedPaths(toResolve, resolved) - } + validateRelay(EncodedNodeId.WithPublicKey.Plain(nodeId), nextPaymentInfo, paymentRelayData, nextBlinding, nextBlindedNodes, toResolve, resolved) + } + + private def validateRelay(nextNodeId: EncodedNodeId.WithPublicKey, + nextPaymentInfo: OfferTypes.PaymentInfo, + paymentRelayData: BlindedRouteData.PaymentRelayData, + nextBlinding: PublicKey, + nextBlindedNodes: Seq[RouteBlinding.BlindedNode], + toResolve: Seq[PaymentBlindedRoute], + resolved: Seq[ResolvedPath]): Behavior[Command] = { + // Note that we default to private fees if we don't have a channel yet with that node. + // The announceChannel parameter is ignored if we already have a channel. + val relayFees = getRelayFees(nodeParams, nextNodeId.publicKey, announceChannel = false) + val shouldRelay = paymentRelayData.paymentRelay.feeBase >= relayFees.feeBase && + paymentRelayData.paymentRelay.feeProportionalMillionths >= relayFees.feeProportionalMillionths && + paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta && + nextPaymentInfo.feeBase >= 0.msat && + nextPaymentInfo.feeProportionalMillionths >= 0 && + nextPaymentInfo.cltvExpiryDelta.toInt >= 0 + if (shouldRelay) { + context.log.debug("unwrapped blinded path starting at our node: next_node={}", nextNodeId.publicKey) + val path = ResolvedPath(PartialBlindedRoute(nextNodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo) + resolveBlindedPaths(toResolve, resolved :+ path) + } else { + context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta) + resolveBlindedPaths(toResolve, resolved) } + } /** Resolve the introduction node's [[EncodedNodeId.ShortChannelIdDir]] to the corresponding [[EncodedNodeId.WithPublicKey]]. */ private def waitForNodeId(paymentRoute: PaymentBlindedRoute, toResolve: Seq[PaymentBlindedRoute], resolved: Seq[ResolvedPath]): Behavior[Command] = diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala index 81605c5cc1..37e31c7ff8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala @@ -21,7 +21,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.{RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshi, MilliSatoshiLong, randomKey} import scodec.bits.ByteVector object BlindedRouteCreation { @@ -77,7 +77,7 @@ object BlindedRouteCreation { Total: 24 to 36 bytes */ val targetLength = 36 - val paddedPayloads = payloads.map(tlvs =>{ + val paddedPayloads = payloads.map(tlvs => { val payloadLength = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes.length tlvs.copy(records = tlvs.records + RouteBlindingEncryptedDataTlv.Padding(ByteVector.fill(targetLength - payloadLength)(0))) }) @@ -95,4 +95,19 @@ object BlindedRouteCreation { Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) } + /** Create a blinded route where the recipient is a wallet node. */ + def createBlindedRouteToWallet(hop: Router.ChannelHop, pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { + val routeExpiry = routeFinalExpiry + hop.cltvExpiryDelta + val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), + RouteBlindingEncryptedDataTlv.PathId(pathId), + )).require.bytes + val intermediatePayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream[RouteBlindingEncryptedDataTlv]( + RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(hop.nextNodeId)), + RouteBlindingEncryptedDataTlv.PaymentRelay(hop.cltvExpiryDelta, hop.params.relayFees.feeProportionalMillionths, hop.params.relayFees.feeBase), + RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), + )).require.bytes + Sphinx.RouteBlinding.create(randomKey(), Seq(hop.nodeId, hop.nextNodeId), Seq(intermediatePayload, finalPayload)) + } + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index 10bd99f443..4468ed7172 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -23,7 +23,7 @@ import fr.acinq.eclair.wire.protocol.BlindedRouteData.PaymentRelayData import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs._ -import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.{BitVector, ByteVector} /** @@ -227,7 +227,8 @@ object PaymentOnion { object IntermediatePayload { sealed trait ChannelRelay extends IntermediatePayload { // @formatter:off - def outgoingChannelId: ShortChannelId + /** The outgoing channel, or the nodeId of one of our peers. */ + def outgoing: Either[PublicKey, ShortChannelId] def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry // @formatter:on @@ -238,7 +239,7 @@ object PaymentOnion { // @formatter:off val amountOut = records.get[AmountToForward].get.amount val cltvOut = records.get[OutgoingCltv].get.cltv - override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId + override val outgoing = Right(records.get[OutgoingChannelId].get.shortChannelId) override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = amountOut override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = cltvOut // @formatter:on @@ -258,12 +259,12 @@ object PaymentOnion { } /** - * @param blindedRecords decrypted tlv stream from the encrypted_recipient_data tlv. - * @param nextBlinding blinding point that must be forwarded to the next hop. + * @param paymentRelayData decrypted relaying data from the encrypted_recipient_data tlv. + * @param nextBlinding blinding point that must be forwarded to the next hop. */ case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], paymentRelayData: PaymentRelayData, nextBlinding: PublicKey) extends ChannelRelay { // @formatter:off - override val outgoingChannelId = paymentRelayData.outgoingChannelId + override val outgoing = paymentRelayData.outgoing override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = paymentRelayData.amountToForward(incomingAmount) override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = paymentRelayData.outgoingCltv(incomingCltv) // @formatter:on diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 096c1f8edf..be53f4aab0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -98,7 +98,11 @@ object BlindedRouteData { } case class PaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]) { - val outgoingChannelId: ShortChannelId = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId + // This is usually a channel, unless the next node is a mobile wallet connected to our node. + val outgoing: Either[PublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match { + case Some(r) => Right(r.shortChannelId) + case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey.Wallet].publicKey) + } val paymentRelay: PaymentRelay = records.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get val allowedFeatures: Features[Feature] = records.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty) @@ -110,7 +114,9 @@ object BlindedRouteData { } def validatePaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, PaymentRelayData] = { - if (records.get[OutgoingChannelId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) + // Note that the BOLTs require using an OutgoingChannelId, but we optionally support a wallet node_id. + if (records.get[OutgoingChannelId].isEmpty && records.get[OutgoingNodeId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) + if (records.get[OutgoingNodeId].nonEmpty && !records.get[OutgoingNodeId].get.nodeId.isInstanceOf[EncodedNodeId.WithPublicKey.Wallet]) return Left(ForbiddenTlv(UInt64(4))) if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) if (records.get[PaymentConstraints].isEmpty) return Left(MissingRequiredTlv(UInt64(12))) if (records.get[PathId].nonEmpty) return Left(ForbiddenTlv(UInt64(6))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 421bc31e30..a92aea65a4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -227,7 +227,8 @@ object TestConstants { maxAttempts = 2, ), purgeInvoicesInterval = None, - revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis) + revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), + wakeUpTimeout = 30 seconds, ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( @@ -394,7 +395,8 @@ object TestConstants { maxAttempts = 2, ), purgeInvoicesInterval = None, - revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis) + revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(10, 100 millis), + wakeUpTimeout = 30 seconds, ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index e4f064ee3e..57a93308a1 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -505,7 +505,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadBob) = RouteBlindingEncryptedDataCodecs.decode(bob, blinding, tlvsBob.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForCarol = decryptedPayloadBob.nextBlinding val Right(payloadBob) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsBob, decryptedPayloadBob.tlvs, blindingEphemeralKeyForCarol) - assert(payloadBob.outgoingChannelId == ShortChannelId(1)) + assert(payloadBob.outgoing.contains(ShortChannelId(1))) assert(payloadBob.amountToForward(110_125 msat) == 100_125.msat) assert(payloadBob.outgoingCltv(CltvExpiry(749150)) == CltvExpiry(749100)) assert(payloadBob.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(50), 0, 10_000 msat)) @@ -523,7 +523,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadCarol) = RouteBlindingEncryptedDataCodecs.decode(carol, blindingEphemeralKeyForCarol, tlvsCarol.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForDave = decryptedPayloadCarol.nextBlinding val Right(payloadCarol) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsCarol, decryptedPayloadCarol.tlvs, blindingEphemeralKeyForDave) - assert(payloadCarol.outgoingChannelId == ShortChannelId(2)) + assert(payloadCarol.outgoing.contains(ShortChannelId(2))) assert(payloadCarol.amountToForward(100_125 msat) == 100_010.msat) assert(payloadCarol.outgoingCltv(CltvExpiry(749100)) == CltvExpiry(749025)) assert(payloadCarol.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(75), 150, 100 msat)) @@ -544,7 +544,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadDave) = RouteBlindingEncryptedDataCodecs.decode(dave, blindingOverride, tlvsDave.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForEve = decryptedPayloadDave.nextBlinding val Right(payloadDave) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsDave, decryptedPayloadDave.tlvs, blindingEphemeralKeyForEve) - assert(payloadDave.outgoingChannelId == ShortChannelId(3)) + assert(payloadDave.outgoing.contains(ShortChannelId(3))) assert(payloadDave.amountToForward(100_010 msat) == 100_000.msat) assert(payloadDave.outgoingCltv(CltvExpiry(749025)) == CltvExpiry(749000)) assert(payloadDave.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(25), 100, 0 msat)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index 50da4b3c03..63ee150f2f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -19,7 +19,8 @@ package fr.acinq.eclair.io import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe => TypedProbe} import akka.actor.typed.ActorRef import akka.actor.typed.eventstream.EventStream -import akka.actor.typed.scaladsl.adapter.TypedActorRefOps +import akka.actor.typed.receptionist.Receptionist +import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, TypedActorRefOps} import akka.testkit.TestProbe import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -33,8 +34,8 @@ import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient} import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessagePayloadTlv, TlvStream} import fr.acinq.eclair.{EncodedNodeId, RealShortChannelId, ShortChannelId, UInt64, randomBytes32, randomKey} -import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import org.scalatest.{Outcome, Tag} import scodec.bits.HexStringSyntax import scala.concurrent.duration.DurationInt @@ -43,19 +44,24 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val aliceId: PublicKey = Alice.nodeParams.nodeId val bobId: PublicKey = Bob.nodeParams.nodeId + val wakeUpTimeout = "wake_up_timeout" + case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, router: TypedProbe[Router.GetNodeId], peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status]) override def withFixture(test: OneArgTest): Outcome = { val switchboard = TestProbe("switchboard")(system.classicSystem) + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped) val register = TestProbe("register")(system.classicSystem) val router = TypedProbe[Router.GetNodeId]("router") val peerConnection = TypedProbe[Nothing]("peerConnection") val peer = TypedProbe[Peer.RelayOnionMessage]("peer") val probe = TypedProbe[Status]("probe") - val relay = testKit.spawn(MessageRelay(Alice.nodeParams, switchboard.ref, register.ref, router.ref)) + val nodeParams = if (test.tags.contains(wakeUpTimeout)) Alice.nodeParams.copy(wakeUpTimeout = 100 millis) else Alice.nodeParams + val relay = testKit.spawn(MessageRelay(nodeParams, switchboard.ref, register.ref, router.ref)) try { withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, router, peerConnection, peer, probe))) } finally { + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref.toTyped) testKit.stop(relay) } } @@ -86,6 +92,19 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) } + test("relay after waking up next node") { f => + import f._ + + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + val messageId = randomBytes32() + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, None) + + val request = switchboard.expectMsgType[GetPeerInfo] + assert(request.remoteNodeId == bobId) + request.replyTo ! Peer.PeerInfo(peer.ref.toClassic, bobId, Peer.CONNECTED, None, Set.empty) + assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) + } + test("can't open new connection") { f => import f._ @@ -99,6 +118,15 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app probe.expectMessage(ConnectionFailure(messageId, PeerConnection.ConnectionResult.NoAddressFound)) } + test("can't wake up next node", Tag(wakeUpTimeout)) { f => + import f._ + + val Right(message) = OnionMessages.buildMessage(randomKey(), randomKey(), Seq(), Recipient(bobId, None), TlvStream.empty) + val messageId = randomBytes32() + relay ! RelayMessage(messageId, randomKey().publicKey, Right(EncodedNodeId.WithPublicKey.Wallet(bobId)), message, RelayChannelsOnly, Some(probe.ref)) + probe.expectMessage(Disconnected(messageId)) + } + test("no channel with previous node") { f => import f._ diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index c5f8e9c16f..66afdec537 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -86,7 +86,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward == amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) @@ -96,7 +96,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_c.amountToForward == amount_cd) assert(relay_c.outgoingCltv == expiry_cd) - assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId)) assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) @@ -106,7 +106,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_d.amountToForward == amount_de) assert(relay_d.outgoingCltv == expiry_de) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) @@ -176,7 +176,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward >= amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) @@ -186,7 +186,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_c.amountToForward >= amount_cd) assert(relay_c.outgoingCltv == expiry_cd) - assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId)) assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) assert(payload_c.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) @@ -197,7 +197,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_d.amountToForward >= amount_de) assert(relay_d.outgoingCltv == expiry_de) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) @@ -239,7 +239,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward >= amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) @@ -548,7 +548,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // A smaller amount is sent to d, who doesn't know that it's invalid. val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.amountToForward < amount_de) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding @@ -570,7 +570,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invalidExpiry = payment.cmd.cltvExpiry - Channel.MIN_CLTV_EXPIRY_DELTA - CltvExpiryDelta(1) val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, invalidExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.outgoingCltv < CltvExpiry(currentBlockCount)) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala index 923dbf05b9..96881276a6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/ChannelRelayerSpec.scala @@ -19,6 +19,7 @@ package fr.acinq.eclair.payment.relay import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.typed import akka.actor.typed.eventstream.EventStream +import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory @@ -29,6 +30,7 @@ import fr.acinq.eclair.TestConstants.emptyOnionPacket import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.io.{Peer, Switchboard} import fr.acinq.eclair.payment.IncomingPaymentPacket.ChannelRelayPacket import fr.acinq.eclair.payment.relay.ChannelRelayer._ import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket, PaymentPacketSpec} @@ -39,19 +41,23 @@ import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload.ChannelRel import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{CltvExpiry, NodeParams, RealShortChannelId, TestConstants, randomBytes32, _} import org.scalatest.Inside.inside -import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import org.scalatest.{Outcome, Tag} import scodec.bits.HexStringSyntax +import scala.concurrent.duration.DurationInt + class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { import ChannelRelayerSpec._ + val wakeUpTimeout = "wake_up_timeout" + case class FixtureParam(nodeParams: NodeParams, channelRelayer: typed.ActorRef[ChannelRelayer.Command], register: TestProbe[Any]) override def withFixture(test: OneArgTest): Outcome = { // we are node B in the route A -> B -> C -> .... - val nodeParams = TestConstants.Bob.nodeParams + val nodeParams = if (test.tags.contains(wakeUpTimeout)) TestConstants.Bob.nodeParams.copy(wakeUpTimeout = 100 millis) else TestConstants.Bob.nodeParams val register = TestProbe[Any]("register") val channelRelayer = testKit.spawn(ChannelRelayer.apply(nodeParams, register.ref.toClassic)) try { @@ -159,7 +165,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a import f._ val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) - val payload = createBlindedPayload(u.channelUpdate, isIntroduction = false) + val payload = createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction = false) val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) @@ -168,6 +174,30 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) } + test("relay blinded payment (wake up wallet node)") { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) + Seq(true, false).foreach(isIntroduction => { + val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction) + val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + + channelRelayer ! WrappedLocalChannelUpdate(u) + channelRelayer ! Relay(r) + + // We try to wake-up the next node. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + expectFwdAdd(register, channelIds(realScid1), outgoingAmount, outgoingExpiry) + }) + + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + } + test("relay with retries") { f => import f._ @@ -272,7 +302,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a Seq(true, false).foreach { isIntroduction => // The outgoing channel is disabled, so we won't be able to relay the payment. val u = createLocalUpdate(channelId1, feeBaseMsat = 5000 msat, feeProportionalMillionths = 0, enabled = false) - val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! Relay(r) @@ -295,6 +325,27 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a } } + test("fail to relay blinded payment (cannot wake up remote node)", Tag(wakeUpTimeout)) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val u = createLocalUpdate(channelId1, feeBaseMsat = 2500 msat, feeProportionalMillionths = 0) + val payload = createBlindedPayload(Left(outgoingNodeId), u.channelUpdate, isIntroduction = true) + val r = createValidIncomingPacket(payload, outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + + channelRelayer ! WrappedLocalChannelUpdate(u) + channelRelayer ! Relay(r) + + // We try to wake-up the next node, but we timeout before they connect. + assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId) + val fail = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] + assert(fail.message.reason.contains(InvalidOnionBlinding(Sphinx.hash(r.add.onionRoutingPacket)))) + + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + } + test("relay when expiry larger than our requirements") { f => import f._ @@ -521,7 +572,7 @@ class ChannelRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("a Seq(true, false).foreach { isIntroduction => testCases.foreach { htlcResult => - val r = createValidIncomingPacket(createBlindedPayload(u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) + val r = createValidIncomingPacket(createBlindedPayload(Right(u.channelUpdate.shortChannelId), u.channelUpdate, isIntroduction), outgoingAmount + u.channelUpdate.feeBaseMsat, outgoingExpiry + u.channelUpdate.cltvExpiryDelta) channelRelayer ! WrappedLocalChannelUpdate(u) channelRelayer ! Relay(r) val fwd = expectFwdAdd(register, channelId1, outgoingAmount, outgoingExpiry) @@ -655,13 +706,16 @@ object ChannelRelayerSpec { localAlias2 -> channelId2, ) - def createBlindedPayload(update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = { + def createBlindedPayload(outgoing: Either[PublicKey, ShortChannelId], update: ChannelUpdate, isIntroduction: Boolean): ChannelRelay.Blinded = { val tlvs = TlvStream[OnionPaymentPayloadTlv](Set( Some(OnionPaymentPayloadTlv.EncryptedRecipientData(hex"2a")), if (isIntroduction) Some(OnionPaymentPayloadTlv.BlindingPoint(randomKey().publicKey)) else None, ).flatten[OnionPaymentPayloadTlv]) val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv]( - RouteBlindingEncryptedDataTlv.OutgoingChannelId(update.shortChannelId), + outgoing match { + case Left(nodeId) => RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(nodeId)) + case Right(scid) => RouteBlindingEncryptedDataTlv.OutgoingChannelId(scid) + }, RouteBlindingEncryptedDataTlv.PaymentRelay(update.cltvExpiryDelta, update.feeProportionalMillionths, update.feeBaseMsat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(500_000), 0 msat), ) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala index d294bd1d58..8213838f2c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala @@ -20,17 +20,19 @@ import akka.actor.Status import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.typed.ActorRef import akka.actor.typed.eventstream.EventStream +import akka.actor.typed.receptionist.Receptionist import akka.actor.typed.scaladsl.ActorContext import akka.actor.typed.scaladsl.adapter._ import com.softwaremill.quicklens.ModifyPimp import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} -import fr.acinq.bitcoin.scalacompat.{Block, BlockHash, ByteVector32, Crypto} +import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto} import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.{AsyncPaymentPrototype, BasicMultiPartPayment, PaymentSecret, VariableLengthOnion} import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register} import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.io.{Peer, Switchboard} import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.IncomingPaymentPacket.{RelayToBlindedPathsPacket, RelayToTrampolinePacket} import fr.acinq.eclair.payment.Invoice.ExtraEdge @@ -41,16 +43,16 @@ import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPaymentToNode import fr.acinq.eclair.payment.send.{BlindedRecipient, ClearRecipient} -import fr.acinq.eclair.router.Router.RouteRequest -import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound, Router} +import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, RouteRequest} +import fr.acinq.eclair.router.{BalanceTooLow, BlindedRouteCreation, RouteNotFound, Router} import fr.acinq.eclair.wire.protocol.OfferTypes._ import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.{AllowedFeatures, PathId, PaymentConstraints} import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{BlockHeight, Bolt11Feature, CltvExpiry, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes, randomBytes32, randomKey} -import org.scalatest.Outcome +import fr.acinq.eclair.{Alias, BlockHeight, Bolt11Feature, Bolt12Feature, CltvExpiry, CltvExpiryDelta, EncodedNodeId, FeatureSupport, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, RealShortChannelId, ShortChannelId, TestConstants, TimestampMilli, UInt64, randomBytes, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike +import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} import java.util.UUID @@ -93,6 +95,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeParams = TestConstants.Bob.nodeParams .modify(_.multiPartPaymentExpiry).setTo(5 seconds) .modify(_.relayParams.asyncPaymentsParams.holdTimeoutBlocks).setToIf(test.tags.contains("long_hold_timeout"))(200000) // timeout after payment expires + .modify(_.wakeUpTimeout).setToIf(test.tags.contains("wake_up_timeout"))(100 millis) val router = TestProbe[Any]("router") val register = TestProbe[Any]("register") val eventListener = TestProbe[PaymentEvent]("event-listener") @@ -716,26 +719,15 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl } } - def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { - val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes - PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) - } - test("relay to blinded paths without multi-part") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(createPaymentBlindedRoute(outgoingNodeId))) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, None) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(p => Upstream.ReceivedHtlc(p.add, TimestampMilli.now()))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingPayments.map(p => Upstream.ReceivedHtlc(p.add, TimestampMilli.now()))), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -745,7 +737,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeRelayerAdapters = outgoingPayment.replyTo nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) @@ -754,7 +746,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayerAdapters ! createSuccessEvent() val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.outgoing.length == 1) parent.expectMessageType[NodeRelayer.RelayComplete] register.expectNoMessage(100 millis) @@ -763,18 +755,12 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("relay to blinded paths with multi-part") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), payerKey, chain) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), Seq(createPaymentBlindedRoute(outgoingNodeId))) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), None) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(p => Upstream.ReceivedHtlc(p.add, TimestampMilli.now()))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingPayments.map(p => Upstream.ReceivedHtlc(p.add, TimestampMilli.now()))), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] assert(outgoingPayment.recipient.totalAmount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -784,7 +770,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeRelayerAdapters = outgoingPayment.replyTo nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) @@ -793,25 +779,81 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayerAdapters ! createSuccessEvent() val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.outgoing.length == 1) + parent.expectMessageType[NodeRelayer.RelayComplete] + register.expectNoMessage(100 millis) + } + + test("relay to blinded path with wake-up") { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams) + val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + + // The remote node is a wallet node: we try to wake them up before relaying the payment. + val wakeUp = switchboard.expectMessageType[Switchboard.GetPeerInfo] + assert(wakeUp.remoteNodeId == outgoingNodeId) + wakeUp.replyTo ! Peer.PeerInfo(TestProbe[Any]().ref.toClassic, outgoingNodeId, Peer.CONNECTED, None, Set.empty) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] + validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingPayments.map(p => Upstream.ReceivedHtlc(p.add, TimestampMilli.now()))), ignoreNodeId = true) + val outgoingPayment = mockPayFSM.expectMessageType[SendMultiPartPayment] + assert(outgoingPayment.recipient.totalAmount == outgoingAmount) + assert(outgoingPayment.recipient.expiry == outgoingExpiry) + assert(outgoingPayment.recipient.isInstanceOf[BlindedRecipient]) + + // those are adapters for pay-fsm messages + val nodeRelayerAdapters = outgoingPayment.replyTo + + nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) + incomingPayments.foreach { p => + val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] + assert(fwd.channelId == p.add.channelId) + assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) + } + + nodeRelayerAdapters ! createSuccessEvent() + val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] + validateRelayEvent(relayEvent) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.outgoing.length == 1) parent.expectMessageType[NodeRelayer.RelayComplete] register.expectNoMessage(100 millis) } + test("fail to relay to blinded path when wake-up fails", Tag("wake_up_timeout")) { f => + import f._ + + val switchboard = TestProbe[Switchboard.GetPeerInfo]() + system.receptionist ! Receptionist.Register(Switchboard.SwitchboardServiceKey, switchboard.ref) + + val incomingPayments = createIncomingPaymentsToWalletBlindedPath(nodeParams) + val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) + incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) + + // The remote node is a wallet node: we try to wake them up before relaying the payment, but it times out. + assert(switchboard.expectMessageType[Switchboard.GetPeerInfo].remoteNodeId == outgoingNodeId) + system.receptionist ! Receptionist.Deregister(Switchboard.SwitchboardServiceKey, switchboard.ref) + mockPayFSM.expectNoMessage(100 millis) + + incomingPayments.foreach { p => + val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] + assert(fwd.channelId == p.add.channelId) + assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true)) + } + } + test("relay to compact blinded paths") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) - val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId) val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) - val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute)) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir)) val (nodeRelayer, parent) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) @@ -821,7 +863,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl getNodeId.replyTo ! Some(outgoingNodeId) val outgoingCfg = mockPayFSM.expectMessageType[SendPaymentConfig] - validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingMultiPart.map(p => Upstream.ReceivedHtlc(p.add, TimestampMilli.now()))), ignoreNodeId = true) + validateOutgoingCfg(outgoingCfg, Upstream.Trampoline(incomingPayments.map(p => Upstream.ReceivedHtlc(p.add, TimestampMilli.now()))), ignoreNodeId = true) val outgoingPayment = mockPayFSM.expectMessageType[SendPaymentToNode] assert(outgoingPayment.amount == outgoingAmount) assert(outgoingPayment.recipient.expiry == outgoingExpiry) @@ -831,7 +873,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl val nodeRelayerAdapters = outgoingPayment.replyTo nodeRelayerAdapters ! PreimageReceived(paymentHash, paymentPreimage) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FULFILL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FULFILL_HTLC(p.add.id, paymentPreimage, commit = true)) @@ -840,7 +882,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl nodeRelayerAdapters ! createSuccessEvent() val relayEvent = eventListener.expectMessageType[TrampolinePaymentRelayed] validateRelayEvent(relayEvent) - assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingMultiPart.map(i => (i.add.amountMsat, i.add.channelId))) + assert(relayEvent.incoming.map(p => (p.amount, p.channelId)) == incomingPayments.map(i => (i.add.amountMsat, i.add.channelId))) assert(relayEvent.outgoing.length == 1) parent.expectMessageType[NodeRelayer.RelayComplete] register.expectNoMessage(100 millis) @@ -849,16 +891,8 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl test("fail to relay to compact blinded paths with unknown scid") { f => import f._ - val (payerKey, chain) = (randomKey(), BlockHash(randomBytes32())) - val offer = Offer(None, Some("test offer"), outgoingNodeId, Features.empty, chain) - val request = InvoiceRequest(offer, outgoingAmount, 1, Features.empty, payerKey, chain) - val paymentBlindedRoute = createPaymentBlindedRoute(outgoingNodeId) val scidDir = ShortChannelIdDir(isNode1 = true, RealShortChannelId(123456L)) - val compactPaymentBlindedRoute = paymentBlindedRoute.copy(route = paymentBlindedRoute.route.copy(introductionNodeId = scidDir)) - val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, Features.empty, Seq(compactPaymentBlindedRoute)) - val incomingPayments = incomingMultiPart.map(incoming => RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, IntermediatePayload.NodeRelay.ToBlindedPaths( - incoming.innerPayload.amountToForward, outgoingExpiry, invoice - ))) + val incomingPayments = createIncomingPaymentsToRemoteBlindedPath(Features.empty, Some(scidDir)) val (nodeRelayer, _) = f.createNodeRelay(incomingPayments.head) incomingPayments.foreach(incoming => nodeRelayer ! NodeRelay.Relay(incoming)) @@ -869,7 +903,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl mockPayFSM.expectNoMessage(100 millis) - incomingMultiPart.foreach { p => + incomingPayments.foreach { p => val fwd = register.expectMessageType[Register.Forward[CMD_FAIL_HTLC]] assert(fwd.channelId == p.add.channelId) assert(fwd.message == CMD_FAIL_HTLC(p.add.id, Right(UnknownNextPeer()), commit = true)) @@ -962,4 +996,43 @@ object NodeRelayerSpec { nextTrampolinePacket) } + def createPaymentBlindedRoute(nodeId: PublicKey, sessionKey: PrivateKey = randomKey(), pathId: ByteVector = randomBytes32()): PaymentBlindedRoute = { + val selfPayload = blindedRouteDataCodec.encode(TlvStream(PathId(pathId), PaymentConstraints(CltvExpiry(1234567), 0 msat), AllowedFeatures(Features.empty))).require.bytes + PaymentBlindedRoute(Sphinx.RouteBlinding.create(sessionKey, Seq(nodeId), Seq(selfPayload)).route, PaymentInfo(1 msat, 2, CltvExpiryDelta(3), 4 msat, 5 msat, Features.empty)) + } + + /** Create payments to a blinded path that starts at a remote node. */ + def createIncomingPaymentsToRemoteBlindedPath(features: Features[Bolt12Feature], scidDir_opt: Option[EncodedNodeId.ShortChannelIdDir]): Seq[RelayToBlindedPathsPacket] = { + val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash) + val request = InvoiceRequest(offer, outgoingAmount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) + val paymentBlindedRoute = scidDir_opt match { + case Some(scidDir) => + val nonCompact = createPaymentBlindedRoute(outgoingNodeId) + nonCompact.copy(route = nonCompact.route.copy(introductionNodeId = scidDir)) + case None => + createPaymentBlindedRoute(outgoingNodeId) + } + val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(paymentBlindedRoute)) + incomingMultiPart.map(incoming => { + val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice) + RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload) + }) + } + + /** Create payments to a blinded path that starts at our node and relays to a wallet node. */ + def createIncomingPaymentsToWalletBlindedPath(nodeParams: NodeParams): Seq[RelayToBlindedPathsPacket] = { + val features: Features[Bolt12Feature] = Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional) + val offer = Offer(None, Some("test offer"), outgoingNodeId, features, Block.RegtestGenesisBlock.hash) + val request = InvoiceRequest(offer, outgoingAmount, 1, Features(Features.BasicMultiPartPayment -> FeatureSupport.Optional), randomKey(), Block.RegtestGenesisBlock.hash) + val edge = ExtraEdge(nodeParams.nodeId, outgoingNodeId, Alias(561), 2_000_000 msat, 250, CltvExpiryDelta(144), 1 msat, None) + val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, outgoingNodeId, HopRelayParams.FromHint(edge)) + val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, outgoingExpiry).route + val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(outgoingAmount, Seq(hop), CltvExpiryDelta(12)) + val invoice = Bolt12Invoice(request, randomBytes32(), outgoingNodeKey, 300 seconds, features, Seq(PaymentBlindedRoute(route, paymentInfo))) + incomingMultiPart.map(incoming => { + val innerPayload = IntermediatePayload.NodeRelay.ToBlindedPaths(incoming.innerPayload.amountToForward, outgoingExpiry, invoice) + RelayToBlindedPathsPacket(incoming.add, incoming.outerPayload, innerPayload) + }) + } + } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala index 038f15aa88..8a9ae7d641 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.payment.send.BlindedPathsResolver.{FullBlindedRoute, Part import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams} import fr.acinq.eclair.router.{BlindedRouteCreation, Router} import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo -import fr.acinq.eclair.{BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey} +import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, EncodedNodeId, Features, MilliSatoshiLong, NodeParams, RealShortChannelId, TestConstants, randomBytes32, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.HexStringSyntax @@ -151,6 +151,31 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l } } + test("resolve route starting at our node (wallet node)") { f => + import f._ + + val probe = TestProbe() + val walletNodeId = randomKey().publicKey + val edge = ExtraEdge(nodeParams.nodeId, walletNodeId, Alias(561), 5_000_000 msat, 200, CltvExpiryDelta(144), 1 msat, None) + val hop = ChannelHop(edge.shortChannelId, nodeParams.nodeId, walletNodeId, HopRelayParams.FromHint(edge)) + val route = BlindedRouteCreation.createBlindedRouteToWallet(hop, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route + val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(100_000_000 msat, Seq(hop), CltvExpiryDelta(12)) + val resolver = testKit.spawn(BlindedPathsResolver(nodeParams, randomBytes32(), router.ref, register.ref)) + resolver ! Resolve(probe.ref, Seq(PaymentBlindedRoute(route, paymentInfo))) + // We are the introduction node: we decrypt the payload and discover that the next node is a wallet node. + val resolved = probe.expectMsgType[Seq[ResolvedPath]] + assert(resolved.size == 1) + assert(resolved.head.route.isInstanceOf[PartialBlindedRoute]) + val partialRoute = resolved.head.route.asInstanceOf[PartialBlindedRoute] + assert(partialRoute.firstNodeId == walletNodeId) + assert(partialRoute.nextNodeId == EncodedNodeId.WithPublicKey.Wallet(walletNodeId)) + assert(partialRoute.blindedNodes == route.subsequentNodes) + assert(partialRoute.nextBlinding != route.blindingKey) + // We don't need to resolve the nodeId. + register.expectNoMessage(100 millis) + router.expectNoMessage(100 millis) + } + test("ignore blinded paths that cannot be resolved") { f => import f._ @@ -181,8 +206,9 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l val probe = TestProbe() val scid = RealShortChannelId(BlockHeight(750_000), 3, 7) - val edgeLowFees = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None) - val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None) + val nextNodeId = randomKey().publicKey + val edgeLowFees = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 100 msat, 5, CltvExpiryDelta(144), 1 msat, None) + val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None) val toResolve = Seq( // We don't allow paying blinded routes to ourselves. BlindedRouteCreation.createBlindedRouteWithoutHops(nodeParams.nodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, @@ -190,6 +216,8 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes with low cltv_expiry_delta. BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowExpiryDelta.targetNodeId, HopRelayParams.FromHint(edgeLowExpiryDelta))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, + // We reject blinded routes with low fees, even when the next node seems to be a wallet node. + BlindedRouteCreation.createBlindedRouteToWallet(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees)), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes that cannot be decrypted. BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route.copy(blindingKey = randomKey().publicKey) ).map(r => PaymentBlindedRoute(r, PaymentInfo(1_000_000 msat, 2500, CltvExpiryDelta(300), 1 msat, 500_000_000 msat, Features.empty))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index 13e22eca5e..c0b691b588 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -89,7 +89,7 @@ class PaymentOnionSpec extends AnyFunSuite { val Right(payload) = IntermediatePayload.ChannelRelay.Standard.validate(decoded) assert(payload.amountOut == 561.msat) assert(payload.cltvOut == CltvExpiry(42)) - assert(payload.outgoingChannelId == ShortChannelId(1105)) + assert(payload.outgoing.contains(ShortChannelId(1105))) val encoded = perHopPayloadCodec.encode(expected).require.bytes assert(encoded == bin) } @@ -110,7 +110,7 @@ class PaymentOnionSpec extends AnyFunSuite { val decoded = perHopPayloadCodec.decode(bin.bits).require.value assert(decoded == expected) val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(decoded, blindedTlvs, randomKey().publicKey) - assert(payload.outgoingChannelId == ShortChannelId(42)) + assert(payload.outgoing.contains(ShortChannelId(42))) assert(payload.amountToForward(10_000 msat) == 9990.msat) assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856)) assert(payload.paymentRelayData.allowedFeatures.isEmpty) @@ -119,6 +119,20 @@ class PaymentOnionSpec extends AnyFunSuite { } } + test("encode/decode channel relay blinded per-hop-payload (with wallet node_id)") { + val walletNodeId = PublicKey(hex"0221cd519eba9c8b840a5e40b65dc2c040e159a766979723ed770efceb97260ec8") + val blindedTlvs = TlvStream[RouteBlindingEncryptedDataTlv]( + RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Wallet(walletNodeId)), + RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), + RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat), + ) + val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(TlvStream(EncryptedRecipientData(hex"deadbeef")), blindedTlvs, randomKey().publicKey) + assert(payload.outgoing == Left(walletNodeId)) + assert(payload.amountToForward(10_000 msat) == 9990.msat) + assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856)) + assert(payload.paymentRelayData.allowedFeatures.isEmpty) + } + test("encode/decode node relay per-hop payload") { val nodeId = PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619") val expected = TlvStream[OnionPaymentPayloadTlv](AmountToForward(561 msat), OutgoingCltv(CltvExpiry(42)), OutgoingNodeId(nodeId)) @@ -292,6 +306,8 @@ class PaymentOnionSpec extends AnyFunSuite { TestCase(MissingRequiredTlv(UInt64(10)), hex"23 0c21036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e2", validBlindedTlvs), // Missing encrypted outgoing channel. TestCase(MissingRequiredTlv(UInt64(2)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), + // Forbidden encrypted outgoing plain node_id. + TestCase(ForbiddenTlv(UInt64(4)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingNodeId(EncodedNodeId.WithPublicKey.Plain(randomKey().publicKey)), RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(144), 100, 10 msat), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), // Missing encrypted payment relay data. TestCase(MissingRequiredTlv(UInt64(10)), hex"0a 0a080123456789abcdef", TlvStream(RouteBlindingEncryptedDataTlv.OutgoingChannelId(ShortChannelId(42)), RouteBlindingEncryptedDataTlv.PaymentConstraints(CltvExpiry(1500), 1 msat))), // Missing encrypted payment constraint.