Skip to content

Commit

Permalink
fix: improve performance for background jobs in multitenancy mode (#749)
Browse files Browse the repository at this point in the history
Signed-off-by: mineme0110 <[email protected]>
  • Loading branch information
mineme0110 authored Oct 10, 2023
1 parent 110eb2d commit 17def3f
Show file tree
Hide file tree
Showing 34 changed files with 853 additions and 369 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.iohk.atala.connect.core.model.ConnectionRecord.ProtocolState
import io.iohk.atala.mercury.protocol.connection.*
import io.iohk.atala.shared.models.WalletAccessContext
import zio.RIO

import zio.Task
import java.util.UUID

trait ConnectionRepository {
Expand All @@ -19,6 +19,12 @@ trait ConnectionRepository {
states: ConnectionRecord.ProtocolState*
): RIO[WalletAccessContext, Seq[ConnectionRecord]]

def getConnectionRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ConnectionRecord.ProtocolState*
): Task[Seq[ConnectionRecord]]

def getConnectionRecord(recordId: UUID): RIO[WalletAccessContext, Option[ConnectionRecord]]

def deleteConnectionRecord(recordId: UUID): RIO[WalletAccessContext, Int]
Expand Down Expand Up @@ -50,4 +56,5 @@ trait ConnectionRepository {
recordId: UUID,
failReason: Option[String],
): RIO[WalletAccessContext, Int]

}
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,30 @@ class ConnectionRepositoryInMemory(walletRefs: Ref[Map[WalletId, Ref[Map[UUID, C
.getOrElse(ZIO.succeed(0))
} yield count

def updateAfterFailForAllWallets(
recordId: UUID,
failReason: Option[String],
): Task[Int] = walletRefs.get.flatMap { wallets =>
ZIO.foldLeft(wallets.values)(0) { (acc, walletRef) =>
for {
records <- walletRef.get
count <- records.get(recordId) match {
case Some(record) =>
walletRef.update { r =>
r.updated(
recordId,
record.copy(
metaRetries = record.metaRetries - 1,
metaLastFailure = failReason
)
)
} *> ZIO.succeed(1) // Record updated, count as 1 update
case None => ZIO.succeed(0) // No record updated
}
} yield acc + count
}
}

override def getConnectionRecordByThreadId(thid: String): RIO[WalletAccessContext, Option[ConnectionRecord]] = {
for {
storeRef <- walletStoreRef
Expand Down Expand Up @@ -183,6 +207,27 @@ class ConnectionRepositoryInMemory(walletRefs: Ref[Map[WalletId, Ref[Map[UUID, C
.toSeq
}

override def getConnectionRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ConnectionRecord.ProtocolState*
): Task[Seq[ConnectionRecord]] = {

for {
refs <- walletRefs.get
stores <- ZIO.foreach(refs.values.toList)(_.get)
} yield {
stores
.flatMap(_.values)
.filter { rec =>
(!ignoreWithZeroRetries || rec.metaRetries > 0) &&
states.contains(rec.protocolState)
}
.take(limit)
.toSeq
}
}

override def createConnectionRecord(record: ConnectionRecord): RIO[WalletAccessContext, Int] = {
for {
_ <- for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ trait ConnectionService {
states: ConnectionRecord.ProtocolState*
): ZIO[WalletAccessContext, ConnectionServiceError, Seq[ConnectionRecord]]

def getConnectionRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ConnectionRecord.ProtocolState*
): IO[ConnectionServiceError, Seq[ConnectionRecord]]

def getConnectionRecord(recordId: UUID): ZIO[WalletAccessContext, ConnectionServiceError, Option[ConnectionRecord]]

def getConnectionRecordByThreadId(
Expand All @@ -65,5 +71,4 @@ trait ConnectionService {
recordId: UUID,
failReason: Option[String]
): ZIO[WalletAccessContext, ConnectionServiceError, Unit]

}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ private class ConnectionServiceImpl(
} yield records
}

override def getConnectionRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ProtocolState*
): IO[ConnectionServiceError, Seq[ConnectionRecord]] = {
for {
records <- connectionRepository
.getConnectionRecordsByStatesForAllWallets(ignoreWithZeroRetries, limit, states: _*)
.mapError(RepositoryError.apply)
} yield records
}

override def getConnectionRecord(
recordId: UUID
): ZIO[WalletAccessContext, ConnectionServiceError, Option[ConnectionRecord]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import io.iohk.atala.mercury.model.DidId
import io.iohk.atala.mercury.protocol.connection.{ConnectionRequest, ConnectionResponse}
import io.iohk.atala.shared.models.WalletAccessContext
import zio.{URLayer, ZIO, ZLayer}

import zio.IO
import java.time.Duration
import java.util.UUID

Expand Down Expand Up @@ -109,6 +109,13 @@ class ConnectionServiceNotifier(
states: ConnectionRecord.ProtocolState*
): ZIO[WalletAccessContext, ConnectionServiceError, Seq[ConnectionRecord]] =
svc.getConnectionRecordsByStates(ignoreWithZeroRetries, limit, states: _*)

override def getConnectionRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ConnectionRecord.ProtocolState*
): IO[ConnectionServiceError, Seq[ConnectionRecord]] =
svc.getConnectionRecordsByStatesForAllWallets(ignoreWithZeroRetries, limit, states: _*)
}

object ConnectionServiceNotifier {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ object MockConnectionService extends Mock[ConnectionService] {
states: ConnectionRecord.ProtocolState*
): IO[ConnectionServiceError, Seq[ConnectionRecord]] = ???

override def getConnectionRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ConnectionRecord.ProtocolState*
): IO[ConnectionServiceError, Seq[ConnectionRecord]] = ???

override def getConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]] = ???

override def getConnectionRecordByThreadId(thid: String): IO[ConnectionServiceError, Option[ConnectionRecord]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,46 @@ object ConnectionRepositorySpecSuite {
wallet1Record <- repo.getConnectionRecord(record.id).provide(wac1)
wallet2Record <- repo.getConnectionRecord(record.id).provide(wac2)
} yield assertTrue(wallet1Record.isDefined) && assertTrue(wallet2Record.isEmpty)
}
},
test("getConnectionRecordsByStatesForAllWallets returns correct records for all wallets") {
val walletId1 = WalletId.random
val walletId2 = WalletId.random
for {
repo <- ZIO.service[ConnectionRepository]

wac1 = ZLayer.succeed(WalletAccessContext(walletId1))
wac2 = ZLayer.succeed(WalletAccessContext(walletId2))
aRecordWallet1 = connectionRecord
bRecordWallet2 = connectionRecord
_ <- repo.createConnectionRecord(aRecordWallet1).provide(wac1)
_ <- repo.createConnectionRecord(bRecordWallet2).provide(wac2)
_ <- repo
.updateConnectionProtocolState(
aRecordWallet1.id,
ProtocolState.InvitationGenerated,
ProtocolState.ConnectionRequestReceived,
1
)
.provide(wac1)
_ <- repo
.updateConnectionProtocolState(
bRecordWallet2.id,
ProtocolState.InvitationGenerated,
ProtocolState.ConnectionResponsePending,
1
)
.provide(wac2)
allWalletRecords <- repo.getConnectionRecordsByStatesForAllWallets(
ignoreWithZeroRetries = true,
limit = 10,
ProtocolState.ConnectionRequestReceived,
ProtocolState.ConnectionResponsePending
)
} yield {
assertTrue(allWalletRecords.size == 2) &&
assertTrue(allWalletRecords.exists(_.id == aRecordWallet1.id)) &&
assertTrue(allWalletRecords.exists(_.id == bRecordWallet2.id))
}
},
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ import io.iohk.atala.shared.db.Implicits.*
import io.iohk.atala.shared.models.WalletAccessContext
import org.postgresql.util.PSQLException
import zio.*

import zio.interop.catz.*
import java.time.Instant
import java.util.UUID
import doobie.free.connection

class JdbcConnectionRepository(xa: Transactor[ContextAwareTask]) extends ConnectionRepository {
class JdbcConnectionRepository(xa: Transactor[ContextAwareTask], xb: Transactor[Task]) extends ConnectionRepository {

// given logHandler: LogHandler = LogHandler.jdkLogHandler

Expand Down Expand Up @@ -114,9 +115,25 @@ class JdbcConnectionRepository(xa: Transactor[ContextAwareTask]) extends Connect
limit: Int,
states: ConnectionRecord.ProtocolState*
): RIO[WalletAccessContext, Seq[ConnectionRecord]] = {
getRecordsByStates(ignoreWithZeroRetries, limit, states: _*).transactWallet(xa)
}

override def getConnectionRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ConnectionRecord.ProtocolState*
): Task[Seq[ConnectionRecord]] = {
getRecordsByStates(ignoreWithZeroRetries, limit, states: _*).transact(xb)
}

private def getRecordsByStates(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ConnectionRecord.ProtocolState*
): ConnectionIO[Seq[ConnectionRecord]] = {
states match
case Nil =>
ZIO.succeed(Nil)
connection.pure(Nil)
case head +: tail =>
val nel = NonEmptyList.of(head, tail: _*)
val inClauseFragment = Fragments.in(fr"protocol_state", nel)
Expand Down Expand Up @@ -148,7 +165,6 @@ class JdbcConnectionRepository(xa: Transactor[ContextAwareTask]) extends Connect
.to[Seq]

cxnIO
.transactWallet(xa)
}

override def getConnectionRecord(recordId: UUID): RIO[WalletAccessContext, Option[ConnectionRecord]] = {
Expand Down Expand Up @@ -298,10 +314,8 @@ class JdbcConnectionRepository(xa: Transactor[ContextAwareTask]) extends Connect
""".stripMargin.update
cxnIO.run.transactWallet(xa)
}

}

object JdbcConnectionRepository {
val layer: URLayer[Transactor[ContextAwareTask], ConnectionRepository] =
ZLayer.fromFunction(new JdbcConnectionRepository(_))
val layer: URLayer[Transactor[ContextAwareTask] & Transactor[Task], ConnectionRepository] =
ZLayer.fromFunction(new JdbcConnectionRepository(_, _))
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ object JdbcConnectionRepositorySpec extends ZIOSpecDefault, PostgresTestContaine
Migrations.layer,
dbConfig,
pgContainerLayer,
contextAwareTransactorLayer
contextAwareTransactorLayer,
systemTransactorLayer
)

override def spec: Spec[TestEnvironment with Scope, Any] =
Expand Down
4 changes: 2 additions & 2 deletions infrastructure/shared/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ services:
AGENT_DB_NAME: agent
AGENT_DB_USER: postgres
AGENT_DB_PASSWORD: postgres
DIDCOMM_SERVICE_URL: ${DIDCOMM_SERVICE_URL:-http://$${DOCKERHOST}:$${PORT}/didcomm}
REST_SERVICE_URL: ${REST_SERVICE_URL:-http://$${DOCKERHOST}:$${PORT}/prism-agent}
DIDCOMM_SERVICE_URL: ${DIDCOMM_SERVICE_URL:-http://${DOCKERHOST}:${PORT}/didcomm}
REST_SERVICE_URL: ${REST_SERVICE_URL:-http://${DOCKERHOST}:${PORT}/prism-agent}
PRISM_NODE_HOST: prism-node
PRISM_NODE_PORT: 50053
VAULT_ADDR: ${VAULT_ADDR:-http://vault-server:8200}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ trait CredentialRepository {
states: IssueCredentialRecord.ProtocolState*
): RIO[WalletAccessContext, Seq[IssueCredentialRecord]]

def getIssueCredentialRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: IssueCredentialRecord.ProtocolState*
): Task[Seq[IssueCredentialRecord]]

def getIssueCredentialRecordByThreadId(
thid: DidCommID,
ignoreWithZeroRetries: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,26 @@ class CredentialRepositoryInMemory(
.toSeq
}

override def getIssueCredentialRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ProtocolState*
): Task[Seq[IssueCredentialRecord]] = {
for {
refs <- walletRefs.get
stores <- ZIO.foreach(refs.values.toList)(_.get)
} yield {
stores
.flatMap(_.values)
.filter { rec =>
(!ignoreWithZeroRetries || rec.metaRetries > 0) &&
states.contains(rec.protocolState)
}
.take(limit)
.toSeq
}
}

override def getIssueCredentialRecordByThreadId(
thid: DidCommID,
ignoreWithZeroRetries: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ trait PresentationRepository {
limit: Int,
states: PresentationRecord.ProtocolState*
): RIO[WalletAccessContext, Seq[PresentationRecord]]

def getPresentationRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: PresentationRecord.ProtocolState*
): Task[Seq[PresentationRecord]]

def getPresentationRecordByThreadId(thid: DidCommID): RIO[WalletAccessContext, Option[PresentationRecord]]

def updatePresentationRecordProtocolState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,26 @@ class PresentationRepositoryInMemory(
.toSeq
}

override def getPresentationRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: ProtocolState*
): Task[Seq[PresentationRecord]] = {
for {
refs <- walletRefs.get
stores <- ZIO.foreach(refs.values.toList)(_.get)
} yield {
stores
.flatMap(_.values)
.filter { rec =>
(!ignoreWithZeroRetries || rec.metaRetries > 0) &&
states.contains(rec.protocolState)
}
.take(limit)
.toSeq
}
}

override def getPresentationRecordByThreadId(
thid: DidCommID,
): RIO[WalletAccessContext, Option[PresentationRecord]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ trait CredentialService {
states: IssueCredentialRecord.ProtocolState*
): ZIO[WalletAccessContext, CredentialServiceError, Seq[IssueCredentialRecord]]

def getIssueCredentialRecordsByStatesForAllWallets(
ignoreWithZeroRetries: Boolean,
limit: Int,
states: IssueCredentialRecord.ProtocolState*
): IO[CredentialServiceError, Seq[IssueCredentialRecord]]

def getIssueCredentialRecord(
recordId: DidCommID
): ZIO[WalletAccessContext, CredentialServiceError, Option[IssueCredentialRecord]]
Expand Down
Loading

0 comments on commit 17def3f

Please sign in to comment.