diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala index c9886b031e..b668ff62b7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala @@ -275,6 +275,11 @@ object Features { val mandatory = 38 } + case object ProvideStorage extends Feature with InitFeature with NodeFeature { + val rfcName = "option_provide_storage" + val mandatory = 42 + } + case object ChannelType extends Feature with InitFeature with NodeFeature { val rfcName = "option_channel_type" val mandatory = 44 @@ -358,6 +363,7 @@ object Features { DualFunding, Quiescence, OnionMessages, + ProvideStorage, ChannelType, ScidAlias, PaymentMetadata, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 55a0b6fc8b..4c7af077c9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -92,7 +92,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config, willFundRates_opt: Option[LiquidityAds.WillFundRates], peerWakeUpConfig: PeerReadyNotifier.WakeUpConfig, - onTheFlyFundingConfig: OnTheFlyFunding.Config) { + onTheFlyFundingConfig: OnTheFlyFunding.Config, + peerStorageWriteDelayMax: FiniteDuration) { val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey val nodeId: PublicKey = nodeKeyManager.nodeId @@ -678,6 +679,7 @@ object NodeParams extends Logging { onTheFlyFundingConfig = OnTheFlyFunding.Config( proposalTimeout = FiniteDuration(config.getDuration("on-the-fly-funding.proposal-timeout").getSeconds, TimeUnit.SECONDS), ), + peerStorageWriteDelayMax = 1 minute, ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index d1ff3487ea..05219d7eec 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -15,6 +15,7 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAddress, NodeAnnouncement} import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, Paginated, RealShortChannelId, ShortChannelId, TimestampMilli} import grizzled.slf4j.Logging +import scodec.bits.ByteVector import java.io.File import java.util.UUID @@ -292,6 +293,16 @@ case class DualPeersDb(primary: PeersDb, secondary: PeersDb) extends PeersDb { runAsync(secondary.getRelayFees(nodeId)) primary.getRelayFees(nodeId) } + + override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = { + runAsync(secondary.updateStorage(nodeId, data)) + primary.updateStorage(nodeId, data) + } + + override def getStorage(nodeId: PublicKey): Option[ByteVector] = { + runAsync(secondary.getStorage(nodeId)) + primary.getStorage(nodeId) + } } case class DualPaymentsDb(primary: PaymentsDb, secondary: PaymentsDb) extends PaymentsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala index ea10f348e8..b1824071d0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala @@ -19,6 +19,7 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.wire.protocol.NodeAddress +import scodec.bits.ByteVector trait PeersDb { @@ -34,4 +35,8 @@ trait PeersDb { def getRelayFees(nodeId: PublicKey): Option[RelayFees] + def updateStorage(nodeId: PublicKey, data: ByteVector): Unit + + def getStorage(nodeId: PublicKey): Option[ByteVector] + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala index 45094dd4a9..76102f9482 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala @@ -26,14 +26,14 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging -import scodec.bits.BitVector +import scodec.bits.{BitVector, ByteVector} import java.sql.Statement import javax.sql.DataSource object PgPeersDb { val DB_NAME = "peers" - val CURRENT_VERSION = 3 + val CURRENT_VERSION = 4 } class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logging { @@ -54,13 +54,18 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)") } + def migration34(statement: Statement): Unit = { + statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)") + } + using(pg.createStatement()) { statement => getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local") - statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)") + statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, storage BYTEA)") statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)") - case Some(v@(1 | 2)) => + statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)") + case Some(v@(1 | 2 | 3)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 2) { migration12(statement) @@ -68,6 +73,9 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg if (v < 3) { migration23(statement) } + if (v < 4) { + migration34(statement) + } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -98,6 +106,10 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg statement.setString(1, nodeId.value.toHex) statement.executeUpdate() } + using(pg.prepareStatement("DELETE FROM local.peer_storage WHERE node_id = ?")) { statement => + statement.setString(1, nodeId.value.toHex) + statement.executeUpdate() + } } } @@ -155,4 +167,31 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg } } } + + override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Postgres) { + withLock { pg => + using(pg.prepareStatement( + """ + INSERT INTO local.peer_storage (node_id, data) + VALUES (?, ?) + ON CONFLICT (node_id) + DO UPDATE SET data = EXCLUDED.data + """)) { statement => + statement.setString(1, nodeId.value.toHex) + statement.setBytes(2, data.toArray) + statement.executeUpdate() + } + } + } + + override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Postgres) { + withLock { pg => + using(pg.prepareStatement("SELECT data FROM local.peer_storage WHERE node_id = ?")) { statement => + statement.setString(1, nodeId.value.toHex) + statement.executeQuery() + .headOption + .map(rs => ByteVector(rs.getBytes("data"))) + } + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala index 610bb07909..ba99c46ad1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala @@ -26,13 +26,13 @@ import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, setVersion, using} import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging -import scodec.bits.BitVector +import scodec.bits.{BitVector, ByteVector} import java.sql.{Connection, Statement} object SqlitePeersDb { val DB_NAME = "peers" - val CURRENT_VERSION = 2 + val CURRENT_VERSION = 3 } class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging { @@ -46,13 +46,23 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging { statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)") } + def migration23(statement: Statement): Unit = { + statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)") + } + getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE TABLE peers (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)") statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)") - case Some(v@1) => + statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)") + case Some(v@(1 | 2)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") - migration12(statement) + if (v < 2) { + migration12(statement) + } + if (v < 3) { + migration23(statement) + } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -128,4 +138,27 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging { ) } } + + override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Sqlite) { + using(sqlite.prepareStatement("UPDATE peer_storage SET data = ? WHERE node_id = ?")) { update => + update.setBytes(1, data.toArray) + update.setBytes(2, nodeId.value.toArray) + if (update.executeUpdate() == 0) { + using(sqlite.prepareStatement("INSERT INTO peer_storage VALUES (?, ?)")) { statement => + statement.setBytes(1, nodeId.value.toArray) + statement.setBytes(2, data.toArray) + statement.executeUpdate() + } + } + } + } + + override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Sqlite) { + using(sqlite.prepareStatement("SELECT data FROM peer_storage WHERE node_id = ?")) { statement => + statement.setBytes(1, nodeId.value.toArray) + statement.executeQuery() + .headOption + .map(rs => ByteVector(rs.getBytes("data"))) + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 68e31f63ba..e2a77a33ca 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -44,7 +44,8 @@ import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure -import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, RoutingMessage, SpliceInit, TlvStream, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc} +import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, PeerStorageRetrieval, PeerStorageStore, RoutingMessage, SpliceInit, TlvStream, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc} +import scodec.bits.ByteVector /** * This actor represents a logical peer. There is one [[Peer]] per unique remote node id at all time. @@ -84,7 +85,7 @@ class Peer(val nodeParams: NodeParams, FinalChannelId(state.channelId) -> channel }.toMap context.system.eventStream.publish(PeerCreated(self, remoteNodeId)) - goto(DISCONNECTED) using DisconnectedData(channels) // when we restart, we will attempt to reconnect right away, but then we'll wait + goto(DISCONNECTED) using DisconnectedData(channels, PeerStorage(nodeParams.db.peers.getStorage(remoteNodeId), written = true, TimestampMilli.min)) // when we restart, we will attempt to reconnect right away, but then we'll wait } when(DISCONNECTED) { @@ -93,7 +94,7 @@ class Peer(val nodeParams: NodeParams, stay() case Event(connectionReady: PeerConnection.ConnectionReady, d: DisconnectedData) => - gotoConnected(connectionReady, d.channels.map { case (k: ChannelId, v) => (k, v) }) + gotoConnected(connectionReady, d.channels.map { case (k: ChannelId, v) => (k, v) }, d.peerStorage) case Event(Terminated(actor), d: DisconnectedData) if d.channels.values.toSet.contains(actor) => // we have at most 2 ids: a TemporaryChannelId and a FinalChannelId @@ -454,7 +455,7 @@ class Peer(val nodeParams: NodeParams, stopPeer() } else { d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id) - goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) }) + goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) }, d.peerStorage) } case Event(Terminated(actor), d: ConnectedData) if d.channels.values.toSet.contains(actor) => @@ -473,7 +474,7 @@ class Peer(val nodeParams: NodeParams, log.debug(s"got new connection, killing current one and switching") d.peerConnection ! PeerConnection.Kill(KillReason.ConnectionReplaced) d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id) - gotoConnected(connectionReady, d.channels) + gotoConnected(connectionReady, d.channels, d.peerStorage) case Event(msg: OnionMessage, _: ConnectedData) => OnionMessages.process(nodeParams.privateKey, msg) match { @@ -506,6 +507,21 @@ class Peer(val nodeParams: NodeParams, d.peerConnection forward unknownMsg stay() + case Event(store: PeerStorageStore, d: ConnectedData) if nodeParams.features.hasFeature(Features.ProvideStorage) && d.channels.nonEmpty => + val timeSinceLastWrite = TimestampMilli.now() - d.peerStorage.lastWrite + val peerStorage = if (timeSinceLastWrite >= nodeParams.peerStorageWriteDelayMax) { + nodeParams.db.peers.updateStorage(remoteNodeId, store.blob) + PeerStorage(Some(store.blob), written = true, TimestampMilli.now()) + } else { + startSingleTimer("peer-storage-write", WritePeerStorage, nodeParams.peerStorageWriteDelayMax - timeSinceLastWrite) + PeerStorage(Some(store.blob), written = false, d.peerStorage.lastWrite) + } + stay() using d.copy(peerStorage = peerStorage) + + case Event(WritePeerStorage, d: ConnectedData) => + d.peerStorage.data.foreach(nodeParams.db.peers.updateStorage(remoteNodeId, _)) + stay() using d.copy(peerStorage = PeerStorage(d.peerStorage.data, written = true, TimestampMilli.now())) + case Event(unhandledMsg: LightningMessage, _) => log.warning("ignoring message {}", unhandledMsg) stay() @@ -716,7 +732,7 @@ class Peer(val nodeParams: NodeParams, context.system.eventStream.publish(PeerDisconnected(self, remoteNodeId)) } - private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef]): State = { + private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef], peerStorage: PeerStorage): State = { require(remoteNodeId == connectionReady.remoteNodeId, s"invalid nodeId: $remoteNodeId != ${connectionReady.remoteNodeId}") log.debug("got authenticated connection to address {}", connectionReady.address) @@ -726,6 +742,9 @@ class Peer(val nodeParams: NodeParams, nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, connectionReady.address) } + // If we have some data stored from our peer, we send it to them before doing anything else. + peerStorage.data.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_)) + // let's bring existing/requested channels online channels.values.toSet[ActorRef].foreach(_ ! INPUT_RECONNECTED(connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit)) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id) @@ -742,7 +761,7 @@ class Peer(val nodeParams: NodeParams, connectionReady.peerConnection ! CurrentFeeCredit(nodeParams.chainHash, feeCredit.getOrElse(0 msat)) } - goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels) + goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels, peerStorage) } /** @@ -877,12 +896,18 @@ object Peer { case class TemporaryChannelId(id: ByteVector32) extends ChannelId case class FinalChannelId(id: ByteVector32) extends ChannelId + case class PeerStorage(data: Option[ByteVector], written: Boolean, lastWrite: TimestampMilli) + sealed trait Data { def channels: Map[_ <: ChannelId, ActorRef] // will be overridden by Map[FinalChannelId, ActorRef] or Map[ChannelId, ActorRef] + def peerStorage: PeerStorage } - case object Nothing extends Data { override def channels = Map.empty } - case class DisconnectedData(channels: Map[FinalChannelId, ActorRef]) extends Data - case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef]) extends Data { + case object Nothing extends Data { + override def channels = Map.empty + override def peerStorage: PeerStorage = PeerStorage(None, written = true, TimestampMilli.min) + } + case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], peerStorage: PeerStorage) extends Data + case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], peerStorage: PeerStorage) extends Data { val connectionInfo: ConnectionInfo = ConnectionInfo(address, peerConnection, localInit, remoteInit) def localFeatures: Features[InitFeature] = localInit.features def remoteFeatures: Features[InitFeature] = remoteInit.features @@ -993,5 +1018,7 @@ object Peer { case class RelayOnionMessage(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]) case class RelayUnknownMessage(unknownMessage: UnknownMessage) + + case object WritePeerStorage // @formatter:on } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala index b422d8598c..1c9c5869a5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala @@ -22,7 +22,7 @@ import fr.acinq.eclair.wire.protocol.CommonCodecs._ import fr.acinq.eclair.{Features, InitFeature, KamonExt} import scodec.bits.{BinStringSyntax, BitVector, ByteVector} import scodec.codecs._ -import scodec.{Attempt, Codec} +import scodec.{Attempt, Codec, Err} /** * Created by PM on 15/11/2016. @@ -389,6 +389,17 @@ object LightningMessageCodecs { ("onionPacket" | MessageOnionCodecs.messageOnionPacketCodec) :: ("tlvStream" | OnionMessageTlv.onionMessageTlvCodec)).as[OnionMessage] + private def isAcceptableBlobLength(length: Int) = + if (length <= 65531) Attempt.Successful(length) else Attempt.failure(Err(s"length $length is larger than 65531")) + + val peerStorageStore: Codec[PeerStorageStore] = ( + ("blob" | variableSizeBytes(uint16.exmap(isAcceptableBlobLength, isAcceptableBlobLength), bytes)) :: + ("tlvStream" | PeerStorageTlv.peerStorageTlvCodec)).as[PeerStorageStore] + + val peerStorageRetrieval: Codec[PeerStorageRetrieval] = ( + ("blob" | variableSizeBytes(uint16.exmap(isAcceptableBlobLength, isAcceptableBlobLength), bytes)) :: + ("tlvStream" | PeerStorageTlv.peerStorageTlvCodec)).as[PeerStorageRetrieval] + // NB: blank lines to minimize merge conflicts // @@ -476,6 +487,8 @@ object LightningMessageCodecs { val lightningMessageCodec = discriminated[LightningMessage].by(uint16) .typecase(1, warningCodec) .typecase(2, stfuCodec) + .typecase(7, peerStorageStore) + .typecase(9, peerStorageRetrieval) .typecase(16, initCodec) .typecase(17, errorCodec) .typecase(18, pingCodec) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala index e045da1475..4d6a985ccb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala @@ -600,6 +600,10 @@ case class GossipTimestampFilter(chainHash: BlockHash, firstTimestamp: Timestamp case class OnionMessage(blindingKey: PublicKey, onionRoutingPacket: OnionRoutingPacket, tlvStream: TlvStream[OnionMessageTlv] = TlvStream.empty) extends LightningMessage +case class PeerStorageStore(blob: ByteVector, tlvStream: TlvStream[PeerStorageTlv] = TlvStream.empty) extends LightningMessage + +case class PeerStorageRetrieval(blob: ByteVector, tlvStream: TlvStream[PeerStorageTlv] = TlvStream.empty) extends LightningMessage + // NB: blank lines to minimize merge conflicts // diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PeerStorageTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PeerStorageTlv.scala new file mode 100644 index 0000000000..4ebb3fec39 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/PeerStorageTlv.scala @@ -0,0 +1,32 @@ +/* + * Copyright 2021 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire.protocol + +import fr.acinq.eclair.wire.protocol.CommonCodecs.varint +import fr.acinq.eclair.wire.protocol.TlvCodecs.tlvStream +import scodec.Codec +import scodec.codecs.discriminated + +/** + * Created by thomash on July 2024. + */ + +sealed trait PeerStorageTlv extends Tlv + +object PeerStorageTlv { + val peerStorageTlvCodec: Codec[TlvStream[PeerStorageTlv]] = tlvStream(discriminated[PeerStorageTlv].by(varint)) +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 6f56850eb0..8a27a0753e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -106,6 +106,7 @@ object TestConstants { Features.PaymentMetadata -> FeatureSupport.Optional, Features.RouteBlinding -> FeatureSupport.Optional, Features.StaticRemoteKey -> FeatureSupport.Mandatory, + Features.ProvideStorage -> FeatureSupport.Optional, ), unknown = Set(UnknownFeature(TestFeature.optional)) ), @@ -238,6 +239,7 @@ object TestConstants { willFundRates_opt = Some(defaultLiquidityRates), peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds), onTheFlyFundingConfig = OnTheFlyFunding.Config(proposalTimeout = 90 seconds), + peerStorageWriteDelayMax = 5 seconds, ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( @@ -412,6 +414,7 @@ object TestConstants { willFundRates_opt = Some(defaultLiquidityRates), peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds), onTheFlyFundingConfig = OnTheFlyFunding.Config(proposalTimeout = 90 seconds), + peerStorageWriteDelayMax = 5 seconds, ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala index 848b946f04..76f4719014 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala @@ -24,6 +24,7 @@ import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair._ import fr.acinq.eclair.wire.protocol.{NodeAddress, Tor2, Tor3} import org.scalatest.funsuite.AnyFunSuite +import scodec.bits.HexStringSyntax import java.util.concurrent.Executors import scala.concurrent.duration._ @@ -107,4 +108,24 @@ class PeersDbSpec extends AnyFunSuite { } } + test("peer storage") { + forAllDbs { dbs => + val db = dbs.peers + + val a = randomKey().publicKey + val b = randomKey().publicKey + + assert(db.getStorage(a) == None) + assert(db.getStorage(b) == None) + db.updateStorage(a, hex"012345") + assert(db.getStorage(a) == Some(hex"012345")) + assert(db.getStorage(b) == None) + db.updateStorage(a, hex"6789") + assert(db.getStorage(a) == Some(hex"6789")) + assert(db.getStorage(b) == None) + db.updateStorage(b, hex"abcd") + assert(db.getStorage(a) == Some(hex"6789")) + assert(db.getStorage(b) == Some(hex"abcd")) + } + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index 7420825a2b..a16a93cb16 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -37,7 +37,7 @@ import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec.localParams import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol._ import org.scalatest.{Tag, TestData} -import scodec.bits.ByteVector +import scodec.bits.{ByteVector, HexStringSyntax} import java.net.InetSocketAddress import java.nio.channels.ServerSocketChannel @@ -107,11 +107,14 @@ class PeerSpec extends FixtureSpec { def cleanupFixture(fixture: FixtureParam): Unit = fixture.cleanup() - def connect(remoteNodeId: PublicKey, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, switchboard: TestProbe, channels: Set[PersistentChannelData] = Set.empty, remoteInit: protocol.Init = protocol.Init(Bob.nodeParams.features.initFeatures()))(implicit system: ActorSystem): Unit = { + def connect(remoteNodeId: PublicKey, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, switchboard: TestProbe, channels: Set[PersistentChannelData] = Set.empty, remoteInit: protocol.Init = protocol.Init(Bob.nodeParams.features.initFeatures()), sendInit: Boolean = true, peerStorage: Option[ByteVector] = None)(implicit system: ActorSystem): Unit = { // let's simulate a connection - switchboard.send(peer, Peer.Init(channels, Map.empty)) + if (sendInit) { + switchboard.send(peer, Peer.Init(channels, Map.empty)) + } val localInit = protocol.Init(peer.underlyingActor.nodeParams.features.initFeatures()) switchboard.send(peer, PeerConnection.ConnectionReady(peerConnection.ref, remoteNodeId, fakeIPAddress, outgoing = true, localInit, remoteInit)) + peerStorage.foreach(data => peerConnection.expectMsg(PeerStorageRetrieval(data))) peerConnection.expectMsgType[RecommendedFeerates] val probe = TestProbe() probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) @@ -675,6 +678,27 @@ class PeerSpec extends FixtureSpec { probe.expectTerminated(peer) } + test("peer storage") { f => + import f._ + + val peerConnection1 = peerConnection + val peerConnection2 = TestProbe() + val peerConnection3 = TestProbe() + + nodeParams.db.peers.updateStorage(remoteNodeId, hex"abcdef") + connect(remoteNodeId, peer, peerConnection1, switchboard, channels = Set(ChannelCodecsSpec.normal), peerStorage = Some(hex"abcdef")) + peerConnection1.send(peer, PeerStorageStore(hex"c0ffee")) + peerConnection1.send(peer, PeerStorageStore(hex"0123456789")) + Thread.sleep(1000) + peer ! Peer.Disconnect(f.remoteNodeId) + connect(remoteNodeId, peer, peerConnection2, switchboard, channels = Set(ChannelCodecsSpec.normal), sendInit = false, peerStorage = Some(hex"0123456789")) + peerConnection2.send(peer, PeerStorageStore(hex"1111")) + connect(remoteNodeId, peer, peerConnection3, switchboard, channels = Set(ChannelCodecsSpec.normal), sendInit = false, peerStorage = Some(hex"1111")) + assert(nodeParams.db.peers.getStorage(remoteNodeId).contains(hex"c0ffee")) // Only the first update was written because of the rate limit. + Thread.sleep(5_000) + assert(nodeParams.db.peers.getStorage(remoteNodeId).contains(hex"1111")) + } + } object PeerSpec { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala index 1e4721819b..22a934ef99 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.io import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair._ -import fr.acinq.eclair.io.Peer.ChannelId +import fr.acinq.eclair.io.Peer.{ChannelId, PeerStorage} import fr.acinq.eclair.io.ReconnectionTask.WaitingData import fr.acinq.eclair.tor.Socks5ProxyParams import fr.acinq.eclair.wire.protocol.{Color, IPv4, NodeAddress, NodeAnnouncement} @@ -37,8 +37,8 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike private val channels = Map(Peer.FinalChannelId(randomBytes32()) -> system.deadLetters) private val PeerNothingData = Peer.Nothing - private val PeerDisconnectedData = Peer.DisconnectedData(channels) - private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }) + private val PeerDisconnectedData = Peer.DisconnectedData(channels, PeerStorage(None, written = true, TimestampMilli.min)) + private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }, PeerStorage(None, written = true, TimestampMilli.min)) case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, reconnectionTask: TestFSMRef[ReconnectionTask.State, ReconnectionTask.Data, ReconnectionTask], monitor: TestProbe) @@ -81,7 +81,7 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike import f._ val peer = TestProbe() - peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty))) + peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty, PeerStorage(None, written = true, TimestampMilli.min)))) monitor.expectNoMessage() }