Skip to content

Commit

Permalink
allow JSON values in Secret block (#14980)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Aug 26, 2024
1 parent b884702 commit 4c3e2c3
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 29 deletions.
6 changes: 5 additions & 1 deletion src/prefect/blocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,11 @@ def _collect_secret_fields(
_collect_secret_fields(f"{name}.{field_name}", field.annotation, secrets)
return

if type_ in (SecretStr, SecretBytes):
if type_ in (SecretStr, SecretBytes) or (
isinstance(type_, type)
and getattr(type_, "__module__", None) == "pydantic.types"
and getattr(type_, "__name__", None) == "Secret"
):
secrets.append(name)
elif type_ == SecretDict:
# Append .* to field name to signify that all values under this
Expand Down
60 changes: 48 additions & 12 deletions src/prefect/blocks/system.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from typing import Any

from pydantic import Field, SecretStr
from pydantic_extra_types.pendulum_dt import DateTime
import json
from typing import Annotated, Any, Generic, TypeVar, Union

from pydantic import (
Field,
JsonValue,
SecretStr,
StrictStr,
field_validator,
)
from pydantic import Secret as PydanticSecret
from pydantic_extra_types.pendulum_dt import DateTime as PydanticDateTime

from prefect._internal.compatibility.deprecated import deprecated_class
from prefect.blocks.core import Block

_SecretValueType = Union[
Annotated[StrictStr, Field(title="string")],
Annotated[JsonValue, Field(title="JSON")],
]

T = TypeVar("T", bound=_SecretValueType)


@deprecated_class(
start_date="Jun 2024",
Expand Down Expand Up @@ -86,24 +101,26 @@ class DateTime(Block):
_logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/8b3da9a6621e92108b8e6a75b82e15374e170ff7-48x48.png"
_documentation_url = "https://docs.prefect.io/api-ref/prefect/blocks/system/#prefect.blocks.system.DateTime"

value: DateTime = Field(
value: PydanticDateTime = Field(
default=...,
description="An ISO 8601-compatible datetime value.",
)


class Secret(Block):
class Secret(Block, Generic[T]):
"""
A block that represents a secret value. The value stored in this block will be obfuscated when
this block is logged or shown in the UI.
this block is viewed or edited in the UI.
Attributes:
value: A string value that should be kept secret.
value: A value that should be kept secret.
Example:
```python
from prefect.blocks.system import Secret
Secret(value="sk-1234567890").save("BLOCK_NAME", overwrite=True)
secret_block = Secret.load("BLOCK_NAME")
# Access the stored secret
Expand All @@ -114,9 +131,28 @@ class Secret(Block):
_logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/c6f20e556dd16effda9df16551feecfb5822092b-48x48.png"
_documentation_url = "https://docs.prefect.io/api-ref/prefect/blocks/system/#prefect.blocks.system.Secret"

value: SecretStr = Field(
default=..., description="A string value that should be kept secret."
value: Union[SecretStr, PydanticSecret[T]] = Field(
default=...,
description="A value that should be kept secret.",
examples=["sk-1234567890", {"username": "johndoe", "password": "s3cr3t"}],
json_schema_extra={
"writeOnly": True,
"format": "password",
},
)

def get(self):
return self.value.get_secret_value()
@field_validator("value", mode="before")
def validate_value(
cls, value: Union[T, SecretStr, PydanticSecret[T]]
) -> Union[SecretStr, PydanticSecret[T]]:
if isinstance(value, (PydanticSecret, SecretStr)):
return value
else:
return PydanticSecret[type(value)](value)

def get(self) -> T:
try:
value = self.value.get_secret_value()
return json.loads(value)
except (TypeError, json.JSONDecodeError):
return value
14 changes: 8 additions & 6 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,15 +1324,17 @@ async def create_block_document(
`SecretBytes` fields. Note Blocks may not work as expected if
this is set to `False`.
"""
block_document_data = block_document.model_dump(
mode="json",
exclude_unset=True,
exclude={"id", "block_schema", "block_type"},
context={"include_secrets": include_secrets},
serialize_as_any=True,
)
try:
response = await self._client.post(
"/block_documents/",
json=block_document.model_dump(
mode="json",
exclude_unset=True,
exclude={"id", "block_schema", "block_type"},
context={"include_secrets": include_secrets},
),
json=block_document_data,
)
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_409_CONFLICT:
Expand Down
5 changes: 4 additions & 1 deletion src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
model_serializer,
model_validator,
)
from pydantic.functional_validators import ModelWrapValidatorHandler
from pydantic_extra_types.pendulum_dt import DateTime
from typing_extensions import Literal, Self, TypeVar

Expand Down Expand Up @@ -938,7 +939,9 @@ def validate_name_is_present_if_not_anonymous(cls, values):
return validate_name_present_on_nonanonymous_blocks(values)

@model_serializer(mode="wrap")
def serialize_data(self, handler, info: SerializationInfo):
def serialize_data(
self, handler: ModelWrapValidatorHandler, info: SerializationInfo
):
self.data = visit_collection(
self.data,
visit_fn=partial(handle_secret_render, context=info.context or {}),
Expand Down
43 changes: 34 additions & 9 deletions tests/blocks/test_system.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,43 @@
import pendulum
import pytest
from pydantic import Secret as PydanticSecret
from pydantic import SecretStr
from pydantic_extra_types.pendulum_dt import DateTime as PydanticDateTime

from prefect.blocks import system
from prefect.blocks.system import DateTime, Secret


async def test_datetime():
await system.DateTime(value=pendulum.datetime(2022, 1, 1)).save(name="test")
api_block = await system.DateTime.load("test")
def test_datetime():
DateTime(value=PydanticDateTime(2022, 1, 1)).save(name="test")
api_block = DateTime.load("test")
assert api_block.value == pendulum.datetime(2022, 1, 1)


async def test_secret_block():
await system.Secret(value="test").save(name="test")
api_block = await system.Secret.load("test")
assert isinstance(api_block.value, SecretStr)
@pytest.mark.parametrize(
"value",
["test", {"key": "value"}, ["test"]],
ids=["string", "dict", "list"],
)
def test_secret_block(value):
Secret(value=value).save(name="test")
api_block = Secret.load("test")
assert isinstance(api_block.value, PydanticSecret)

assert api_block.get() == "test"
assert api_block.get() == value


@pytest.mark.parametrize(
"value",
[
SecretStr("test"),
PydanticSecret[dict]({"key": "value"}),
PydanticSecret[list](["test"]),
],
ids=["secret_string", "secret_dict", "secret_list"],
)
def test_secret_block_with_pydantic_secret(value):
Secret(value=value).save(name="test")
api_block = Secret.load("test")
assert isinstance(api_block.value, PydanticSecret)

assert api_block.get() == value.get_secret_value()

0 comments on commit 4c3e2c3

Please sign in to comment.