Skip to content

Commit

Permalink
fix(prism-agent): update invitation expiration on connection request (#…
Browse files Browse the repository at this point in the history
…687)

Signed-off-by: Shailesh Patil <[email protected]>
  • Loading branch information
mineme0110 authored Sep 6, 2023
1 parent 854dcf9 commit 1a1702f
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ object ConnectionServiceError {
final case class UnexpectedError(msg: String) extends ConnectionServiceError
final case class InvalidFlowStateError(msg: String) extends ConnectionServiceError
final case class InvitationAlreadyReceived(msg: String) extends ConnectionServiceError
final case class InvitationExpired(msg: String) extends ConnectionServiceError

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ trait ConnectionService {

def markConnectionRequestSent(recordId: UUID): IO[ConnectionServiceError, ConnectionRecord]

def receiveConnectionRequest(request: ConnectionRequest): IO[ConnectionServiceError, ConnectionRecord]
// def receiveConnectionRequest(request: ConnectionRequest): IO[ConnectionServiceError, ConnectionRecord]
def receiveConnectionRequest(
request: ConnectionRequest,
expirationTime: Option[Duration]
): IO[ConnectionServiceError, ConnectionRecord]

def acceptConnectionRequest(recordId: UUID): IO[ConnectionServiceError, ConnectionRecord]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import zio.*
import java.rmi.UnexpectedException
import java.time.Instant
import java.util.UUID

import java.time.Duration
private class ConnectionServiceImpl(
connectionRepository: ConnectionRepository[Task],
maxRetries: Int = 5, // TODO move to config
Expand Down Expand Up @@ -176,13 +176,25 @@ private class ConnectionServiceImpl(
}

override def receiveConnectionRequest(
request: ConnectionRequest
request: ConnectionRequest,
expirationTime: Option[Duration] = None
): IO[ConnectionServiceError, ConnectionRecord] =
for {
record <- getRecordFromThreadIdAndState(
Some(request.thid.orElse(request.pthid).getOrElse(request.id)),
ProtocolState.InvitationGenerated
)
_ <- expirationTime.fold {
ZIO.unit
} { expiryDuration =>
val actualDuration = Duration.between(record.createdAt, Instant.now())
if (actualDuration > expiryDuration) {
for {
_ <- markConnectionInvitationExpired(record.id)
result <- ZIO.fail(InvitationExpired(record.id.toString))
} yield result
} else ZIO.unit
}
_ <- connectionRepository
.updateWithConnectionRequest(record.id, request, ProtocolState.ConnectionRequestReceived, maxRetries)
.flatMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import io.iohk.atala.mercury.model.DidId
import io.iohk.atala.mercury.protocol.connection.{ConnectionRequest, ConnectionResponse}
import zio.{IO, URLayer, ZIO, ZLayer}

import java.time.Duration
import java.util.UUID

class ConnectionServiceNotifier(
Expand Down Expand Up @@ -34,8 +35,11 @@ class ConnectionServiceNotifier(
override def markConnectionRequestSent(recordId: UUID): IO[ConnectionServiceError, ConnectionRecord] =
notifyOnSuccess(svc.markConnectionRequestSent(recordId))

override def receiveConnectionRequest(request: ConnectionRequest): IO[ConnectionServiceError, ConnectionRecord] =
notifyOnSuccess(svc.receiveConnectionRequest(request))
override def receiveConnectionRequest(
request: ConnectionRequest,
expirationTime: Option[Duration]
): IO[ConnectionServiceError, ConnectionRecord] =
notifyOnSuccess(svc.receiveConnectionRequest(request, expirationTime))

override def acceptConnectionRequest(recordId: UUID): IO[ConnectionServiceError, ConnectionRecord] =
notifyOnSuccess(svc.acceptConnectionRequest(recordId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import io.iohk.atala.mercury.protocol.connection.{ConnectionRequest, ConnectionR
import zio.mock.{Mock, Proxy}
import zio.{IO, URLayer, ZIO, ZLayer, mock}

import java.time.Duration
import java.util.UUID

object MockConnectionService extends Mock[ConnectionService] {
Expand Down Expand Up @@ -44,7 +45,10 @@ object MockConnectionService extends Mock[ConnectionService] {
override def markConnectionRequestSent(recordId: UUID): IO[ConnectionServiceError, ConnectionRecord] =
proxy(MarkConnectionRequestSent, recordId)

override def receiveConnectionRequest(request: ConnectionRequest): IO[ConnectionServiceError, ConnectionRecord] =
override def receiveConnectionRequest(
request: ConnectionRequest,
expirationTime: Option[Duration]
): IO[ConnectionServiceError, ConnectionRecord] =
proxy(ReceiveConnectionRequest, request)

override def acceptConnectionRequest(recordId: UUID): IO[ConnectionServiceError, ConnectionRecord] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ object ConnectionServiceImplSpec extends ZIOSpecDefault {
)
// FIXME: Should the service return an Option while we have dedicated "not found" error for that case !?
connectionRequest = maybeAcceptedInvitationRecord.connectionRequest.get
maybeReceivedRequestConnectionRecord <- inviterSvc.receiveConnectionRequest(connectionRequest)
maybeReceivedRequestConnectionRecord <- inviterSvc.receiveConnectionRequest(connectionRequest, None)
allInviterRecords <- inviterSvc.getConnectionRecords()
} yield {
val updatedRecord = maybeReceivedRequestConnectionRecord
Expand All @@ -194,8 +194,8 @@ object ConnectionServiceImplSpec extends ZIOSpecDefault {
)
connectionRequest = maybeAcceptedInvitationRecord.connectionRequest.get
connectionRecordUpdated <- inviterSvc.markConnectionInvitationExpired(inviterRecord.id)

exit <- inviterSvc.receiveConnectionRequest(connectionRequest).exit
expiryTime = Duration.fromSeconds(300)
exit <- inviterSvc.receiveConnectionRequest(connectionRequest, Some(expiryTime)).exit

} yield {
assertTrue(exit match
Expand All @@ -220,7 +220,11 @@ object ConnectionServiceImplSpec extends ZIOSpecDefault {
DidId("did:peer:INVITEE")
)
connectionRequest = maybeAcceptedInvitationRecord.connectionRequest.get
maybeReceivedRequestConnectionRecord <- inviterSvc.receiveConnectionRequest(connectionRequest)
expiryTime = Duration.fromSeconds(300)
maybeReceivedRequestConnectionRecord <- inviterSvc.receiveConnectionRequest(
connectionRequest,
Some(expiryTime)
)
maybeAcceptedRequestConnectionRecord <- inviterSvc.acceptConnectionRequest(inviterRecord.id)
allInviterRecords <- inviterSvc.getConnectionRecords()
} yield {
Expand Down Expand Up @@ -250,7 +254,12 @@ object ConnectionServiceImplSpec extends ZIOSpecDefault {
)
connectionRequest = maybeAcceptedInvitationRecord.connectionRequest.get
_ <- inviteeSvc.markConnectionRequestSent(inviteeRecord.id)
maybeReceivedRequestConnectionRecord <- inviterSvc.receiveConnectionRequest(connectionRequest)
expiryTime = Duration.fromSeconds(300)

maybeReceivedRequestConnectionRecord <- inviterSvc.receiveConnectionRequest(
connectionRequest,
Some(expiryTime)
)
maybeAcceptedRequestConnectionRecord <- inviterSvc.acceptConnectionRequest(inviterRecord.id)
connectionResponseMessage <- ZIO.fromEither(
maybeAcceptedRequestConnectionRecord.connectionResponse.get.makeMessage.asJson.as[Message]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ object ConnectionServiceNotifierSpec extends ZIOSpecDefault {
thid = Some(connectionRecord.thid),
pthid = None,
body = ConnectionRequest.Body()
)
),
None
)
_ <- cs.acceptConnectionRequest(connectionRecord.id)
_ <- cs.markConnectionResponseSent(connectionRecord.id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.iohk.atala.agent.server
import io.circe.*
import io.circe.parser.*
import io.iohk.atala.agent.server.DidCommHttpServerError.{DIDCommMessageParsingError, RequestBodyParsingError}
import io.iohk.atala.agent.server.config.AppConfig
import io.iohk.atala.agent.walletapi.model.error.DIDSecretStorageError
import io.iohk.atala.agent.walletapi.service.ManagedDIDService
import io.iohk.atala.connect.core.model.error.ConnectionServiceError
Expand Down Expand Up @@ -42,7 +43,7 @@ object DidCommHttpServer {

private def didCommServiceEndpoint: HttpApp[
DidOps & DidAgent & CredentialService & PresentationService & ConnectionService & ManagedDIDService & HttpClient &
DidAgent & DIDResolver,
DidAgent & DIDResolver & AppConfig,
Nothing
] = Http.collectZIO[Request] {
case Method.GET -> !! / "did" =>
Expand Down Expand Up @@ -147,17 +148,24 @@ object DidCommHttpServer {
/*
* Connect
*/
private val handleConnect
: PartialFunction[Message, ZIO[ConnectionService, DIDCommMessageParsingError | ConnectionServiceError, Unit]] = {
private val handleConnect: PartialFunction[Message, ZIO[
ConnectionService & AppConfig,
DIDCommMessageParsingError | ConnectionServiceError,
Unit
]] = {
case msg if msg.piuri == ConnectionRequest.`type` =>
for {
connectionRequest <- ZIO
.fromEither(ConnectionRequest.fromMessage(msg))
.mapError(DIDCommMessageParsingError.apply)
_ <- ZIO.logInfo("As an Inviter in connect got ConnectionRequest: " + connectionRequest)
connectionService <- ZIO.service[ConnectionService]
maybeRecord <- connectionService.receiveConnectionRequest(connectionRequest)
_ <- connectionService.acceptConnectionRequest(maybeRecord.id)
config <- ZIO.service[AppConfig]
record <- connectionService.receiveConnectionRequest(
connectionRequest,
Some(config.connect.connectInvitationExpiry)
)
_ <- connectionService.acceptConnectionRequest(record.id)
} yield ()
case msg if msg.piuri == ConnectionResponse.`type` =>
for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ import io.iohk.atala.shared.utils.aspects.CustomMetricsAspect
import io.iohk.atala.shared.utils.DurationOps.toMetricsSeconds
import zio.*
import zio.metrics.*
import java.time.{Instant, Duration}

object ConnectBackgroundJobs {

val didCommExchanges = {
Expand All @@ -29,7 +27,6 @@ object ConnectBackgroundJobs {
.getConnectionRecordsByStates(
ignoreWithZeroRetries = true,
limit = config.connect.connectBgJobRecordsLimit,
ConnectionRecord.ProtocolState.InvitationGenerated,
ConnectionRecord.ProtocolState.ConnectionRequestPending,
ConnectionRecord.ProtocolState.ConnectionResponsePending
)
Expand All @@ -40,7 +37,7 @@ object ConnectBackgroundJobs {

private[this] def performExchange(
record: ConnectionRecord
): URIO[DidOps & DIDResolver & HttpClient & ConnectionService & ManagedDIDService & AppConfig, Unit] = {
): URIO[DidOps & DIDResolver & HttpClient & ConnectionService & ManagedDIDService, Unit] = {
import ProtocolState.*
import Role.*

Expand All @@ -60,9 +57,6 @@ object ConnectBackgroundJobs {
val InviterConnectionResponseMsgSuccess = counterMetric(
"connection_flow_inviter_connection_response_msg_success_counter"
)
val InviterConnectionInvitationExpiredSuccess = counterMetric(
"connection_flow_inviter_connection_invitation_expired_success_counter"
)
val InviteeProcessConnectionRecordPendingSuccess = counterMetric(
"connection_flow_invitee_process_connection_record_success_counter"
)
Expand Down Expand Up @@ -160,42 +154,15 @@ object ConnectBackgroundJobs {
else ZIO.fail(ErrorResponseReceivedFromPeerAgent(resp)) @@ InviterConnectionResponseMsgFailed
}
} yield ()

inviterProcessFlow
@@ InviterProcessConnectionRecordPendingSuccess.trackSuccess
@@ InviterProcessConnectionRecordPendingFailed.trackError
@@ InviterProcessConnectionRecordPendingTotal
@@ Metric
.gauge("connection_flow_inviter_process_connection_record_ms_gauge")
.trackDurationWith(_.toMetricsSeconds)
case ConnectionRecord(
id,
createdAt,
_,
_,
_,
Inviter,
InvitationGenerated,
_,
_,
_,
metaRetries,
_,
_
) if metaRetries > 0 =>
for {
connectionService <- ZIO.service[ConnectionService]
config <- ZIO.service[AppConfig]
expired <- ZIO.succeed {
val expiryDuration = config.connect.connectInvitationExpiry
val actualDuration = Duration.between(createdAt, Instant.now())
actualDuration > expiryDuration
}
_ <-
if (expired) {
connectionService.markConnectionInvitationExpired(id)
@@ InviterConnectionInvitationExpiredSuccess
} else ZIO.unit
} yield ()

case _ => ZIO.unit
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ object ConnectionController {
ErrorResponse.badRequest(title = "InvalidFlowState", detail = Some(msg))
case ConnectionServiceError.InvitationAlreadyReceived(msg) =>
ErrorResponse.badRequest(title = "InvitationAlreadyReceived", detail = Some(msg))
case ConnectionServiceError.InvitationExpired(msg) =>
ErrorResponse.badRequest(title = "InvitationExpired", detail = Some(msg))
}

0 comments on commit 1a1702f

Please sign in to comment.