Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Peer storage #2888

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/Features.scala
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ object Features {
val mandatory = 38
}

case object ProvideStorage extends Feature with InitFeature with NodeFeature {
val rfcName = "option_provide_storage"
val mandatory = 42
}

case object ChannelType extends Feature with InitFeature with NodeFeature {
val rfcName = "option_channel_type"
val mandatory = 44
Expand Down Expand Up @@ -358,6 +363,7 @@ object Features {
DualFunding,
Quiescence,
OnionMessages,
ProvideStorage,
ChannelType,
ScidAlias,
PaymentMetadata,
Expand Down
4 changes: 3 additions & 1 deletion eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager,
revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config,
willFundRates_opt: Option[LiquidityAds.WillFundRates],
peerWakeUpConfig: PeerReadyNotifier.WakeUpConfig,
onTheFlyFundingConfig: OnTheFlyFunding.Config) {
onTheFlyFundingConfig: OnTheFlyFunding.Config,
peerStorageWriteDelayMax: FiniteDuration) {
val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey

val nodeId: PublicKey = nodeKeyManager.nodeId
Expand Down Expand Up @@ -678,6 +679,7 @@ object NodeParams extends Logging {
onTheFlyFundingConfig = OnTheFlyFunding.Config(
proposalTimeout = FiniteDuration(config.getDuration("on-the-fly-funding.proposal-timeout").getSeconds, TimeUnit.SECONDS),
),
peerStorageWriteDelayMax = 1 minute,
)
}
}
11 changes: 11 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAddress, NodeAnnouncement}
import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, Paginated, RealShortChannelId, ShortChannelId, TimestampMilli}
import grizzled.slf4j.Logging
import scodec.bits.ByteVector

import java.io.File
import java.util.UUID
Expand Down Expand Up @@ -292,6 +293,16 @@ case class DualPeersDb(primary: PeersDb, secondary: PeersDb) extends PeersDb {
runAsync(secondary.getRelayFees(nodeId))
primary.getRelayFees(nodeId)
}

override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = {
runAsync(secondary.updateStorage(nodeId, data))
primary.updateStorage(nodeId, data)
}

override def getStorage(nodeId: PublicKey): Option[ByteVector] = {
runAsync(secondary.getStorage(nodeId))
primary.getStorage(nodeId)
}
}

case class DualPaymentsDb(primary: PaymentsDb, secondary: PaymentsDb) extends PaymentsDb {
Expand Down
5 changes: 5 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package fr.acinq.eclair.db
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.payment.relay.Relayer.RelayFees
import fr.acinq.eclair.wire.protocol.NodeAddress
import scodec.bits.ByteVector

trait PeersDb {

Expand All @@ -34,4 +35,8 @@ trait PeersDb {

def getRelayFees(nodeId: PublicKey): Option[RelayFees]

def updateStorage(nodeId: PublicKey, data: ByteVector): Unit

def getStorage(nodeId: PublicKey): Option[ByteVector]

}
47 changes: 43 additions & 4 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock
import fr.acinq.eclair.payment.relay.Relayer.RelayFees
import fr.acinq.eclair.wire.protocol._
import grizzled.slf4j.Logging
import scodec.bits.BitVector
import scodec.bits.{BitVector, ByteVector}

import java.sql.Statement
import javax.sql.DataSource

object PgPeersDb {
val DB_NAME = "peers"
val CURRENT_VERSION = 3
val CURRENT_VERSION = 4
}

class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logging {
Expand All @@ -54,20 +54,28 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg
statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)")
}

def migration34(statement: Statement): Unit = {
statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
}

using(pg.createStatement()) { statement =>
getVersion(statement, DB_NAME) match {
case None =>
statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local")
statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, storage BYTEA)")
statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)")
case Some(v@(1 | 2)) =>
statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
case Some(v@(1 | 2 | 3)) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
if (v < 2) {
migration12(statement)
}
if (v < 3) {
migration23(statement)
}
if (v < 4) {
migration34(statement)
}
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
Expand Down Expand Up @@ -98,6 +106,10 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg
statement.setString(1, nodeId.value.toHex)
statement.executeUpdate()
}
using(pg.prepareStatement("DELETE FROM local.peer_storage WHERE node_id = ?")) { statement =>
statement.setString(1, nodeId.value.toHex)
statement.executeUpdate()
}
}
}

Expand Down Expand Up @@ -155,4 +167,31 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg
}
}
}

override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Postgres) {
withLock { pg =>
using(pg.prepareStatement(
"""
INSERT INTO local.peer_storage (node_id, data)
VALUES (?, ?)
ON CONFLICT (node_id)
DO UPDATE SET data = EXCLUDED.data
""")) { statement =>
statement.setString(1, nodeId.value.toHex)
statement.setBytes(2, data.toArray)
statement.executeUpdate()
}
}
}

override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Postgres) {
withLock { pg =>
using(pg.prepareStatement("SELECT data FROM local.peer_storage WHERE node_id = ?")) { statement =>
statement.setString(1, nodeId.value.toHex)
statement.executeQuery()
.headOption
.map(rs => ByteVector(rs.getBytes("data")))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, setVersion, using}
import fr.acinq.eclair.payment.relay.Relayer.RelayFees
import fr.acinq.eclair.wire.protocol._
import grizzled.slf4j.Logging
import scodec.bits.BitVector
import scodec.bits.{BitVector, ByteVector}

import java.sql.{Connection, Statement}

object SqlitePeersDb {
val DB_NAME = "peers"
val CURRENT_VERSION = 2
val CURRENT_VERSION = 3
}

class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging {
Expand All @@ -46,13 +46,23 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging {
statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)")
}

def migration23(statement: Statement): Unit = {
statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)")
}

getVersion(statement, DB_NAME) match {
case None =>
statement.executeUpdate("CREATE TABLE peers (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)")
case Some(v@1) =>
statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)")
case Some(v@(1 | 2)) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration12(statement)
if (v < 2) {
migration12(statement)
}
if (v < 3) {
migration23(statement)
}
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
Expand Down Expand Up @@ -128,4 +138,27 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging {
)
}
}

override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Sqlite) {
using(sqlite.prepareStatement("UPDATE peer_storage SET data = ? WHERE node_id = ?")) { update =>
update.setBytes(1, data.toArray)
update.setBytes(2, nodeId.value.toArray)
if (update.executeUpdate() == 0) {
using(sqlite.prepareStatement("INSERT INTO peer_storage VALUES (?, ?)")) { statement =>
statement.setBytes(1, nodeId.value.toArray)
statement.setBytes(2, data.toArray)
statement.executeUpdate()
}
}
}
}

override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Sqlite) {
using(sqlite.prepareStatement("SELECT data FROM peer_storage WHERE node_id = ?")) { statement =>
statement.setBytes(1, nodeId.value.toArray)
statement.executeQuery()
.headOption
.map(rs => ByteVector(rs.getBytes("data")))
}
}
}
47 changes: 37 additions & 10 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
import fr.acinq.eclair.router.Router
import fr.acinq.eclair.wire.protocol
import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure
import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, RoutingMessage, SpliceInit, TlvStream, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc}
import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, PeerStorageRetrieval, PeerStorageStore, RoutingMessage, SpliceInit, TlvStream, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc}
import scodec.bits.ByteVector

/**
* This actor represents a logical peer. There is one [[Peer]] per unique remote node id at all time.
Expand Down Expand Up @@ -84,7 +85,7 @@ class Peer(val nodeParams: NodeParams,
FinalChannelId(state.channelId) -> channel
}.toMap
context.system.eventStream.publish(PeerCreated(self, remoteNodeId))
goto(DISCONNECTED) using DisconnectedData(channels) // when we restart, we will attempt to reconnect right away, but then we'll wait
goto(DISCONNECTED) using DisconnectedData(channels, PeerStorage(nodeParams.db.peers.getStorage(remoteNodeId), written = true, TimestampMilli.min)) // when we restart, we will attempt to reconnect right away, but then we'll wait
}

when(DISCONNECTED) {
Expand All @@ -93,7 +94,7 @@ class Peer(val nodeParams: NodeParams,
stay()

case Event(connectionReady: PeerConnection.ConnectionReady, d: DisconnectedData) =>
gotoConnected(connectionReady, d.channels.map { case (k: ChannelId, v) => (k, v) })
gotoConnected(connectionReady, d.channels.map { case (k: ChannelId, v) => (k, v) }, d.peerStorage)

case Event(Terminated(actor), d: DisconnectedData) if d.channels.values.toSet.contains(actor) =>
// we have at most 2 ids: a TemporaryChannelId and a FinalChannelId
Expand Down Expand Up @@ -454,7 +455,7 @@ class Peer(val nodeParams: NodeParams,
stopPeer()
} else {
d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id)
goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) })
goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) }, d.peerStorage)
}

case Event(Terminated(actor), d: ConnectedData) if d.channels.values.toSet.contains(actor) =>
Expand All @@ -473,7 +474,7 @@ class Peer(val nodeParams: NodeParams,
log.debug(s"got new connection, killing current one and switching")
d.peerConnection ! PeerConnection.Kill(KillReason.ConnectionReplaced)
d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id)
gotoConnected(connectionReady, d.channels)
gotoConnected(connectionReady, d.channels, d.peerStorage)

case Event(msg: OnionMessage, _: ConnectedData) =>
OnionMessages.process(nodeParams.privateKey, msg) match {
Expand Down Expand Up @@ -506,6 +507,21 @@ class Peer(val nodeParams: NodeParams,
d.peerConnection forward unknownMsg
stay()

case Event(store: PeerStorageStore, d: ConnectedData) if nodeParams.features.hasFeature(Features.ProvideStorage) && d.channels.nonEmpty =>
val timeSinceLastWrite = TimestampMilli.now() - d.peerStorage.lastWrite
val peerStorage = if (timeSinceLastWrite >= nodeParams.peerStorageWriteDelayMax) {
nodeParams.db.peers.updateStorage(remoteNodeId, store.blob)
PeerStorage(Some(store.blob), written = true, TimestampMilli.now())
} else {
startSingleTimer("peer-storage-write", WritePeerStorage, nodeParams.peerStorageWriteDelayMax - timeSinceLastWrite)
PeerStorage(Some(store.blob), written = false, d.peerStorage.lastWrite)
}
stay() using d.copy(peerStorage = peerStorage)

case Event(WritePeerStorage, d: ConnectedData) =>
d.peerStorage.data.foreach(nodeParams.db.peers.updateStorage(remoteNodeId, _))
stay() using d.copy(peerStorage = PeerStorage(d.peerStorage.data, written = true, TimestampMilli.now()))

case Event(unhandledMsg: LightningMessage, _) =>
log.warning("ignoring message {}", unhandledMsg)
stay()
Expand Down Expand Up @@ -716,7 +732,7 @@ class Peer(val nodeParams: NodeParams,
context.system.eventStream.publish(PeerDisconnected(self, remoteNodeId))
}

private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef]): State = {
private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef], peerStorage: PeerStorage): State = {
require(remoteNodeId == connectionReady.remoteNodeId, s"invalid nodeId: $remoteNodeId != ${connectionReady.remoteNodeId}")
log.debug("got authenticated connection to address {}", connectionReady.address)

Expand All @@ -726,6 +742,9 @@ class Peer(val nodeParams: NodeParams,
nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, connectionReady.address)
}

// If we have some data stored from our peer, we send it to them before doing anything else.
peerStorage.data.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_))

// let's bring existing/requested channels online
channels.values.toSet[ActorRef].foreach(_ ! INPUT_RECONNECTED(connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit)) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id)

Expand All @@ -742,7 +761,7 @@ class Peer(val nodeParams: NodeParams,
connectionReady.peerConnection ! CurrentFeeCredit(nodeParams.chainHash, feeCredit.getOrElse(0 msat))
}

goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels)
goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels, peerStorage)
}

/**
Expand Down Expand Up @@ -877,12 +896,18 @@ object Peer {
case class TemporaryChannelId(id: ByteVector32) extends ChannelId
case class FinalChannelId(id: ByteVector32) extends ChannelId

case class PeerStorage(data: Option[ByteVector], written: Boolean, lastWrite: TimestampMilli)

sealed trait Data {
def channels: Map[_ <: ChannelId, ActorRef] // will be overridden by Map[FinalChannelId, ActorRef] or Map[ChannelId, ActorRef]
def peerStorage: PeerStorage
}
case object Nothing extends Data { override def channels = Map.empty }
case class DisconnectedData(channels: Map[FinalChannelId, ActorRef]) extends Data
case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef]) extends Data {
case object Nothing extends Data {
override def channels = Map.empty
override def peerStorage: PeerStorage = PeerStorage(None, written = true, TimestampMilli.min)
}
case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], peerStorage: PeerStorage) extends Data
case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], peerStorage: PeerStorage) extends Data {
val connectionInfo: ConnectionInfo = ConnectionInfo(address, peerConnection, localInit, remoteInit)
def localFeatures: Features[InitFeature] = localInit.features
def remoteFeatures: Features[InitFeature] = remoteInit.features
Expand Down Expand Up @@ -993,5 +1018,7 @@ object Peer {
case class RelayOnionMessage(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]])

case class RelayUnknownMessage(unknownMessage: UnknownMessage)

case object WritePeerStorage
// @formatter:on
}
Loading
Loading