Skip to content

Commit

Permalink
test: support multi-tenant tests on wallet-api (#623)
Browse files Browse the repository at this point in the history
* fix: dynamic seed instead of global seed for walletapi

* test: add WalletManagementService tests

* [wip] dummy test for WalletSecretStorage

* test: implement WalletSecretStorage spec

* test: make test db aware of app user and migration user

* [wip]: multitenancy spec for managed DID

* test: add multitenant managed DID index test
  • Loading branch information
patlo-iog authored Aug 11, 2023
1 parent ceb7827 commit 85c0bc3
Show file tree
Hide file tree
Showing 11 changed files with 380 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import io.iohk.atala.agent.walletapi.model.*
import io.iohk.atala.agent.walletapi.model.error.{*, given}
import io.iohk.atala.agent.walletapi.service.ManagedDIDService.DEFAULT_MASTER_KEY_ID
import io.iohk.atala.agent.walletapi.service.handler.{DIDCreateHandler, DIDUpdateHandler, PublicationHandler}
import io.iohk.atala.agent.walletapi.storage.WalletSecretStorage
import io.iohk.atala.agent.walletapi.storage.{DIDNonSecretStorage, DIDSecretStorage}
import io.iohk.atala.agent.walletapi.util.*
import io.iohk.atala.castor.core.model.did.*
Expand All @@ -28,19 +29,19 @@ class ManagedDIDServiceImpl private[walletapi] (
didOpValidator: DIDOperationValidator,
private[walletapi] val secretStorage: DIDSecretStorage,
override private[walletapi] val nonSecretStorage: DIDNonSecretStorage,
walletSecretStorage: WalletSecretStorage,
apollo: Apollo,
seed: WalletSeed, // TODO: support dynamic seed lookup
createDIDSem: Semaphore
) extends ManagedDIDService {

private val AGREEMENT_KEY_ID = "agreement"
private val AUTHENTICATION_KEY_ID = "authentication"

private val keyResolver = KeyResolver(apollo, nonSecretStorage)(seed)

// TODO: implement seed caching & TTL in dispatching layer
private val keyResolver = KeyResolver(apollo, nonSecretStorage, walletSecretStorage)
private val publicationHandler = PublicationHandler(didService, keyResolver)(DEFAULT_MASTER_KEY_ID)
private val didCreateHandler = DIDCreateHandler(apollo, nonSecretStorage)(seed, DEFAULT_MASTER_KEY_ID)
private val didUpdateHandler = DIDUpdateHandler(apollo, nonSecretStorage, publicationHandler)(seed)
private val didCreateHandler = DIDCreateHandler(apollo, nonSecretStorage, walletSecretStorage)(DEFAULT_MASTER_KEY_ID)
private val didUpdateHandler = DIDUpdateHandler(apollo, nonSecretStorage, walletSecretStorage, publicationHandler)

def syncManagedDIDState: ZIO[WalletAccessContext, GetManagedDIDError, Unit] = nonSecretStorage
.listManagedDID(offset = None, limit = None)
Expand Down Expand Up @@ -370,7 +371,7 @@ class ManagedDIDServiceImpl private[walletapi] (
object ManagedDIDServiceImpl {

val layer: RLayer[
DIDOperationValidator & DIDService & DIDSecretStorage & DIDNonSecretStorage & Apollo & SeedResolver,
DIDOperationValidator & DIDService & DIDSecretStorage & DIDNonSecretStorage & WalletSecretStorage & Apollo,
ManagedDIDService
] = {
ZLayer.fromZIO {
Expand All @@ -379,16 +380,16 @@ object ManagedDIDServiceImpl {
didOpValidator <- ZIO.service[DIDOperationValidator]
secretStorage <- ZIO.service[DIDSecretStorage]
nonSecretStorage <- ZIO.service[DIDNonSecretStorage]
walletSecretStorage <- ZIO.service[WalletSecretStorage]
apollo <- ZIO.service[Apollo]
seed <- ZIO.serviceWithZIO[SeedResolver](_.resolve)
createDIDSem <- Semaphore.make(1)
} yield ManagedDIDServiceImpl(
didService,
didOpValidator,
secretStorage,
nonSecretStorage,
walletSecretStorage,
apollo,
seed,
createDIDSem
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package io.iohk.atala.agent.walletapi.service

import io.iohk.atala.agent.walletapi.crypto.Apollo
import io.iohk.atala.agent.walletapi.model.ManagedDIDDetail
import io.iohk.atala.agent.walletapi.model.WalletSeed
import io.iohk.atala.agent.walletapi.model.error.CommonWalletStorageError
import io.iohk.atala.agent.walletapi.storage.WalletSecretStorage
import io.iohk.atala.agent.walletapi.storage.{DIDNonSecretStorage, DIDSecretStorage}
import io.iohk.atala.agent.walletapi.util.SeedResolver
import io.iohk.atala.castor.core.model.did.CanonicalPrismDID
import io.iohk.atala.castor.core.model.error
import io.iohk.atala.castor.core.model.error.DIDOperationError
Expand All @@ -20,17 +19,17 @@ class ManagedDIDServiceWithEventNotificationImpl(
didOpValidator: DIDOperationValidator,
override private[walletapi] val secretStorage: DIDSecretStorage,
override private[walletapi] val nonSecretStorage: DIDNonSecretStorage,
walletSecretStorage: WalletSecretStorage,
apollo: Apollo,
seed: WalletSeed,
createDIDSem: Semaphore,
eventNotificationService: EventNotificationService
) extends ManagedDIDServiceImpl(
didService,
didOpValidator,
secretStorage,
nonSecretStorage,
walletSecretStorage,
apollo,
seed,
createDIDSem
) {

Expand Down Expand Up @@ -59,7 +58,7 @@ class ManagedDIDServiceWithEventNotificationImpl(

object ManagedDIDServiceWithEventNotificationImpl {
val layer: RLayer[
DIDOperationValidator & DIDService & DIDSecretStorage & DIDNonSecretStorage & Apollo & SeedResolver &
DIDOperationValidator & DIDService & DIDSecretStorage & DIDNonSecretStorage & WalletSecretStorage & Apollo &
EventNotificationService,
ManagedDIDService
] = ZLayer.fromZIO {
Expand All @@ -68,17 +67,17 @@ object ManagedDIDServiceWithEventNotificationImpl {
didOpValidator <- ZIO.service[DIDOperationValidator]
secretStorage <- ZIO.service[DIDSecretStorage]
nonSecretStorage <- ZIO.service[DIDNonSecretStorage]
walletSecretStorage <- ZIO.service[WalletSecretStorage]
apollo <- ZIO.service[Apollo]
seed <- ZIO.serviceWithZIO[SeedResolver](_.resolve)
createDIDSem <- Semaphore.make(1)
eventNotificationService <- ZIO.service[EventNotificationService]
} yield ManagedDIDServiceWithEventNotificationImpl(
didService,
didOpValidator,
secretStorage,
nonSecretStorage,
walletSecretStorage,
apollo,
seed,
createDIDSem,
eventNotificationService
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ class WalletManagementServiceImpl(

override def createWallet(seed: Option[WalletSeed]): Task[WalletId] =
for {
seed <- seed.fold(apollo.ecKeyFactory.randomBip32Seed().map(_._1).map(WalletSeed.fromByteArray))(ZIO.succeed)
seed <- seed.fold(
apollo.ecKeyFactory
.randomBip32Seed()
.map(_._1)
.map(WalletSeed.fromByteArray)
)(ZIO.succeed)
walletId <- nonSecretStorage.createWallet
_ <- secretStorage
.setWalletSeed(seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,28 @@ import io.iohk.atala.agent.walletapi.model.PublicationState
import io.iohk.atala.agent.walletapi.model.WalletSeed
import io.iohk.atala.agent.walletapi.model.error.CreateManagedDIDError
import io.iohk.atala.agent.walletapi.storage.DIDNonSecretStorage
import io.iohk.atala.agent.walletapi.storage.WalletSecretStorage
import io.iohk.atala.agent.walletapi.util.OperationFactory
import io.iohk.atala.castor.core.model.did.PrismDIDOperation
import io.iohk.atala.shared.models.WalletAccessContext
import zio.*

private[walletapi] class DIDCreateHandler(
apollo: Apollo,
nonSecretStorage: DIDNonSecretStorage
nonSecretStorage: DIDNonSecretStorage,
walletSecretStorage: WalletSecretStorage,
)(
seed: WalletSeed,
masterKeyId: String
) {
def materialize(
didTemplate: ManagedDIDTemplate
): ZIO[WalletAccessContext, CreateManagedDIDError, DIDCreateMaterial] = {
val operationFactory = OperationFactory(apollo)
for {
walletId <- ZIO.serviceWith[WalletAccessContext](_.walletId)
seed <- walletSecretStorage.getWalletSeed
.someOrElseZIO(ZIO.dieMessage(s"Wallet seed for wallet $walletId does not exist"))
.mapError(CreateManagedDIDError.WalletStorageError.apply)
didIndex <- nonSecretStorage
.getMaxDIDIndex()
.mapBoth(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.iohk.atala.agent.walletapi.model.WalletSeed
import io.iohk.atala.agent.walletapi.model.error.UpdateManagedDIDError
import io.iohk.atala.agent.walletapi.model.error.{*, given}
import io.iohk.atala.agent.walletapi.storage.DIDNonSecretStorage
import io.iohk.atala.agent.walletapi.storage.WalletSecretStorage
import io.iohk.atala.agent.walletapi.util.OperationFactory
import io.iohk.atala.castor.core.model.did.PrismDIDOperation
import io.iohk.atala.castor.core.model.did.PrismDIDOperation.Update
Expand All @@ -22,9 +23,8 @@ import zio.*
private[walletapi] class DIDUpdateHandler(
apollo: Apollo,
nonSecretStorage: DIDNonSecretStorage,
walletSecretStorage: WalletSecretStorage,
publicationHandler: PublicationHandler
)(
seed: WalletSeed
) {
def materialize(
state: ManagedDIDState,
Expand All @@ -36,6 +36,10 @@ private[walletapi] class DIDUpdateHandler(
state.keyMode match {
case KeyManagementMode.HD =>
for {
walletId <- ZIO.serviceWith[WalletAccessContext](_.walletId)
seed <- walletSecretStorage.getWalletSeed
.someOrElseZIO(ZIO.dieMessage(s"Wallet seed for wallet $walletId does not exist"))
.mapError(UpdateManagedDIDError.WalletStorageError.apply)
keyCounter <- nonSecretStorage
.getHdKeyCounter(did)
.mapError(UpdateManagedDIDError.WalletStorageError.apply)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,16 @@ class JdbcWalletSecretStorage(xa: Transactor[ContextAwareTask]) extends WalletSe
}

override def getWalletSeed: RIO[WalletAccessContext, Option[WalletSeed]] = {
val cxnIO = (walletId: WalletId) =>
val cxnIO =
sql"""
| SELECT seed
| FROM public.wallet_seed
| WHERE wallet_id = $walletId
""".stripMargin
.query[Array[Byte]]
.option

ZIO
.serviceWithZIO[WalletAccessContext](ctx => cxnIO(ctx.walletId).transactWallet(xa))
cxnIO
.transactWallet(xa)
.map(_.map(WalletSeed.fromByteArray))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ import io.iohk.atala.agent.walletapi.model.KeyManagementMode
import io.iohk.atala.agent.walletapi.model.ManagedDIDState
import io.iohk.atala.agent.walletapi.model.WalletSeed
import io.iohk.atala.agent.walletapi.storage.DIDNonSecretStorage
import io.iohk.atala.agent.walletapi.storage.WalletSecretStorage
import io.iohk.atala.castor.core.model.did.EllipticCurve
import io.iohk.atala.castor.core.model.did.PrismDID
import io.iohk.atala.shared.models.WalletAccessContext
import zio.*

class KeyResolver(apollo: Apollo, nonSecretStorage: DIDNonSecretStorage)(
seed: WalletSeed
) {
class KeyResolver(apollo: Apollo, nonSecretStorage: DIDNonSecretStorage, walletSecretStorage: WalletSecretStorage) {
def getKey(state: ManagedDIDState, keyId: String): RIO[WalletAccessContext, Option[ECKeyPair]] = {
val did = state.createOperation.did
getKey(did, state.keyMode, keyId)
Expand All @@ -26,14 +25,19 @@ class KeyResolver(apollo: Apollo, nonSecretStorage: DIDNonSecretStorage)(
}

private def resolveHdKey(did: PrismDID, keyId: String): RIO[WalletAccessContext, Option[ECKeyPair]] = {
nonSecretStorage
.getHdKeyPath(did, keyId)
.flatMap {
case None => ZIO.none
case Some(path) =>
apollo.ecKeyFactory
.deriveKeyPair(EllipticCurve.SECP256K1, seed.toByteArray)(path.derivationPath: _*)
.asSome
for {
maybeSeed <- walletSecretStorage.getWalletSeed
maybeKeyPair <- maybeSeed.fold(ZIO.none) { seed =>
nonSecretStorage
.getHdKeyPath(did, keyId)
.flatMap {
case None => ZIO.none
case Some(path) =>
apollo.ecKeyFactory
.deriveKeyPair(EllipticCurve.SECP256K1, seed.toByteArray)(path.derivationPath: _*)
.asSome
}
}
} yield maybeKeyPair
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import io.iohk.atala.agent.walletapi.storage.DIDSecretStorage
import io.iohk.atala.agent.walletapi.storage.StorageSpecHelper
import io.iohk.atala.agent.walletapi.storage.WalletNonSecretStorage
import io.iohk.atala.agent.walletapi.storage.WalletSecretStorage
import io.iohk.atala.agent.walletapi.util.SeedResolver
import io.iohk.atala.agent.walletapi.vault.VaultDIDSecretStorage
import io.iohk.atala.agent.walletapi.vault.VaultWalletSecretStorage
import io.iohk.atala.castor.core.model.did.InternalKeyPurpose
Expand Down Expand Up @@ -44,6 +43,7 @@ import scala.collection.immutable.ArraySeq
import zio.*
import zio.test.*
import zio.test.Assertion.*
import io.iohk.atala.agent.walletapi.model.error.DIDSecretStorageError

object ManagedDIDServiceSpec
extends ZIOSpecDefault,
Expand Down Expand Up @@ -110,7 +110,6 @@ object ManagedDIDServiceSpec
DIDOperationValidator.layer(),
JdbcDIDNonSecretStorage.layer,
JdbcWalletNonSecretStorage.layer,
SeedResolver.layer(isDevMode = true),
transactorLayer,
testDIDServiceLayer,
apolloLayer
Expand Down Expand Up @@ -156,11 +155,12 @@ object ManagedDIDServiceSpec
override def spec = {
def testSuite(name: String) =
suite(name)(
publishStoredDIDSpec,
createAndStoreDIDSpec,
updateManagedDIDSpec,
deactivateManagedDIDSpec
).globalWallet
publishStoredDIDSpec.globalWallet,
createAndStoreDIDSpec.globalWallet,
updateManagedDIDSpec.globalWallet,
deactivateManagedDIDSpec.globalWallet,
multitenantSpec
)
@@ TestAspect.before(DBTestUtils.runMigrationAgentDB)
@@ TestAspect.sequential

Expand Down Expand Up @@ -471,4 +471,81 @@ object ManagedDIDServiceSpec
}
)

private val multitenantSpec = suite("multi-tenant managed DID")(
test("do not see Prism DID outside of the wallet") {
val template = generateDIDTemplate()
for {
walletSvc <- ZIO.service[WalletManagementService]
walletId1 <- walletSvc.createWallet()
walletId2 <- walletSvc.createWallet()
ctx1 = ZLayer.succeed(WalletAccessContext(walletId1))
ctx2 = ZLayer.succeed(WalletAccessContext(walletId2))
svc <- ZIO.service[ManagedDIDService]
dids1 <- ZIO.foreach(1 to 3)(_ => svc.createAndStoreDID(template).map(_.asCanonical)).provide(ctx1)
dids2 <- ZIO.foreach(1 to 3)(_ => svc.createAndStoreDID(template).map(_.asCanonical)).provide(ctx2)
ownWalletDids1 <- svc.listManagedDIDPage(0, 1000).map(_._1.map(_.did)).provide(ctx1)
ownWalletDids2 <- svc.listManagedDIDPage(0, 1000).map(_._1.map(_.did)).provide(ctx2)
crossWalletDids1 <- ZIO.foreach(dids1)(did => svc.getManagedDIDState(did)).provide(ctx2)
crossWalletDids2 <- ZIO.foreach(dids2)(did => svc.getManagedDIDState(did)).provide(ctx1)
} yield assert(dids1)(hasSameElements(ownWalletDids1)) &&
assert(dids2)(hasSameElements(ownWalletDids2)) &&
assert(crossWalletDids1)(forall(isNone)) &&
assert(crossWalletDids2)(forall(isNone))
},
test("do not see Peer DID outside of the wallet") {
for {
walletSvc <- ZIO.service[WalletManagementService]
walletId1 <- walletSvc.createWallet()
walletId2 <- walletSvc.createWallet()
ctx1 = ZLayer.succeed(WalletAccessContext(walletId1))
ctx2 = ZLayer.succeed(WalletAccessContext(walletId2))
svc <- ZIO.service[ManagedDIDService]
dids1 <- ZIO.foreach(1 to 3)(_ => svc.createAndStorePeerDID("http://example.com")).provide(ctx1)
dids2 <- ZIO.foreach(1 to 3)(_ => svc.createAndStorePeerDID("http://example.com")).provide(ctx2)
ownWalletDids1 <- ZIO.foreach(dids1)(d => svc.getPeerDID(d.did).exit).provide(ctx1)
ownWalletDids2 <- ZIO.foreach(dids2)(d => svc.getPeerDID(d.did).exit).provide(ctx2)
crossWalletDids1 <- ZIO.foreach(dids1)(d => svc.getPeerDID(d.did).exit).provide(ctx2)
crossWalletDids2 <- ZIO.foreach(dids2)(d => svc.getPeerDID(d.did).exit).provide(ctx1)
} yield assert(ownWalletDids1)(forall(succeeds(anything))) &&
assert(ownWalletDids2)(forall(succeeds(anything))) &&
assert(crossWalletDids1)(forall(failsWithA[DIDSecretStorageError.KeyNotFoundError])) &&
assert(crossWalletDids2)(forall(failsWithA[DIDSecretStorageError.KeyNotFoundError]))
},
test("increment DID index based on count only on its wallet") {
val template = generateDIDTemplate()
for {
walletSvc <- ZIO.service[WalletManagementService]
walletId1 <- walletSvc.createWallet()
walletId2 <- walletSvc.createWallet()
ctx1 = ZLayer.succeed(WalletAccessContext(walletId1))
ctx2 = ZLayer.succeed(WalletAccessContext(walletId2))
svc <- ZIO.service[ManagedDIDService]
wallet1Counter1 <- svc.nonSecretStorage.getMaxDIDIndex().provide(ctx1)
wallet2Counter1 <- svc.nonSecretStorage.getMaxDIDIndex().provide(ctx2)
_ <- svc.createAndStoreDID(template).provide(ctx1)
wallet1Counter2 <- svc.nonSecretStorage.getMaxDIDIndex().provide(ctx1)
wallet2Counter2 <- svc.nonSecretStorage.getMaxDIDIndex().provide(ctx2)
_ <- svc.createAndStoreDID(template).provide(ctx1)
wallet1Counter3 <- svc.nonSecretStorage.getMaxDIDIndex().provide(ctx1)
wallet2Counter3 <- svc.nonSecretStorage.getMaxDIDIndex().provide(ctx2)
_ <- svc.createAndStoreDID(template).provide(ctx2)
wallet1Counter4 <- svc.nonSecretStorage.getMaxDIDIndex().provide(ctx1)
wallet2Counter4 <- svc.nonSecretStorage.getMaxDIDIndex().provide(ctx2)
} yield {
// initial counter
assert(wallet1Counter1)(isNone) &&
assert(wallet2Counter1)(isNone) &&
// add DID to wallet 1
assert(wallet1Counter2)(isSome(equalTo(0))) &&
assert(wallet2Counter2)(isNone) &&
// add DID to wallet 1
assert(wallet1Counter3)(isSome(equalTo(1))) &&
assert(wallet2Counter3)(isNone) &&
// add DID to wallet 2
assert(wallet1Counter4)(isSome(equalTo(1))) &&
assert(wallet2Counter4)(isSome(equalTo(0)))
}
}
)

}
Loading

0 comments on commit 85c0bc3

Please sign in to comment.