Skip to content

Commit

Permalink
Merge pull request #10 from bitovi/pr-feedback
Browse files Browse the repository at this point in the history
PR feedback
  • Loading branch information
phillipskevin authored Oct 1, 2024
2 parents 55f18b2 + daa374b commit 53c03b3
Show file tree
Hide file tree
Showing 11 changed files with 1,565 additions and 1,164 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
.venv
__pycache__
_certs
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Some examples require extra dependencies. See each sample's directory for specif
* [custom_decorator](custom_decorator) - Custom decorator to auto-heartbeat a long-running activity.
* [dsl](dsl) - DSL workflow that executes steps defined in a YAML file.
* [encryption](encryption) - Apply end-to-end encryption for all input/output.
* [encryption_jwt](encryption_jwt) - Apply end-to-end encryption for all input/output using a KMS and per-namespace JWT-based auth.
* [gevent_async](gevent_async) - Combine gevent and Temporal.
* [langchain](langchain) - Orchestrate workflows for LangChain.
* [open_telemetry](open_telemetry) - Trace workflows with OpenTelemetry.
Expand Down
1 change: 1 addition & 0 deletions encryption_jwt/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_certs
12 changes: 6 additions & 6 deletions encryption_jwt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ The Codec Server uses the [Operations API](https://docs.temporal.io/ops) to get

## Install

For this sample, the optional `encryption` and `bedrock` dependency groups must be included. To include, run:
For this sample, the optional `encryption_jwt` and `bedrock` dependency groups must be included. To include, run:

```sh
poetry install --with encryption,bedrock
poetry install --with encryption_jwt,bedrock
```

## Setup
Expand All @@ -31,17 +31,17 @@ Alternately replace the key management portion with your own implementation.
### Self-signed certificates

The codec server will need to use HTTPS, self-signed certificates will work in the development
environment. Run the following command in a `_certs` directory that's a subdirectory of the
repository root, it will create certificate files that are good for 10 years.
environment. Run the following command in a `_certs` directory that's a subdirectory of this one.
It will create certificate files that are good for 10 years.

```sh
openssl req -x509 -newkey rsa:4096 -sha256 -days 3650 -nodes -keyout localhost.key -out localhost.pem -subj "/CN=localhost"
```

In the projects you can access the files using the following relative paths.

- `../_certs/localhost.pem`
- `../_certs/localhost.key`
- `./_certs/localhost.pem`
- `./_certs/localhost.key`

## Run

Expand Down
17 changes: 10 additions & 7 deletions encryption_jwt/codec.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from typing import Iterable, List

from temporalio.api.common.v1 import Payload
from temporalio.converter import PayloadCodec

from encryption_jwt.encryptor import KMSEncryptor


class EncryptionCodec(PayloadCodec):

def __init__(self, namespace: str):
self._encryptor = KMSEncryptor(namespace)

async def encode(self, payloads: Iterable[Payload]) -> List[Payload]:
# We blindly encode all payloads with the key and set the metadata with the key that was
# used (base64 encoded).

def encrypt_payload(p: Payload):
data, key = self._encryptor.encrypt(p.SerializeToString())
async def encrypt_payload(p: Payload):
data, key = await self._encryptor.encrypt(p.SerializeToString())
return Payload(
metadata={
"encoding": b"binary/encrypted",
Expand All @@ -23,12 +24,14 @@ def encrypt_payload(p: Payload):
data=data,
)

return list(map(encrypt_payload, payloads))
# return list(map(encrypt_payload, payloads))
return [await encrypt_payload(payload) for payload in payloads]

async def decode(self, payloads: Iterable[Payload]) -> List[Payload]:
def decrypt_payload(p: Payload):
async def decrypt_payload(p: Payload):
data_key_encrypted_base64 = p.metadata.get("data_key_encrypted", b"")
data = self._encryptor.decrypt(data_key_encrypted_base64, p.data)
data = await self._encryptor.decrypt(data_key_encrypted_base64, p.data)
return Payload.FromString(data)

return list(map(decrypt_payload, payloads))
# return list(map(decrypt_payload, payloads))
return [await decrypt_payload(payload) for payload in payloads]
133 changes: 89 additions & 44 deletions encryption_jwt/codec_server.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import logging
import os
import ssl
import logging

import jwt
import grpc
import requests
from aiohttp import hdrs, web

from temporalio.api.common.v1 import Payload, Payloads
from temporalio.api.cloud.cloudservice.v1 import request_response_pb2, service_pb2_grpc
from google.protobuf import json_format
from jwt.algorithms import RSAAlgorithm
from temporalio.api.cloud.cloudservice.v1 import GetUsersRequest
from temporalio.api.common.v1 import Payloads
from temporalio.client import CloudOperationsClient

from encryption_jwt.codec import EncryptionCodec

AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["admin"]
AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["owner", "admin"]
AUTHORIZED_NAMESPACE_ACCESS_ROLES = ["read", "write", "admin"]

TEMPORAL_CLIENT_CLOUD_API_VERSION = "2024-05-13-00"

temporal_ops_address = "saas-api.tmprl.cloud:443"
if os.environ.get("TEMPORAL_OPS_ADDRESS"):
temporal_ops_address = os.environ.get("TEMPORAL_OPS_ADDRESS")
Expand Down Expand Up @@ -42,52 +47,90 @@ async def cors_options(req: web.Request) -> web.Response:

return resp

def decryption_authorized(email: str, namespace: str) -> bool:
credentials = grpc.composite_channel_credentials(grpc.ssl_channel_credentials(
), grpc.access_token_call_credentials(os.environ.get("TEMPORAL_API_KEY")))

with grpc.secure_channel(temporal_ops_address, credentials) as channel:
client = service_pb2_grpc.CloudServiceStub(channel)
request = request_response_pb2.GetUsersRequest()

response = client.GetUsers(request, metadata=(
("temporal-cloud-api-version", os.environ.get("TEMPORAL_OPS_API_VERSION")),))

authorized = False
for user in response.users:
if user.spec.email.lower() == email.lower():
if user.spec.access.account_access.role in AUTHORIZED_ACCOUNT_ACCESS_ROLES:
authorized = True
else:
if namespace in user.spec.access.namespace_accesses:
if user.spec.access.namespace_accesses[namespace].permission in AUTHORIZED_NAMESPACE_ACCESS_ROLES:
authorized = True

return authorized
async def decryption_authorized(email: str, namespace: str) -> bool:
client = await CloudOperationsClient.connect(
api_key=os.environ.get("TEMPORAL_API_KEY"),
version=TEMPORAL_CLIENT_CLOUD_API_VERSION,
)

response = await client.cloud_service.get_users(
GetUsersRequest(namespace=namespace)
)

for user in response.users:
if user.spec.email.lower() == email.lower():
if (
user.spec.access.account_access.role
in AUTHORIZED_ACCOUNT_ACCESS_ROLES
):
return True
else:
if namespace in user.spec.access.namespace_accesses:
if (
user.spec.access.namespace_accesses[namespace].permission
in AUTHORIZED_NAMESPACE_ACCESS_ROLES
):
return True

return False

def make_handler(fn: str):
async def handler(req: web.Request):
# Read payloads as JSON
assert req.content_type == "application/json"
payloads = json_format.Parse(await req.read(), Payloads())

# Extract the email from the JWT.
auth_header = req.headers.get("Authorization")
namespace = req.headers.get("x-namespace")
auth_header = req.headers.get("Authorization")
_bearer, encoded = auth_header.split(" ")
decoded = jwt.decode(encoded, options={"verify_signature": False})

# Use the email to determine if the payload should be decrypted.
authorized = decryption_authorized(decoded["https://saas-api.tmprl.cloud/user/email"], namespace)
# Extract the kid from the Auth header
jwt_dict = jwt.get_unverified_header(encoded)
kid = jwt_dict["kid"]
algorithm = jwt_dict["alg"]

# Fetch Temporal Cloud JWKS
jwks_url = "https://login.tmprl.cloud/.well-known/jwks.json"
jwks = requests.get(jwks_url).json()

# Extract Temporal Cloud's public key
public_key = None
for key in jwks["keys"]:
if key["kid"] == kid:
# Convert JWKS key to PEM format
public_key = RSAAlgorithm.from_jwk(key)
break

if public_key is None:
raise ValueError("Public key not found in JWKS")

# Decode the jwt, verifying against Temporal Cloud's public key
decoded = jwt.decode(
encoded,
public_key,
algorithms=[algorithm],
audience=[
"https://saas-api.tmprl.cloud",
"https://prod-tmprl.us.auth0.com/userinfo",
],
)

# Use the email to determine if the user is authorized to decrypt the payload
authorized = await decryption_authorized(
decoded["https://saas-api.tmprl.cloud/user/email"], namespace
)

if authorized:
# Read payloads as JSON
assert req.content_type == "application/json"
payloads = json_format.Parse(await req.read(), Payloads())
encryptionCodec = EncryptionCodec(namespace)
payloads = Payloads(payloads=await getattr(encryptionCodec, fn)(payloads.payloads))
payloads = Payloads(
payloads=await getattr(encryptionCodec, fn)(payloads.payloads)
)

# Apply CORS and return JSON
resp = await cors_options(req)
resp.content_type = "application/json"
resp.text = json_format.MessageToJson(payloads)
return resp

return handler

# Build app
Expand All @@ -97,8 +140,8 @@ async def handler(req: web.Request):
logger = logging.getLogger(__name__)
app.add_routes(
[
web.post("/encode", make_handler('encode')),
web.post("/decode", make_handler('decode')),
web.post("/encode", make_handler("encode")),
web.post("/decode", make_handler("decode")),
web.options("/decode", cors_options),
]
)
Expand All @@ -112,8 +155,10 @@ async def handler(req: web.Request):
if os.environ.get("SSL_PEM") and os.environ.get("SSL_KEY"):
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.check_hostname = False
ssl_context.load_cert_chain(os.environ.get(
"SSL_PEM"), os.environ.get("SSL_KEY"))
ssl_context.load_cert_chain(
os.environ.get("SSL_PEM"), os.environ.get("SSL_KEY")
)

web.run_app(build_codec_server(), host="0.0.0.0",
port=8081, ssl_context=ssl_context)
web.run_app(
build_codec_server(), host="0.0.0.0", port=8081, ssl_context=ssl_context
)
78 changes: 43 additions & 35 deletions encryption_jwt/encryptor.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
import os
import base64
import logging
from temporalio import workflow
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import os

from botocore.exceptions import ClientError
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from temporalio import workflow

with workflow.unsafe.imports_passed_through():
import boto3
import aioboto3


class KMSEncryptor:
"""Encrypts and decrypts using keys from AWS KMS."""

def __init__(self, namespace: str):
self._namespace = namespace
self._kms_client = None
self._boto_session = None

@property
def kms_client(self):
def boto_session(self):
"""Get a KMS client from boto3."""
if not self._kms_client:
self._kms_client = boto3.client("kms")
if not self._boto_session:
session = aioboto3.Session()
self._boto_session = session

return self._kms_client
return self._boto_session

def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
async def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
"""Encrypt data using a key from KMS."""
# The keys are rotated automatically by KMS, so fetch a new key to encrypt the data.
data_key_encrypted, data_key_plaintext = self.__create_data_key(self._namespace)
data_key_encrypted, data_key_plaintext = await self.__create_data_key(
self._namespace
)

if data_key_encrypted is None:
raise ValueError("No data key!")
Expand All @@ -38,38 +42,42 @@ def encrypt(self, data: bytes) -> tuple[bytes, bytes]:
data_key_encrypted
)

def decrypt(self, data_key_encrypted_base64, data: bytes) -> bytes:
async def decrypt(self, data_key_encrypted_base64, data: bytes) -> bytes:
"""Encrypt data using a key from KMS."""
data_key_encrypted = base64.b64decode(data_key_encrypted_base64)
data_key_plaintext = self.__decrypt_data_key(data_key_encrypted)
data_key_plaintext = await self.__decrypt_data_key(data_key_encrypted)
encryptor = AESGCM(data_key_plaintext)
return encryptor.decrypt(data[:12], data[12:], None)

def __create_data_key(self, namespace: str):
async def __create_data_key(self, namespace: str):
"""Get a set of keys from AWS KMS that can be used to encrypt data."""

# Create data key
alias_name = 'alias/' + namespace.replace('.', '_')
response = self.kms_client.describe_key(KeyId=alias_name)
cmk_id = response['KeyMetadata']['Arn']
key_spec = "AES_256"
try:
response = self.kms_client.generate_data_key(KeyId=cmk_id, KeySpec=key_spec)
except ClientError as e:
logging.error(e)
return None, None

# Return the encrypted and plaintext data key
return response["CiphertextBlob"], response["Plaintext"]

def __decrypt_data_key(self, data_key_encrypted):
alias_name = "alias/" + namespace.replace(".", "_")
async with self.boto_session.client("kms") as kms_client:
response = await kms_client.describe_key(KeyId=alias_name)
cmk_id = response["KeyMetadata"]["Arn"]
key_spec = "AES_256"
try:
response = await kms_client.generate_data_key(
KeyId=cmk_id, KeySpec=key_spec
)
except ClientError as e:
logging.error(e)
return None, None

# Return the encrypted and plaintext data key
return response["CiphertextBlob"], response["Plaintext"]

async def __decrypt_data_key(self, data_key_encrypted):
"""Use AWS KMS to exchange an encrypted key for its plaintext value."""

# Decrypt the data key
try:
response = self.kms_client.decrypt(CiphertextBlob=data_key_encrypted)
except ClientError as e:
logging.error(e)
return None
async with self.boto_session.client("kms") as kms_client:
# Decrypt the data key
try:
response = await kms_client.decrypt(CiphertextBlob=data_key_encrypted)
except ClientError as e:
logging.error(e)
return None

return response["Plaintext"]
return response["Plaintext"]
Loading

0 comments on commit 53c03b3

Please sign in to comment.