diff --git a/src/prefect/blocks/core.py b/src/prefect/blocks/core.py index d262f9403b7f..cac9e8172bb5 100644 --- a/src/prefect/blocks/core.py +++ b/src/prefect/blocks/core.py @@ -150,7 +150,11 @@ def _collect_secret_fields( _collect_secret_fields(f"{name}.{field_name}", field.annotation, secrets) return - if type_ in (SecretStr, SecretBytes): + if type_ in (SecretStr, SecretBytes) or ( + isinstance(type_, type) + and getattr(type_, "__module__", None) == "pydantic.types" + and getattr(type_, "__name__", None) == "Secret" + ): secrets.append(name) elif type_ == SecretDict: # Append .* to field name to signify that all values under this diff --git a/src/prefect/blocks/system.py b/src/prefect/blocks/system.py index 7b7064587e98..43430af39e73 100644 --- a/src/prefect/blocks/system.py +++ b/src/prefect/blocks/system.py @@ -1,11 +1,26 @@ -from typing import Any - -from pydantic import Field, SecretStr -from pydantic_extra_types.pendulum_dt import DateTime +import json +from typing import Annotated, Any, Generic, TypeVar, Union + +from pydantic import ( + Field, + JsonValue, + SecretStr, + StrictStr, + field_validator, +) +from pydantic import Secret as PydanticSecret +from pydantic_extra_types.pendulum_dt import DateTime as PydanticDateTime from prefect._internal.compatibility.deprecated import deprecated_class from prefect.blocks.core import Block +_SecretValueType = Union[ + Annotated[StrictStr, Field(title="string")], + Annotated[JsonValue, Field(title="JSON")], +] + +T = TypeVar("T", bound=_SecretValueType) + @deprecated_class( start_date="Jun 2024", @@ -86,24 +101,26 @@ class DateTime(Block): _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/8b3da9a6621e92108b8e6a75b82e15374e170ff7-48x48.png" _documentation_url = "https://docs.prefect.io/api-ref/prefect/blocks/system/#prefect.blocks.system.DateTime" - value: DateTime = Field( + value: PydanticDateTime = Field( default=..., description="An ISO 8601-compatible datetime value.", ) -class Secret(Block): +class Secret(Block, Generic[T]): """ A block that represents a secret value. The value stored in this block will be obfuscated when - this block is logged or shown in the UI. + this block is viewed or edited in the UI. Attributes: - value: A string value that should be kept secret. + value: A value that should be kept secret. Example: ```python from prefect.blocks.system import Secret + Secret(value="sk-1234567890").save("BLOCK_NAME", overwrite=True) + secret_block = Secret.load("BLOCK_NAME") # Access the stored secret @@ -114,9 +131,28 @@ class Secret(Block): _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/c6f20e556dd16effda9df16551feecfb5822092b-48x48.png" _documentation_url = "https://docs.prefect.io/api-ref/prefect/blocks/system/#prefect.blocks.system.Secret" - value: SecretStr = Field( - default=..., description="A string value that should be kept secret." + value: Union[SecretStr, PydanticSecret[T]] = Field( + default=..., + description="A value that should be kept secret.", + examples=["sk-1234567890", {"username": "johndoe", "password": "s3cr3t"}], + json_schema_extra={ + "writeOnly": True, + "format": "password", + }, ) - def get(self): - return self.value.get_secret_value() + @field_validator("value", mode="before") + def validate_value( + cls, value: Union[T, SecretStr, PydanticSecret[T]] + ) -> Union[SecretStr, PydanticSecret[T]]: + if isinstance(value, (PydanticSecret, SecretStr)): + return value + else: + return PydanticSecret[type(value)](value) + + def get(self) -> T: + try: + value = self.value.get_secret_value() + return json.loads(value) + except (TypeError, json.JSONDecodeError): + return value diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index 938de46b6b8a..294075de66a2 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -1324,15 +1324,17 @@ async def create_block_document( `SecretBytes` fields. Note Blocks may not work as expected if this is set to `False`. """ + block_document_data = block_document.model_dump( + mode="json", + exclude_unset=True, + exclude={"id", "block_schema", "block_type"}, + context={"include_secrets": include_secrets}, + serialize_as_any=True, + ) try: response = await self._client.post( "/block_documents/", - json=block_document.model_dump( - mode="json", - exclude_unset=True, - exclude={"id", "block_schema", "block_type"}, - context={"include_secrets": include_secrets}, - ), + json=block_document_data, ) except httpx.HTTPStatusError as e: if e.response.status_code == status.HTTP_409_CONFLICT: diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index 1cc0eb9977b1..2d439713bfc9 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -24,6 +24,7 @@ model_serializer, model_validator, ) +from pydantic.functional_validators import ModelWrapValidatorHandler from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Literal, Self, TypeVar @@ -938,7 +939,9 @@ def validate_name_is_present_if_not_anonymous(cls, values): return validate_name_present_on_nonanonymous_blocks(values) @model_serializer(mode="wrap") - def serialize_data(self, handler, info: SerializationInfo): + def serialize_data( + self, handler: ModelWrapValidatorHandler, info: SerializationInfo + ): self.data = visit_collection( self.data, visit_fn=partial(handle_secret_render, context=info.context or {}), diff --git a/tests/blocks/test_system.py b/tests/blocks/test_system.py index d9404532f0ea..548af56b8239 100644 --- a/tests/blocks/test_system.py +++ b/tests/blocks/test_system.py @@ -1,18 +1,43 @@ import pendulum +import pytest +from pydantic import Secret as PydanticSecret from pydantic import SecretStr +from pydantic_extra_types.pendulum_dt import DateTime as PydanticDateTime -from prefect.blocks import system +from prefect.blocks.system import DateTime, Secret -async def test_datetime(): - await system.DateTime(value=pendulum.datetime(2022, 1, 1)).save(name="test") - api_block = await system.DateTime.load("test") +def test_datetime(): + DateTime(value=PydanticDateTime(2022, 1, 1)).save(name="test") + api_block = DateTime.load("test") assert api_block.value == pendulum.datetime(2022, 1, 1) -async def test_secret_block(): - await system.Secret(value="test").save(name="test") - api_block = await system.Secret.load("test") - assert isinstance(api_block.value, SecretStr) +@pytest.mark.parametrize( + "value", + ["test", {"key": "value"}, ["test"]], + ids=["string", "dict", "list"], +) +def test_secret_block(value): + Secret(value=value).save(name="test") + api_block = Secret.load("test") + assert isinstance(api_block.value, PydanticSecret) - assert api_block.get() == "test" + assert api_block.get() == value + + +@pytest.mark.parametrize( + "value", + [ + SecretStr("test"), + PydanticSecret[dict]({"key": "value"}), + PydanticSecret[list](["test"]), + ], + ids=["secret_string", "secret_dict", "secret_list"], +) +def test_secret_block_with_pydantic_secret(value): + Secret(value=value).save(name="test") + api_block = Secret.load("test") + assert isinstance(api_block.value, PydanticSecret) + + assert api_block.get() == value.get_secret_value()