Skip to content

Commit

Permalink
Add wake-up step to channel and message relay
Browse files Browse the repository at this point in the history
We allow relaying data to contain a wallet `node_id` instead of an scid.
When that's the case, we start by waking up that wallet node before we
try relaying onion messages or payments.
  • Loading branch information
t-bast committed Jul 31, 2024
1 parent c6e19ae commit eae11e8
Show file tree
Hide file tree
Showing 20 changed files with 576 additions and 267 deletions.
6 changes: 4 additions & 2 deletions eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager,
blockchainWatchdogSources: Seq[String],
onionMessageConfig: OnionMessageConfig,
purgeInvoicesInterval: Option[FiniteDuration],
revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config) {
revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config,
wakeUpTimeout: FiniteDuration) {
val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey

val nodeId: PublicKey = nodeKeyManager.nodeId
Expand Down Expand Up @@ -611,7 +612,8 @@ object NodeParams extends Logging {
revokedHtlcInfoCleanerConfig = RevokedHtlcInfoCleaner.Config(
batchSize = config.getInt("db.revoked-htlc-info-cleaner.batch-size"),
interval = FiniteDuration(config.getDuration("db.revoked-htlc-info-cleaner.interval").getSeconds, TimeUnit.SECONDS)
)
),
wakeUpTimeout = 30 seconds,
)
}
}
105 changes: 57 additions & 48 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

package fr.acinq.eclair.io

import akka.actor.typed.Behavior
import akka.actor.typed.eventstream.EventStream
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.typed.{Behavior, SupervisorStrategy}
import akka.actor.{ActorRef, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
Expand All @@ -34,6 +34,8 @@ import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.OnionMessage
import fr.acinq.eclair.{EncodedNodeId, Logs, NodeParams, ShortChannelId}

import scala.concurrent.duration.DurationInt

object MessageRelay {
// @formatter:off
sealed trait Command
Expand All @@ -44,29 +46,18 @@ object MessageRelay {
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]) extends Command
case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command
case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command
case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command
private case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command
private case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command
private case class WrappedPeerReadyResult(result: PeerReadyNotifier.Result) extends Command

sealed trait Status {
val messageId: ByteVector32
}
sealed trait Status { val messageId: ByteVector32 }
case class Sent(messageId: ByteVector32) extends Status
sealed trait Failure extends Status
case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure {
override def toString: String = s"Relay prevented by policy $policy"
}
case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure {
override def toString: String = s"Can't connect to peer: ${failure.toString}"
}
case class Disconnected(messageId: ByteVector32) extends Failure {
override def toString: String = "Peer is not connected"
}
case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure {
override def toString: String = s"Unknown channel: $channelId"
}
case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure {
override def toString: String = s"Message dropped: $reason"
}
case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { override def toString: String = s"Relay prevented by policy $policy" }
case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { override def toString: String = s"Can't connect to peer: ${failure.toString}" }
case class Disconnected(messageId: ByteVector32) extends Failure { override def toString: String = "Peer is not connected" }
case class UnknownChannel(messageId: ByteVector32, channelId: ShortChannelId) extends Failure { override def toString: String = s"Unknown channel: $channelId" }
case class DroppedMessage(messageId: ByteVector32, reason: DropReason) extends Failure { override def toString: String = s"Message dropped: $reason" }

sealed trait RelayPolicy
case object RelayChannelsOnly extends RelayPolicy
Expand Down Expand Up @@ -106,15 +97,15 @@ private class MessageRelay(nodeParams: NodeParams,
def queryNextNodeId(msg: OnionMessage, nextNode: Either[ShortChannelId, EncodedNodeId]): Behavior[Command] = {
nextNode match {
case Left(outgoingChannelId) if outgoingChannelId == ShortChannelId.toSelf =>
withNextNodeId(msg, nodeParams.nodeId)
withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nodeParams.nodeId))
case Left(outgoingChannelId) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId)
waitForNextNodeId(msg, outgoingChannelId)
case Right(EncodedNodeId.ShortChannelIdDir(isNode1, scid)) =>
router ! Router.GetNodeId(context.messageAdapter(WrappedOptionalNodeId), scid, isNode1)
waitForNextNodeId(msg, scid)
case Right(encodedNodeId: EncodedNodeId.WithPublicKey) =>
withNextNodeId(msg, encodedNodeId.publicKey)
withNextNodeId(msg, encodedNodeId)
}
}

Expand All @@ -127,34 +118,39 @@ private class MessageRelay(nodeParams: NodeParams,
Behaviors.stopped
case WrappedOptionalNodeId(Some(nextNodeId)) =>
log.info("found outgoing node {} for channel {}", nextNodeId, channelId)
withNextNodeId(msg, nextNodeId)
withNextNodeId(msg, EncodedNodeId.WithPublicKey.Plain(nextNodeId))
}
}

private def withNextNodeId(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
if (nextNodeId == nodeParams.nodeId) {
OnionMessages.process(nodeParams.privateKey, msg) match {
case OnionMessages.DropMessage(reason) =>
Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment()
replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason))
Behaviors.stopped
case OnionMessages.SendMessage(nextNode, nextMessage) =>
// We need to repeat the process until we identify the (real) next node, or find out that we're the recipient.
queryNextNodeId(nextMessage, nextNode)
case received: OnionMessages.ReceiveMessage =>
context.system.eventStream ! EventStream.Publish(received)
replyTo_opt.foreach(_ ! Sent(messageId))
Behaviors.stopped
}
} else {
policy match {
case RelayChannelsOnly =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
waitForPreviousPeerForPolicyCheck(msg, nextNodeId)
case RelayAll =>
switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
waitForConnection(msg, nextNodeId)
}
private def withNextNodeId(msg: OnionMessage, nextNodeId: EncodedNodeId.WithPublicKey): Behavior[Command] = {
nextNodeId match {
case EncodedNodeId.WithPublicKey.Plain(nodeId) if nodeId == nodeParams.nodeId =>
OnionMessages.process(nodeParams.privateKey, msg) match {
case OnionMessages.DropMessage(reason) =>
Metrics.OnionMessagesNotRelayed.withTag(Tags.Reason, reason.getClass.getSimpleName).increment()
replyTo_opt.foreach(_ ! DroppedMessage(messageId, reason))
Behaviors.stopped
case OnionMessages.SendMessage(nextNode, nextMessage) =>
// We need to repeat the process until we identify the (real) next node, or find out that we're the recipient.
queryNextNodeId(nextMessage, nextNode)
case received: OnionMessages.ReceiveMessage =>
context.system.eventStream ! EventStream.Publish(received)
replyTo_opt.foreach(_ ! Sent(messageId))
Behaviors.stopped
}
case EncodedNodeId.WithPublicKey.Plain(nodeId) =>
policy match {
case RelayChannelsOnly =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
waitForPreviousPeerForPolicyCheck(msg, nodeId)
case RelayAll =>
switchboard ! Peer.Connect(nodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
waitForConnection(msg, nodeId)
}
case EncodedNodeId.WithPublicKey.Wallet(nodeId) =>
val notifier = context.spawnAnonymous(Behaviors.supervise(PeerReadyNotifier(nodeId, timeout_opt = Some(Left(nodeParams.wakeUpTimeout)))).onFailure(SupervisorStrategy.stop))
notifier ! PeerReadyNotifier.NotifyWhenPeerReady(context.messageAdapter(WrappedPeerReadyResult))
waitForWalletNodeUp(msg, nodeId)
}
}

Expand Down Expand Up @@ -197,4 +193,17 @@ private class MessageRelay(nodeParams: NodeParams,
Behaviors.stopped
}
}

private def waitForWalletNodeUp(msg: OnionMessage, nextNodeId: PublicKey): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case WrappedPeerReadyResult(r: PeerReadyNotifier.PeerReady) =>
log.info("successfully woke up {}: relaying onion message", nextNodeId)
r.peer ! Peer.RelayOnionMessage(messageId, msg, replyTo_opt)
Behaviors.stopped
case WrappedPeerReadyResult(_: PeerReadyNotifier.PeerUnavailable) =>
log.info("could not wake up {}: onion message cannot be relayed", nextNodeId)
replyTo_opt.foreach(_ ! Disconnected(messageId))
Behaviors.stopped
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ object PeerReadyNotifier {
case WrappedListing(Switchboard.SwitchboardServiceKey.Listing(listings)) =>
listings.headOption match {
case Some(switchboard) =>
waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers)
waitForPeerConnected(replyTo, remoteNodeId, switchboard, context, timers)
case None =>
context.log.error("no switchboard found")
replyTo ! PeerUnavailable(remoteNodeId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ object Monitoring {
val Failure = "failure"

object FailureType {
val WakeUp = "WakeUp"
val Remote = "Remote"
val Malformed = "MalformedHtlc"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ object IncomingPaymentPacket {
decryptEncryptedRecipientData(add, privateKey, payload, encrypted.data).flatMap {
case DecodedEncryptedRecipientData(blindedPayload, nextBlinding) =>
validateBlindedChannelRelayPayload(add, payload, blindedPayload, nextBlinding, nextPacket).flatMap {
case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoingChannelId == ShortChannelId.toSelf =>
case ChannelRelayPacket(_, payload, nextPacket) if payload.outgoing == Right(ShortChannelId.toSelf) =>
decrypt(add.copy(onionRoutingPacket = nextPacket, tlvStream = add.tlvStream.copy(records = Set(UpdateAddHtlcTlv.BlindingPoint(nextBlinding)))), privateKey, features)
case relayPacket => Right(relayPacket)
}
Expand Down
Loading

0 comments on commit eae11e8

Please sign in to comment.