From 26118a73e8ec586a3abc2d98e3cab1affa018acf Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Wed, 2 Oct 2024 11:41:15 +0200 Subject: [PATCH] Write peer storage to DB --- .../scala/fr/acinq/eclair/NodeParams.scala | 4 +- .../fr/acinq/eclair/db/DualDatabases.scala | 11 +++++ .../scala/fr/acinq/eclair/db/PeersDb.scala | 5 ++ .../fr/acinq/eclair/db/pg/PgPeersDb.scala | 47 +++++++++++++++++-- .../eclair/db/sqlite/SqlitePeersDb.scala | 41 ++++++++++++++-- .../main/scala/fr/acinq/eclair/io/Peer.scala | 36 +++++++++----- .../scala/fr/acinq/eclair/TestConstants.scala | 3 ++ .../fr/acinq/eclair/db/PeersDbSpec.scala | 21 +++++++++ .../scala/fr/acinq/eclair/io/PeerSpec.scala | 30 ++++++++++-- .../eclair/io/ReconnectionTaskSpec.scala | 8 ++-- 10 files changed, 179 insertions(+), 27 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 55a0b6fc8b..4c7af077c9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -92,7 +92,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config, willFundRates_opt: Option[LiquidityAds.WillFundRates], peerWakeUpConfig: PeerReadyNotifier.WakeUpConfig, - onTheFlyFundingConfig: OnTheFlyFunding.Config) { + onTheFlyFundingConfig: OnTheFlyFunding.Config, + peerStorageWriteDelayMax: FiniteDuration) { val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey val nodeId: PublicKey = nodeKeyManager.nodeId @@ -678,6 +679,7 @@ object NodeParams extends Logging { onTheFlyFundingConfig = OnTheFlyFunding.Config( proposalTimeout = FiniteDuration(config.getDuration("on-the-fly-funding.proposal-timeout").getSeconds, TimeUnit.SECONDS), ), + peerStorageWriteDelayMax = 1 minute, ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index d1ff3487ea..05219d7eec 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -15,6 +15,7 @@ import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAddress, NodeAnnouncement} import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, Paginated, RealShortChannelId, ShortChannelId, TimestampMilli} import grizzled.slf4j.Logging +import scodec.bits.ByteVector import java.io.File import java.util.UUID @@ -292,6 +293,16 @@ case class DualPeersDb(primary: PeersDb, secondary: PeersDb) extends PeersDb { runAsync(secondary.getRelayFees(nodeId)) primary.getRelayFees(nodeId) } + + override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = { + runAsync(secondary.updateStorage(nodeId, data)) + primary.updateStorage(nodeId, data) + } + + override def getStorage(nodeId: PublicKey): Option[ByteVector] = { + runAsync(secondary.getStorage(nodeId)) + primary.getStorage(nodeId) + } } case class DualPaymentsDb(primary: PaymentsDb, secondary: PaymentsDb) extends PaymentsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala index ea10f348e8..b1824071d0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala @@ -19,6 +19,7 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.wire.protocol.NodeAddress +import scodec.bits.ByteVector trait PeersDb { @@ -34,4 +35,8 @@ trait PeersDb { def getRelayFees(nodeId: PublicKey): Option[RelayFees] + def updateStorage(nodeId: PublicKey, data: ByteVector): Unit + + def getStorage(nodeId: PublicKey): Option[ByteVector] + } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala index 45094dd4a9..76102f9482 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala @@ -26,14 +26,14 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging -import scodec.bits.BitVector +import scodec.bits.{BitVector, ByteVector} import java.sql.Statement import javax.sql.DataSource object PgPeersDb { val DB_NAME = "peers" - val CURRENT_VERSION = 3 + val CURRENT_VERSION = 4 } class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logging { @@ -54,13 +54,18 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)") } + def migration34(statement: Statement): Unit = { + statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)") + } + using(pg.createStatement()) { statement => getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local") - statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)") + statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, storage BYTEA)") statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)") - case Some(v@(1 | 2)) => + statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)") + case Some(v@(1 | 2 | 3)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 2) { migration12(statement) @@ -68,6 +73,9 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg if (v < 3) { migration23(statement) } + if (v < 4) { + migration34(statement) + } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -98,6 +106,10 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg statement.setString(1, nodeId.value.toHex) statement.executeUpdate() } + using(pg.prepareStatement("DELETE FROM local.peer_storage WHERE node_id = ?")) { statement => + statement.setString(1, nodeId.value.toHex) + statement.executeUpdate() + } } } @@ -155,4 +167,31 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg } } } + + override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Postgres) { + withLock { pg => + using(pg.prepareStatement( + """ + INSERT INTO local.peer_storage (node_id, data) + VALUES (?, ?) + ON CONFLICT (node_id) + DO UPDATE SET data = EXCLUDED.data + """)) { statement => + statement.setString(1, nodeId.value.toHex) + statement.setBytes(2, data.toArray) + statement.executeUpdate() + } + } + } + + override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Postgres) { + withLock { pg => + using(pg.prepareStatement("SELECT data FROM local.peer_storage WHERE node_id = ?")) { statement => + statement.setString(1, nodeId.value.toHex) + statement.executeQuery() + .headOption + .map(rs => ByteVector(rs.getBytes("data"))) + } + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala index 610bb07909..ba99c46ad1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala @@ -26,13 +26,13 @@ import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, setVersion, using} import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging -import scodec.bits.BitVector +import scodec.bits.{BitVector, ByteVector} import java.sql.{Connection, Statement} object SqlitePeersDb { val DB_NAME = "peers" - val CURRENT_VERSION = 2 + val CURRENT_VERSION = 3 } class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging { @@ -46,13 +46,23 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging { statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)") } + def migration23(statement: Statement): Unit = { + statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)") + } + getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE TABLE peers (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)") statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)") - case Some(v@1) => + statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)") + case Some(v@(1 | 2)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") - migration12(statement) + if (v < 2) { + migration12(statement) + } + if (v < 3) { + migration23(statement) + } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -128,4 +138,27 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging { ) } } + + override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Sqlite) { + using(sqlite.prepareStatement("UPDATE peer_storage SET data = ? WHERE node_id = ?")) { update => + update.setBytes(1, data.toArray) + update.setBytes(2, nodeId.value.toArray) + if (update.executeUpdate() == 0) { + using(sqlite.prepareStatement("INSERT INTO peer_storage VALUES (?, ?)")) { statement => + statement.setBytes(1, nodeId.value.toArray) + statement.setBytes(2, data.toArray) + statement.executeUpdate() + } + } + } + } + + override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Sqlite) { + using(sqlite.prepareStatement("SELECT data FROM peer_storage WHERE node_id = ?")) { statement => + statement.setBytes(1, nodeId.value.toArray) + statement.executeQuery() + .headOption + .map(rs => ByteVector(rs.getBytes("data"))) + } + } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index ce0f1dc12a..e2a77a33ca 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -44,9 +44,7 @@ import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure -import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, RoutingMessage, SpliceInit, TlvStream, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc} -import fr.acinq.eclair.wire.protocol.LiquidityAds.PaymentDetails -import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, PeerStorageRetrieval, PeerStorageStore, RoutingMessage, SpliceInit, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc} +import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, PeerStorageRetrieval, PeerStorageStore, RoutingMessage, SpliceInit, TlvStream, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc} import scodec.bits.ByteVector /** @@ -87,7 +85,7 @@ class Peer(val nodeParams: NodeParams, FinalChannelId(state.channelId) -> channel }.toMap context.system.eventStream.publish(PeerCreated(self, remoteNodeId)) - goto(DISCONNECTED) using DisconnectedData(channels, None) // when we restart, we will attempt to reconnect right away, but then we'll wait + goto(DISCONNECTED) using DisconnectedData(channels, PeerStorage(nodeParams.db.peers.getStorage(remoteNodeId), written = true, TimestampMilli.min)) // when we restart, we will attempt to reconnect right away, but then we'll wait } when(DISCONNECTED) { @@ -510,7 +508,19 @@ class Peer(val nodeParams: NodeParams, stay() case Event(store: PeerStorageStore, d: ConnectedData) if nodeParams.features.hasFeature(Features.ProvideStorage) && d.channels.nonEmpty => - stay() using d.copy(peerStorage = Some(store.blob)) + val timeSinceLastWrite = TimestampMilli.now() - d.peerStorage.lastWrite + val peerStorage = if (timeSinceLastWrite >= nodeParams.peerStorageWriteDelayMax) { + nodeParams.db.peers.updateStorage(remoteNodeId, store.blob) + PeerStorage(Some(store.blob), written = true, TimestampMilli.now()) + } else { + startSingleTimer("peer-storage-write", WritePeerStorage, nodeParams.peerStorageWriteDelayMax - timeSinceLastWrite) + PeerStorage(Some(store.blob), written = false, d.peerStorage.lastWrite) + } + stay() using d.copy(peerStorage = peerStorage) + + case Event(WritePeerStorage, d: ConnectedData) => + d.peerStorage.data.foreach(nodeParams.db.peers.updateStorage(remoteNodeId, _)) + stay() using d.copy(peerStorage = PeerStorage(d.peerStorage.data, written = true, TimestampMilli.now())) case Event(unhandledMsg: LightningMessage, _) => log.warning("ignoring message {}", unhandledMsg) @@ -722,7 +732,7 @@ class Peer(val nodeParams: NodeParams, context.system.eventStream.publish(PeerDisconnected(self, remoteNodeId)) } - private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef], peerStorage: Option[ByteVector]): State = { + private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef], peerStorage: PeerStorage): State = { require(remoteNodeId == connectionReady.remoteNodeId, s"invalid nodeId: $remoteNodeId != ${connectionReady.remoteNodeId}") log.debug("got authenticated connection to address {}", connectionReady.address) @@ -733,7 +743,7 @@ class Peer(val nodeParams: NodeParams, } // If we have some data stored from our peer, we send it to them before doing anything else. - peerStorage.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_)) + peerStorage.data.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_)) // let's bring existing/requested channels online channels.values.toSet[ActorRef].foreach(_ ! INPUT_RECONNECTED(connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit)) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id) @@ -886,16 +896,18 @@ object Peer { case class TemporaryChannelId(id: ByteVector32) extends ChannelId case class FinalChannelId(id: ByteVector32) extends ChannelId + case class PeerStorage(data: Option[ByteVector], written: Boolean, lastWrite: TimestampMilli) + sealed trait Data { def channels: Map[_ <: ChannelId, ActorRef] // will be overridden by Map[FinalChannelId, ActorRef] or Map[ChannelId, ActorRef] - def peerStorage: Option[ByteVector] + def peerStorage: PeerStorage } case object Nothing extends Data { override def channels = Map.empty - override def peerStorage: Option[ByteVector] = None + override def peerStorage: PeerStorage = PeerStorage(None, written = true, TimestampMilli.min) } - case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], peerStorage: Option[ByteVector]) extends Data - case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], peerStorage: Option[ByteVector]) extends Data { + case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], peerStorage: PeerStorage) extends Data + case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], peerStorage: PeerStorage) extends Data { val connectionInfo: ConnectionInfo = ConnectionInfo(address, peerConnection, localInit, remoteInit) def localFeatures: Features[InitFeature] = localInit.features def remoteFeatures: Features[InitFeature] = remoteInit.features @@ -1006,5 +1018,7 @@ object Peer { case class RelayOnionMessage(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]) case class RelayUnknownMessage(unknownMessage: UnknownMessage) + + case object WritePeerStorage // @formatter:on } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 6f56850eb0..8a27a0753e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -106,6 +106,7 @@ object TestConstants { Features.PaymentMetadata -> FeatureSupport.Optional, Features.RouteBlinding -> FeatureSupport.Optional, Features.StaticRemoteKey -> FeatureSupport.Mandatory, + Features.ProvideStorage -> FeatureSupport.Optional, ), unknown = Set(UnknownFeature(TestFeature.optional)) ), @@ -238,6 +239,7 @@ object TestConstants { willFundRates_opt = Some(defaultLiquidityRates), peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds), onTheFlyFundingConfig = OnTheFlyFunding.Config(proposalTimeout = 90 seconds), + peerStorageWriteDelayMax = 5 seconds, ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( @@ -412,6 +414,7 @@ object TestConstants { willFundRates_opt = Some(defaultLiquidityRates), peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds), onTheFlyFundingConfig = OnTheFlyFunding.Config(proposalTimeout = 90 seconds), + peerStorageWriteDelayMax = 5 seconds, ) def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala index 848b946f04..76f4719014 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala @@ -24,6 +24,7 @@ import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair._ import fr.acinq.eclair.wire.protocol.{NodeAddress, Tor2, Tor3} import org.scalatest.funsuite.AnyFunSuite +import scodec.bits.HexStringSyntax import java.util.concurrent.Executors import scala.concurrent.duration._ @@ -107,4 +108,24 @@ class PeersDbSpec extends AnyFunSuite { } } + test("peer storage") { + forAllDbs { dbs => + val db = dbs.peers + + val a = randomKey().publicKey + val b = randomKey().publicKey + + assert(db.getStorage(a) == None) + assert(db.getStorage(b) == None) + db.updateStorage(a, hex"012345") + assert(db.getStorage(a) == Some(hex"012345")) + assert(db.getStorage(b) == None) + db.updateStorage(a, hex"6789") + assert(db.getStorage(a) == Some(hex"6789")) + assert(db.getStorage(b) == None) + db.updateStorage(b, hex"abcd") + assert(db.getStorage(a) == Some(hex"6789")) + assert(db.getStorage(b) == Some(hex"abcd")) + } + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index 7420825a2b..a16a93cb16 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -37,7 +37,7 @@ import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec.localParams import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol._ import org.scalatest.{Tag, TestData} -import scodec.bits.ByteVector +import scodec.bits.{ByteVector, HexStringSyntax} import java.net.InetSocketAddress import java.nio.channels.ServerSocketChannel @@ -107,11 +107,14 @@ class PeerSpec extends FixtureSpec { def cleanupFixture(fixture: FixtureParam): Unit = fixture.cleanup() - def connect(remoteNodeId: PublicKey, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, switchboard: TestProbe, channels: Set[PersistentChannelData] = Set.empty, remoteInit: protocol.Init = protocol.Init(Bob.nodeParams.features.initFeatures()))(implicit system: ActorSystem): Unit = { + def connect(remoteNodeId: PublicKey, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, switchboard: TestProbe, channels: Set[PersistentChannelData] = Set.empty, remoteInit: protocol.Init = protocol.Init(Bob.nodeParams.features.initFeatures()), sendInit: Boolean = true, peerStorage: Option[ByteVector] = None)(implicit system: ActorSystem): Unit = { // let's simulate a connection - switchboard.send(peer, Peer.Init(channels, Map.empty)) + if (sendInit) { + switchboard.send(peer, Peer.Init(channels, Map.empty)) + } val localInit = protocol.Init(peer.underlyingActor.nodeParams.features.initFeatures()) switchboard.send(peer, PeerConnection.ConnectionReady(peerConnection.ref, remoteNodeId, fakeIPAddress, outgoing = true, localInit, remoteInit)) + peerStorage.foreach(data => peerConnection.expectMsg(PeerStorageRetrieval(data))) peerConnection.expectMsgType[RecommendedFeerates] val probe = TestProbe() probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) @@ -675,6 +678,27 @@ class PeerSpec extends FixtureSpec { probe.expectTerminated(peer) } + test("peer storage") { f => + import f._ + + val peerConnection1 = peerConnection + val peerConnection2 = TestProbe() + val peerConnection3 = TestProbe() + + nodeParams.db.peers.updateStorage(remoteNodeId, hex"abcdef") + connect(remoteNodeId, peer, peerConnection1, switchboard, channels = Set(ChannelCodecsSpec.normal), peerStorage = Some(hex"abcdef")) + peerConnection1.send(peer, PeerStorageStore(hex"c0ffee")) + peerConnection1.send(peer, PeerStorageStore(hex"0123456789")) + Thread.sleep(1000) + peer ! Peer.Disconnect(f.remoteNodeId) + connect(remoteNodeId, peer, peerConnection2, switchboard, channels = Set(ChannelCodecsSpec.normal), sendInit = false, peerStorage = Some(hex"0123456789")) + peerConnection2.send(peer, PeerStorageStore(hex"1111")) + connect(remoteNodeId, peer, peerConnection3, switchboard, channels = Set(ChannelCodecsSpec.normal), sendInit = false, peerStorage = Some(hex"1111")) + assert(nodeParams.db.peers.getStorage(remoteNodeId).contains(hex"c0ffee")) // Only the first update was written because of the rate limit. + Thread.sleep(5_000) + assert(nodeParams.db.peers.getStorage(remoteNodeId).contains(hex"1111")) + } + } object PeerSpec { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala index 92221ce397..22a934ef99 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.io import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair._ -import fr.acinq.eclair.io.Peer.ChannelId +import fr.acinq.eclair.io.Peer.{ChannelId, PeerStorage} import fr.acinq.eclair.io.ReconnectionTask.WaitingData import fr.acinq.eclair.tor.Socks5ProxyParams import fr.acinq.eclair.wire.protocol.{Color, IPv4, NodeAddress, NodeAnnouncement} @@ -37,8 +37,8 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike private val channels = Map(Peer.FinalChannelId(randomBytes32()) -> system.deadLetters) private val PeerNothingData = Peer.Nothing - private val PeerDisconnectedData = Peer.DisconnectedData(channels, None) - private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }, None) + private val PeerDisconnectedData = Peer.DisconnectedData(channels, PeerStorage(None, written = true, TimestampMilli.min)) + private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }, PeerStorage(None, written = true, TimestampMilli.min)) case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, reconnectionTask: TestFSMRef[ReconnectionTask.State, ReconnectionTask.Data, ReconnectionTask], monitor: TestProbe) @@ -81,7 +81,7 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike import f._ val peer = TestProbe() - peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty, None))) + peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty, PeerStorage(None, written = true, TimestampMilli.min)))) monitor.expectNoMessage() }