diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index 11b9fc9079..2ff70733df 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -16,10 +16,10 @@ package fr.acinq.eclair.io -import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} +import akka.actor.typed.{Behavior, SupervisorStrategy} import akka.actor.{ActorRef, typed} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -32,6 +32,8 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OnionMessage import fr.acinq.eclair.{EncodedNodeId, NodeParams, ShortChannelId} +import scala.concurrent.duration.DurationInt + object MessageRelay { // @formatter:off sealed trait Command @@ -42,29 +44,18 @@ object MessageRelay { policy: RelayPolicy, replyTo_opt: Option[typed.ActorRef[Status]]) extends Command case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command - case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command - case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command + private case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command + private case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command - sealed trait Status { - val messageId: ByteVector32 - } + sealed trait Status { val messageId: ByteVector32 } case class Sent(messageId: ByteVector32) extends Status sealed trait Failure extends Status - case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { - override def toString: String = s"Relay prevented by policy $policy" - } - case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { - override def toString: String = s"Can't connect to peer: ${failure.toString}" - } - case class Disconnected(messageId: ByteVector32) extends Failure { - override def toString: String = "Peer is not connected" - } - case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { - override def toString: String = s"Unknown channel: $channelId" - } - case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { - override def toString: String = s"Message dropped: $reason" - } + case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { override def toString: String = s"Relay prevented by policy $policy" } + case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { override def toString: String = s"Can't connect to peer: ${failure.toString}" } + case class Disconnected(messageId: ByteVector32) extends Failure { override def toString: String = "Peer is not connected" } + case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { override def toString: String = s"Unknown channel: $channelId" } + case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { override def toString: String = s"Message dropped: $reason" } sealed trait RelayPolicy case object RelayChannelsOnly extends RelayPolicy @@ -100,7 +91,7 @@ private class MessageRelay(nodeParams: NodeParams, def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = { nextNode match { case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf => - withNextNodeId(msg, nodeParams.nodeId) + withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nodeParams.nodeId)) case Left(outgoingChannelId) => register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId) waitForNextNodeId(msg, outgoingChannelId) @@ -108,7 +99,7 @@ private class MessageRelay(nodeParams: NodeParams, router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1) waitForNextNodeId(msg, scid) case Right(encodedNodeId: EncodedNodeId.WithPublicKey) => - withNextNodeId(msg, encodedNodeId.publicKey) + withNextNodeId(msg, encodedNodeId) } } @@ -118,33 +109,39 @@ private class MessageRelay(nodeParams: NodeParams, replyTo_opt.foreach(_ ! UnknownChannel(messageId, channelId)) Behaviors.stopped case WrappedOptionalNodeId(Some(nextNodeId)) => - withNextNodeId(msg, nextNodeId) + withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nextNodeId)) } } - private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = { - if (nextNodeId == nodeParams.nodeId) { - OnionMessages.process(nodeParams.privateKey, msg) match { - case OnionMessages.DropMessage(reason) => - replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) - Behaviors.stopped - case OnionMessages.SendMessage(nextNode, nextMessage) => - // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. - queryNextNodeId(nextMessage, nextNode) - case received: OnionMessages.ReceiveMessage => - context.system.eventStream ! EventStream.Publish(received) - replyTo_opt.foreach(_ ! Sent(messageId)) - Behaviors.stopped - } - } else { - policy match { - case RelayChannelsOnly => - switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) - waitForPreviousPeerForPolicyCheck(msg, nextNodeId) - case RelayAll => - switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) - waitForConnection(msg) - } + private def withNextNodeId(msg: OnionMessage, nextNodeId: EncodedNodeId.WithPublicKey): Behavior[Command] = { + nextNodeId match { + case EncodedNodeId.WithPublicKey.Plain(nodeId) if nodeId == nodeParams.nodeId => + OnionMessages.process(nodeParams.privateKey, msg) match { + case OnionMessages.DropMessage(reason) => + replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason)) + Behaviors.stopped + case OnionMessages.SendMessage(nextNode, nextMessage) => + // We need to repeat the process until we identify the (real) next node, or find out that we're the recipient. + queryNextNodeId(nextMessage, nextNode) + case received: OnionMessages.ReceiveMessage => + context.system.eventStream ! EventStream.Publish(received) + replyTo_opt.foreach(_ ! Sent(messageId)) + Behaviors.stopped + } + case EncodedNodeId.WithPublicKey.Plain(nodeId) => + policy match { + case RelayChannelsOnly => + switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) + waitForPreviousPeerForPolicyCheck(msg, nodeId) + case RelayAll => + switchboard ! Peer.Connect(nodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) + waitForConnection(msg) + } + case EncodedNodeId.WithPublicKey.Wallet(nodeId) => + context.log.info("trying to wake up next peer to relay onion message (nodeId={})", nodeId) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(30 seconds)))).onFailure(SupervisorStrategy.stop)) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter[PeerReadyNotifier.Result](WrappedPeerReadyResult)) + waitForWalletNodeUp(msg) } } @@ -180,4 +177,15 @@ private class MessageRelay(nodeParams: NodeParams, Behaviors.stopped } } + + private def waitForWalletNodeUp(msg: OnionMessage): Behavior[Command] = { + Behaviors.receiveMessagePartial { + case WrappedPeerReadyResult(r: PeerReadyNotifier.PeerReady) => + r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt) + Behaviors.stopped + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + replyTo_opt.foreach(_ ! Disconnected(messageId)) + Behaviors.stopped + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala index d59491c903..20ae3fe823 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/Monitoring.scala @@ -115,6 +115,7 @@ object Monitoring { val Failure = "failure" object FailureType { + val WakeUp = "WakeUp" val Remote = "Remote" val Malformed = "MalformedHtlc" diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index 2f72fe2846..723cc89073 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -128,7 +128,7 @@ object IncomingPaymentPacket { decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap { case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) => validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket).flatMap { - case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoingChannelId == ShortChannelId.toSelf => + case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoing == Right(ShortChannelId.toSelf) => decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Set(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features) case relayPacket => Right(relayPacket) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index a68ff75dd7..7576ef025b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -17,21 +17,23 @@ package fr.acinq.eclair.payment.relay import akka.actor.ActorRef -import akka.actor.typed.Behavior import akka.actor.typed.eventstream.EventStream import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} +import akka.actor.typed.{Behavior, SupervisorStrategy} import fr.acinq.bitcoin.scalacompat.ByteVector32 +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.db.PendingCommandsDb +import fr.acinq.eclair.io.PeerReadyNotifier import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.relay.Relayer.{OutgoingChannel, OutgoingChannelParams} import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket} import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Logs, NodeParams, TimestampMilli, TimestampSecond, channel, nodeFee} +import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, TimestampMilli, TimestampSecond, channel, nodeFee} import java.util.UUID import java.util.concurrent.TimeUnit @@ -43,6 +45,7 @@ object ChannelRelay { // @formatter:off sealed trait Command private case object DoRelay extends Command + private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command private case class WrappedForwardFailure(failure: Register.ForwardFailure[CMD_ADD_HTLC]) extends Command private case class WrappedAddResponse(res: CommandResponse[CMD_ADD_HTLC]) extends Command // @formatter:on @@ -61,7 +64,7 @@ object ChannelRelay { paymentHash_opt = Some(r.add.paymentHash), nodeAlias_opt = Some(nodeParams.alias))) { context.self ! DoRelay - new ChannelRelay(nodeParams, register, channels, r, context).relay(Seq.empty) + new ChannelRelay(nodeParams, register, channels, r, context).start() } } @@ -69,7 +72,7 @@ object ChannelRelay { * This helper method translates relaying errors (returned by the downstream outgoing channel) to BOLT 4 standard * errors that we should return upstream. */ - def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { + private def translateLocalError(error: Throwable, channelUpdate_opt: Option[ChannelUpdate]): FailureMessage = { (error, channelUpdate_opt) match { case (_: ExpiryTooSmall, Some(channelUpdate)) => ExpiryTooSoon(Some(channelUpdate)) case (_: ExpiryTooBig, _) => ExpiryTooFar() @@ -112,16 +115,52 @@ class ChannelRelay private(nodeParams: NodeParams, private val forwardFailureAdapter = context.messageAdapter[Register.ForwardFailure[CMD_ADD_HTLC]](WrappedForwardFailure) private val addResponseAdapter = context.messageAdapter[CommandResponse[CMD_ADD_HTLC]](WrappedAddResponse) + private val nextNodeId_opt = r.payload.outgoing match { + case Left(walletNodeId) => Some(walletNodeId) + // All the channels point to the same next node, we take the first one. + case Right(_) => channels.headOption.map(_._2.nextNodeId) + } + + /** Channel id explicitly requested in the onion payload. */ + private val requestedChannelId_opt = r.payload.outgoing match { + case Left(_) => None + case Right(outgoingChannelId) => channels.collectFirst { + case (channelId, channel) if channel.shortIds.localAlias == outgoingChannelId => channelId + case (channelId, channel) if channel.shortIds.real.toOption.contains(outgoingChannelId) => channelId + } + } + private case class PreviouslyTried(channelId: ByteVector32, failure: RES_ADD_FAILED[ChannelException]) - def relay(previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { + def start(): Behavior[Command] = { + r.payload.outgoing match { + case Left(walletNodeId) => wakeUp(walletNodeId) + case Right(requestedShortChannelId) => relay(Some(requestedShortChannelId), Seq.empty) + } + } + + private def wakeUp(walletNodeId: PublicKey): Behavior[Command] = { + context.log.info("trying to wake up channel peer (nodeId={})", walletNodeId) + val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(walletNodeId, timeout_opt = Some(Left(30 seconds)))).onFailure(SupervisorStrategy.stop)) + notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter[PeerReadyNotifier.Result](WrappedPeerReadyResult)) + Behaviors.receiveMessagePartial { + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) => + Metrics.recordPaymentRelayFailed(Tags.FailureType.WakeUp, Tags.RelayType.Channel) + context.log.info("rejecting htlc: failed to wake-up remote peer") + safeSendAndStop(r.add.channelId, CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) + case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerReady) => + relay(None, Seq.empty) + } + } + + def relay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = { Behaviors.receiveMessagePartial { case DoRelay => if (previousFailures.isEmpty) { - context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, r.payload.outgoingChannelId, nextNodeId_opt.getOrElse("")) + context.log.info("relaying htlc #{} from channelId={} to requestedShortChannelId={} nextNode={}", r.add.id, r.add.channelId, requestedShortChannelId_opt, nextNodeId_opt.getOrElse("")) } context.log.debug("attempting relay previousAttempts={}", previousFailures.size) - handleRelay(previousFailures) match { + handleRelay(requestedShortChannelId_opt, previousFailures) match { case RelayFailure(cmdFail) => Metrics.recordPaymentRelayFailed(Tags.FailureType(cmdFail), Tags.RelayType.Channel) context.log.info("rejecting htlc reason={}", cmdFail.reason) @@ -129,12 +168,12 @@ class ChannelRelay private(nodeParams: NodeParams, case RelaySuccess(selectedChannelId, cmdAdd) => context.log.info("forwarding htlc #{} from channelId={} to channelId={}", r.add.id, r.add.channelId, selectedChannelId) register ! Register.Forward(forwardFailureAdapter, selectedChannelId, cmdAdd) - waitForAddResponse(selectedChannelId, previousFailures) + waitForAddResponse(selectedChannelId, requestedShortChannelId_opt, previousFailures) } } } - def waitForAddResponse(selectedChannelId: ByteVector32, previousFailures: Seq[PreviouslyTried]): Behavior[Command] = + private def waitForAddResponse(selectedChannelId: ByteVector32, requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedForwardFailure(Register.ForwardFailure(Register.Forward(_, channelId, CMD_ADD_HTLC(_, _, _, _, _, _, o: Origin.ChannelRelayedHot, _)))) => context.log.warn(s"couldn't resolve downstream channel $channelId, failing htlc #${o.add.id}") @@ -145,14 +184,14 @@ class ChannelRelay private(nodeParams: NodeParams, case WrappedAddResponse(addFailed@RES_ADD_FAILED(CMD_ADD_HTLC(_, _, _, _, _, _, _: Origin.ChannelRelayedHot, _), _, _)) => context.log.info("attempt failed with reason={}", addFailed.t.getClass.getSimpleName) context.self ! DoRelay - relay(previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) + relay(requestedShortChannelId_opt, previousFailures :+ PreviouslyTried(selectedChannelId, addFailed)) case WrappedAddResponse(_: RES_SUCCESS[_]) => context.log.debug("sent htlc to the downstream channel") waitForAddSettled() } - def waitForAddSettled(): Behavior[Command] = + private def waitForAddSettled(): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedAddResponse(RES_ADD_SETTLED(o: Origin.ChannelRelayedHot, htlc, fulfill: HtlcResult.Fulfill)) => context.log.debug("relaying fulfill to upstream") @@ -169,7 +208,7 @@ class ChannelRelay private(nodeParams: NodeParams, safeSendAndStop(o.originChannelId, cmd) } - def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { + private def safeSendAndStop(channelId: ByteVector32, cmd: channel.HtlcSettlementCommand): Behavior[Command] = { val toSend = cmd match { case _: CMD_FULFILL_HTLC => cmd case _: CMD_FAIL_HTLC | _: CMD_FAIL_MALFORMED_HTLC => r.payload match { @@ -200,9 +239,9 @@ class ChannelRelay private(nodeParams: NodeParams, * - a CMD_FAIL_HTLC to be sent back upstream * - a CMD_ADD_HTLC to propagate downstream */ - def handleRelay(previousFailures: Seq[PreviouslyTried]): RelayResult = { + private def handleRelay(requestedShortChannelId_opt: Option[ShortChannelId], previousFailures: Seq[PreviouslyTried]): RelayResult = { val alreadyTried = previousFailures.map(_.channelId) - selectPreferredChannel(alreadyTried) match { + selectPreferredChannel(requestedShortChannelId_opt, alreadyTried) match { case None if previousFailures.nonEmpty => // no more channels to try val error = previousFailures @@ -217,24 +256,14 @@ class ChannelRelay private(nodeParams: NodeParams, } } - /** all the channels point to the same next node, we take the first one */ - private val nextNodeId_opt = channels.headOption.map(_._2.nextNodeId) - - /** channel id explicitly requested in the onion payload */ - private val requestedChannelId_opt = channels.collectFirst { - case (channelId, channel) if channel.shortIds.localAlias == r.payload.outgoingChannelId => channelId - case (channelId, channel) if channel.shortIds.real.toOption.contains(r.payload.outgoingChannelId) => channelId - } - /** * Select a channel to the same node to relay the payment to, that has the lowest capacity and balance and is * compatible in terms of fees, expiry_delta, etc. * * If no suitable channel is found we default to the originally requested channel. */ - def selectPreferredChannel(alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { - val requestedShortChannelId = r.payload.outgoingChannelId - context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId) + private def selectPreferredChannel(requestedShortChannelId_opt: Option[ShortChannelId], alreadyTried: Seq[ByteVector32]): Option[OutgoingChannel] = { + context.log.debug("selecting next channel with requestedShortChannelId={}", requestedShortChannelId_opt) // we filter out channels that we have already tried val candidateChannels: Map[ByteVector32, OutgoingChannel] = channels -- alreadyTried // and we filter again to keep the ones that are compatible with this payment (mainly fees, expiry delta) @@ -242,7 +271,8 @@ class ChannelRelay private(nodeParams: NodeParams, .values .map { channel => val relayResult = relayOrFail(Some(channel)) - context.log.debug(s"candidate channel: channelId=${channel.channelId} availableForSend={} capacity={} channelUpdate={} result={}", + context.log.debug("candidate channel: channelId={} availableForSend={} capacity={} channelUpdate={} result={}", + channel.channelId, channel.commitments.availableBalanceForSend, channel.commitments.latest.capacity, channel.channelUpdate, @@ -268,7 +298,7 @@ class ChannelRelay private(nodeParams: NodeParams, context.log.debug("requested short channel id is our preferred channel") Some(channel) } else { - context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend) + context.log.debug("replacing requestedShortChannelId={} by preferredShortChannelId={} with availableBalanceMsat={}", requestedShortChannelId_opt, channel.channelUpdate.shortChannelId, channel.commitments.availableBalanceForSend) Some(channel) } case None => @@ -289,7 +319,7 @@ class ChannelRelay private(nodeParams: NodeParams, * channel, because some parameters don't match with our settings for that channel. In that case we directly fail the * htlc. */ - def relayOrFail(outgoingChannel_opt: Option[OutgoingChannelParams]): RelayResult = { + private def relayOrFail(outgoingChannel_opt: Option[OutgoingChannelParams]): RelayResult = { outgoingChannel_opt match { case None => RelayFailure(CMD_FAIL_HTLC(r.add.id, Right(UnknownNextPeer()), commit = true)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala index 59eb7b58b0..355ea83c22 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelayer.scala @@ -24,7 +24,7 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.channel._ import fr.acinq.eclair.payment.IncomingPaymentPacket -import fr.acinq.eclair.{SubscriptionsComplete, Logs, NodeParams, ShortChannelId} +import fr.acinq.eclair.{Logs, NodeParams, ShortChannelId, SubscriptionsComplete} import java.util.UUID import scala.collection.mutable @@ -71,9 +71,12 @@ object ChannelRelayer { Behaviors.receiveMessage { case Relay(channelRelayPacket) => val relayId = UUID.randomUUID() - val nextNodeId_opt: Option[PublicKey] = scid2channels.get(channelRelayPacket.payload.outgoingChannelId) match { - case Some(channelId) => channels.get(channelId).map(_.nextNodeId) - case None => None + val nextNodeId_opt: Option[PublicKey] = channelRelayPacket.payload.outgoing match { + case Left(walletNodeId) => Some(walletNodeId) + case Right(outgoingChannelId) => scid2channels.get(outgoingChannelId) match { + case Some(channelId) => channels.get(channelId).map(_.nextNodeId) + case None => None + } } val nextChannels: Map[ByteVector32, Relayer.OutgoingChannel] = nextNodeId_opt match { case Some(nextNodeId) => node2channels.get(nextNodeId).flatMap(channels.get).map(c => c.channelId -> c).toMap diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 8acd5db157..c48c2b7248 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -42,7 +42,7 @@ import fr.acinq.eclair.router.Router.RouteParams import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound} import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, EncodedNodeId, Features, Logs, MilliSatoshi, NodeParams, TimestampMilli, UInt64, nodeFee, randomBytes32} import java.util.UUID import java.util.concurrent.TimeUnit @@ -145,8 +145,14 @@ object NodeRelay { } } - private def shouldWakeUpNextNode(nodeParams: NodeParams): Boolean = { - false + private def shouldWakeUpNextNode(nodeParams: NodeParams, recipient: Recipient): Boolean = { + recipient match { + case r: BlindedRecipient => r.blindedHops.head.resolved.route match { + case BlindedPathsResolver.PartialBlindedRoute(_: EncodedNodeId.WithPublicKey.Wallet, _, _) => true + case _ => false + } + case _ => false + } } /** When we have identified that the next node is one of our peers, return their (real) nodeId. */ @@ -297,7 +303,7 @@ class NodeRelay private(nodeParams: NodeParams, private def wakeUpIfNeeded(upstream: Upstream.Trampoline, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, nextPacket_opt: Option[OnionRoutingPacket]): Behavior[Command] = { if (isAsyncPayment(nodeParams, nextPayload)) { waitForTrigger(upstream, recipient, nextPayload, nextPacket_opt) - } else if (shouldWakeUpNextNode(nodeParams)) { + } else if (shouldWakeUpNextNode(nodeParams, recipient)) { waitForPeerReady(upstream, recipient, nextPayload, nextPacket_opt) } else { relay(upstream, recipient, nextPayload, nextPacket_opt) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala index a12799d8d2..cb91693e7c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/BlindedPathsResolver.scala @@ -14,7 +14,7 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OfferTypes.PaymentInfo import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData import fr.acinq.eclair.wire.protocol.{BlindedRouteData, OfferTypes, RouteBlindingEncryptedDataCodecs} -import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams} +import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams, ShortChannelId} import scodec.bits.ByteVector import scala.annotation.tailrec @@ -45,8 +45,8 @@ object BlindedPathsResolver { override val firstNodeId: PublicKey = introductionNodeId } /** A partially unwrapped blinded route that started at our node: it only contains the part of the route after our node. */ - case class PartialBlindedRoute(nextNodeId: PublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute { - override val firstNodeId: PublicKey = nextNodeId + case class PartialBlindedRoute(nextNodeId: EncodedNodeId.WithPublicKey, nextBlinding: PublicKey, blindedNodes: Seq[BlindedNode]) extends ResolvedBlindedRoute { + override val firstNodeId: PublicKey = nextNodeId.publicKey } // @formatter:on @@ -111,8 +111,15 @@ private class BlindedPathsResolver(nodeParams: NodeParams, feeProportionalMillionths = nextFeeProportionalMillionths, cltvExpiryDelta = nextCltvExpiryDelta ) - register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), paymentRelayData.outgoingChannelId) - waitForNextNodeId(nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + paymentRelayData.outgoing match { + case Left(outgoingNodeId) => + // The next node is a wallet node directly connected to us. + val path = ResolvedPath(PartialBlindedRoute(EncodedNodeId.WithPublicKey.Wallet(outgoingNodeId), nextBlinding, paymentRoute.route.subsequentNodes), nextPaymentInfo) + resolveBlindedPaths(toResolve.tail, resolved :+ path) + case Right(outgoingChannelId) => + register ! Register.GetNextNodeId(context.messageAdapter(WrappedNodeId), outgoingChannelId) + waitForNextNodeId(outgoingChannelId, nextPaymentInfo, paymentRelayData, nextBlinding, paymentRoute.route.subsequentNodes, toResolve.tail, resolved) + } } } case encodedNodeId: EncodedNodeId.WithPublicKey => @@ -129,7 +136,8 @@ private class BlindedPathsResolver(nodeParams: NodeParams, } /** Resolve the next node in the blinded path when we are the introduction node. */ - private def waitForNextNodeId(nextPaymentInfo: OfferTypes.PaymentInfo, + private def waitForNextNodeId(outgoingChannelId: ShortChannelId, + nextPaymentInfo: OfferTypes.PaymentInfo, paymentRelayData: BlindedRouteData.PaymentRelayData, nextBlinding: PublicKey, nextBlindedNodes: Seq[RouteBlinding.BlindedNode], @@ -137,7 +145,7 @@ private class BlindedPathsResolver(nodeParams: NodeParams, resolved: Seq[ResolvedPath]): Behavior[Command] = Behaviors.receiveMessagePartial { case WrappedNodeId(None) => - context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", paymentRelayData.outgoingChannelId) + context.log.warn("ignoring blinded path starting at our node: could not resolve outgoingChannelId={}", outgoingChannelId) resolveBlindedPaths(toResolve, resolved) case WrappedNodeId(Some(nodeId)) if nodeId == nodeParams.nodeId => // The next node in the route is also our node: this is fishy, there is not reason to include us in the route twice. @@ -152,7 +160,7 @@ private class BlindedPathsResolver(nodeParams: NodeParams, paymentRelayData.paymentRelay.cltvExpiryDelta >= nodeParams.channelConf.expiryDelta if (shouldRelay) { context.log.debug("unwrapped blinded path starting at our node: next_node={}", nodeId) - val path = ResolvedPath(PartialBlindedRoute(nodeId, nextBlinding, nextBlindedNodes), nextPaymentInfo) + val path = ResolvedPath(PartialBlindedRoute(EncodedNodeId.WithPublicKey.Plain(nodeId), nextBlinding, nextBlindedNodes), nextPaymentInfo) resolveBlindedPaths(toResolve, resolved :+ path) } else { context.log.warn("ignoring blinded path starting at our node: allocated fees are too low (base={}, proportional={}, expiryDelta={})", paymentRelayData.paymentRelay.feeBase, paymentRelayData.paymentRelay.feeProportionalMillionths, paymentRelayData.paymentRelay.cltvExpiryDelta) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala index 10bd99f443..4468ed7172 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PaymentOnion.scala @@ -23,7 +23,7 @@ import fr.acinq.eclair.wire.protocol.BlindedRouteData.PaymentRelayData import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs._ -import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.{BitVector, ByteVector} /** @@ -227,7 +227,8 @@ object PaymentOnion { object IntermediatePayload { sealed trait ChannelRelay extends IntermediatePayload { // @formatter:off - def outgoingChannelId: ShortChannelId + /** The outgoing channel, or the nodeId of one of our peers. */ + def outgoing: Either[PublicKey, ShortChannelId] def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry // @formatter:on @@ -238,7 +239,7 @@ object PaymentOnion { // @formatter:off val amountOut = records.get[AmountToForward].get.amount val cltvOut = records.get[OutgoingCltv].get.cltv - override val outgoingChannelId = records.get[OutgoingChannelId].get.shortChannelId + override val outgoing = Right(records.get[OutgoingChannelId].get.shortChannelId) override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = amountOut override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = cltvOut // @formatter:on @@ -258,12 +259,12 @@ object PaymentOnion { } /** - * @param blindedRecords decrypted tlv stream from the encrypted_recipient_data tlv. - * @param nextBlinding blinding point that must be forwarded to the next hop. + * @param paymentRelayData decrypted relaying data from the encrypted_recipient_data tlv. + * @param nextBlinding blinding point that must be forwarded to the next hop. */ case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], paymentRelayData: PaymentRelayData, nextBlinding: PublicKey) extends ChannelRelay { // @formatter:off - override val outgoingChannelId = paymentRelayData.outgoingChannelId + override val outgoing = paymentRelayData.outgoing override def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = paymentRelayData.amountToForward(incomingAmount) override def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = paymentRelayData.outgoingCltv(incomingCltv) // @formatter:on diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 096c1f8edf..be53f4aab0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -98,7 +98,11 @@ object BlindedRouteData { } case class PaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]) { - val outgoingChannelId: ShortChannelId = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId + // This is usually a channel, unless the next node is a mobile wallet connected to our node. + val outgoing: Either[PublicKey, ShortChannelId] = records.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId] match { + case Some(r) => Right(r.shortChannelId) + case None => Left(records.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId.asInstanceOf[EncodedNodeId.WithPublicKey.Wallet].publicKey) + } val paymentRelay: PaymentRelay = records.get[RouteBlindingEncryptedDataTlv.PaymentRelay].get val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get val allowedFeatures: Features[Feature] = records.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty) @@ -110,7 +114,9 @@ object BlindedRouteData { } def validatePaymentRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, PaymentRelayData] = { - if (records.get[OutgoingChannelId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) + // Note that the BOLTs require using an OutgoingChannelId, but we optionally support a wallet node_id. + if (records.get[OutgoingChannelId].isEmpty && records.get[OutgoingNodeId].isEmpty) return Left(MissingRequiredTlv(UInt64(2))) + if (records.get[OutgoingNodeId].nonEmpty && !records.get[OutgoingNodeId].get.nodeId.isInstanceOf[EncodedNodeId.WithPublicKey.Wallet]) return Left(ForbiddenTlv(UInt64(4))) if (records.get[PaymentRelay].isEmpty) return Left(MissingRequiredTlv(UInt64(10))) if (records.get[PaymentConstraints].isEmpty) return Left(MissingRequiredTlv(UInt64(12))) if (records.get[PathId].nonEmpty) return Left(ForbiddenTlv(UInt64(6))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 5f62668804..7a1673ff60 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -505,7 +505,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadBob) = RouteBlindingEncryptedDataCodecs.decode(bob, blinding, tlvsBob.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForCarol = decryptedPayloadBob.nextBlinding val Right(payloadBob) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsBob, decryptedPayloadBob.tlvs, blindingEphemeralKeyForCarol) - assert(payloadBob.outgoingChannelId == ShortChannelId(1)) + assert(payloadBob.outgoing.contains(ShortChannelId(1))) assert(payloadBob.amountToForward(110_125 msat) == 100_125.msat) assert(payloadBob.outgoingCltv(CltvExpiry(749150)) == CltvExpiry(749100)) assert(payloadBob.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(50), 0, 10_000 msat)) @@ -523,7 +523,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadCarol) = RouteBlindingEncryptedDataCodecs.decode(carol, blindingEphemeralKeyForCarol, tlvsCarol.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForDave = decryptedPayloadCarol.nextBlinding val Right(payloadCarol) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsCarol, decryptedPayloadCarol.tlvs, blindingEphemeralKeyForDave) - assert(payloadCarol.outgoingChannelId == ShortChannelId(2)) + assert(payloadCarol.outgoing.contains(ShortChannelId(2))) assert(payloadCarol.amountToForward(100_125 msat) == 100_010.msat) assert(payloadCarol.outgoingCltv(CltvExpiry(749100)) == CltvExpiry(749025)) assert(payloadCarol.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(75), 150, 100 msat)) @@ -544,7 +544,7 @@ class SphinxSpec extends AnyFunSuite { val Right(decryptedPayloadDave) = RouteBlindingEncryptedDataCodecs.decode(dave, blindingOverride, tlvsDave.get[OnionPaymentPayloadTlv.EncryptedRecipientData].get.data) val blindingEphemeralKeyForEve = decryptedPayloadDave.nextBlinding val Right(payloadDave) = PaymentOnion.IntermediatePayload.ChannelRelay.Blinded.validate(tlvsDave, decryptedPayloadDave.tlvs, blindingEphemeralKeyForEve) - assert(payloadDave.outgoingChannelId == ShortChannelId(3)) + assert(payloadDave.outgoing.contains(ShortChannelId(3))) assert(payloadDave.amountToForward(100_010 msat) == 100_000.msat) assert(payloadDave.outgoingCltv(CltvExpiry(749025)) == CltvExpiry(749000)) assert(payloadDave.paymentRelayData.paymentRelay == RouteBlindingEncryptedDataTlv.PaymentRelay(CltvExpiryDelta(25), 100, 0 msat)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index 11f63749e7..2f5d722fda 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -86,7 +86,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward == amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) @@ -96,7 +96,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_c.amountToForward == amount_cd) assert(relay_c.outgoingCltv == expiry_cd) - assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId)) assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) @@ -106,7 +106,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_d.amountToForward == amount_de) assert(relay_d.outgoingCltv == expiry_de) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) @@ -176,7 +176,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward >= amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) @@ -186,7 +186,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_d.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_c.amountToForward >= amount_cd) assert(relay_c.outgoingCltv == expiry_cd) - assert(payload_c.outgoingChannelId == channelUpdate_cd.shortChannelId) + assert(payload_c.outgoing.contains(channelUpdate_cd.shortChannelId)) assert(relay_c.relayFeeMsat == fee_c) assert(relay_c.expiryDelta == channelUpdate_cd.cltvExpiryDelta) assert(payload_c.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) @@ -197,7 +197,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_e.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_d.amountToForward >= amount_de) assert(relay_d.outgoingCltv == expiry_de) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.relayFeeMsat == fee_d) assert(relay_d.expiryDelta == channelUpdate_de.cltvExpiryDelta) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) @@ -239,7 +239,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(packet_c.payload.length == PaymentOnionCodecs.paymentOnionPayloadLength) assert(relay_b.amountToForward >= amount_bc) assert(relay_b.outgoingCltv == expiry_bc) - assert(payload_b.outgoingChannelId == channelUpdate_bc.shortChannelId) + assert(payload_b.outgoing.contains(channelUpdate_bc.shortChannelId)) assert(relay_b.relayFeeMsat == fee_b) assert(relay_b.expiryDelta == channelUpdate_bc.cltvExpiryDelta) assert(payload_b.isInstanceOf[IntermediatePayload.ChannelRelay.Standard]) @@ -548,7 +548,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { // A smaller amount is sent to d, who doesn't know that it's invalid. val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.amountToForward < amount_de) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding @@ -570,7 +570,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val invalidExpiry = payment.cmd.cltvExpiry - Channel.MIN_CLTV_EXPIRY_DELTA - CltvExpiryDelta(1) val add_d = UpdateAddHtlc(randomBytes32(), 0, payment.cmd.amount, paymentHash, invalidExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt) val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) - assert(payload_d.outgoingChannelId == channelUpdate_de.shortChannelId) + assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) assert(relay_d.outgoingCltv < CltvExpiry(currentBlockCount)) assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) val blinding_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextBlinding diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala index 13e22eca5e..95cc870acc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/PaymentOnionSpec.scala @@ -89,7 +89,7 @@ class PaymentOnionSpec extends AnyFunSuite { val Right(payload) = IntermediatePayload.ChannelRelay.Standard.validate(decoded) assert(payload.amountOut == 561.msat) assert(payload.cltvOut == CltvExpiry(42)) - assert(payload.outgoingChannelId == ShortChannelId(1105)) + assert(payload.outgoing.contains(ShortChannelId(1105))) val encoded = perHopPayloadCodec.encode(expected).require.bytes assert(encoded == bin) } @@ -110,7 +110,7 @@ class PaymentOnionSpec extends AnyFunSuite { val decoded = perHopPayloadCodec.decode(bin.bits).require.value assert(decoded == expected) val Right(payload) = IntermediatePayload.ChannelRelay.Blinded.validate(decoded, blindedTlvs, randomKey().publicKey) - assert(payload.outgoingChannelId == ShortChannelId(42)) + assert(payload.outgoing.contains(ShortChannelId(42))) assert(payload.amountToForward(10_000 msat) == 9990.msat) assert(payload.outgoingCltv(CltvExpiry(1000)) == CltvExpiry(856)) assert(payload.paymentRelayData.allowedFeatures.isEmpty)