From 7c9ccc91accc1abd23dc019b0fd9426938d2ad7d Mon Sep 17 00:00:00 2001 From: Jonathan Serafini Date: Thu, 17 Aug 2023 11:35:28 -0400 Subject: [PATCH] feat: add support for arbitrary types --- fast_depends/__about__.py | 2 +- fast_depends/_compat.py | 10 ++++++++++ fast_depends/core/build.py | 8 +++++--- tests/async/test_cast.py | 26 ++++++++++++++++++++++++++ tests/sync/test_cast.py | 24 ++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 4 deletions(-) diff --git a/fast_depends/__about__.py b/fast_depends/__about__.py index 5b449fd..bc9eafa 100644 --- a/fast_depends/__about__.py +++ b/fast_depends/__about__.py @@ -1,3 +1,3 @@ """FastDepends - extracted and cleared from HTTP domain FastAPI Dependency Injection System""" -__version__ = "2.1.6" +__version__ = "2.1.7" diff --git a/fast_depends/_compat.py b/fast_depends/_compat.py index 82f5594..0290e81 100644 --- a/fast_depends/_compat.py +++ b/fast_depends/_compat.py @@ -4,17 +4,27 @@ PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if PYDANTIC_V2: + from pydantic import ConfigDict from pydantic._internal._typing_extra import ( eval_type_lenient as evaluate_forwardref, ) from pydantic.fields import FieldInfo + + class CreateBaseModel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + else: from pydantic.fields import ModelField as FieldInfo # type: ignore from pydantic.typing import evaluate_forwardref # type: ignore[no-redef] + class CreateBaseModel(BaseModel): + class Config: + arbitrary_types_allowed = True + __all__ = ( "BaseModel", + "CreateBaseModel", "FieldInfo", "create_model", "evaluate_forwardref", diff --git a/fast_depends/core/build.py b/fast_depends/core/build.py index 1458f66..286ca88 100644 --- a/fast_depends/core/build.py +++ b/fast_depends/core/build.py @@ -20,7 +20,7 @@ get_origin, ) -from fast_depends._compat import create_model +from fast_depends._compat import create_model, CreateBaseModel from fast_depends.core.model import CallModel from fast_depends.dependencies import Depends from fast_depends.library import CustomField @@ -155,11 +155,13 @@ def build_call_model( elif param.name not in ("args", "kwargs"): positional_args.append(param.name) - func_model = create_model(name, **class_fields) # type: ignore + func_model = create_model(name, __base__=CreateBaseModel, **class_fields) # type: ignore if cast and return_annotation and return_annotation is not inspect._empty: response_model = create_model( - "ResponseModel", response=(return_annotation, ...) + "ResponseModel", + __base__=CreateBaseModel, + response=(return_annotation, ...), ) else: response_model = None diff --git a/tests/async/test_cast.py b/tests/async/test_cast.py index f677022..8711926 100644 --- a/tests/async/test_cast.py +++ b/tests/async/test_cast.py @@ -26,6 +26,32 @@ async def some_func(a, b: int): assert isinstance(await some_func(1, "2"), int) +@pytest.mark.asyncio +async def test_arbitrary_args(): + class ArbitraryType: + def __init__(self): + self.value = "value" + + @inject + async def some_func(a: ArbitraryType): + return a + + assert isinstance(await some_func(ArbitraryType()), ArbitraryType) + + +@pytest.mark.asyncio +async def test_arbitrary_response(): + class ArbitraryType: + def __init__(self): + self.value = "value" + + @inject + async def some_func(a: ArbitraryType) -> ArbitraryType: + return a + + assert isinstance(await some_func(ArbitraryType()), ArbitraryType) + + @pytest.mark.asyncio async def test_types_casting(): @inject diff --git a/tests/sync/test_cast.py b/tests/sync/test_cast.py index f325e33..61e22ee 100644 --- a/tests/sync/test_cast.py +++ b/tests/sync/test_cast.py @@ -24,6 +24,30 @@ def some_func(a, b: int): assert isinstance(some_func(1, "2"), int) +async def test_arbitrary_args(): + class ArbitraryType: + def __init__(self): + self.value = "value" + + @inject + def some_func(a: ArbitraryType): + return a + + assert isinstance(await some_func(ArbitraryType()), ArbitraryType) + + +async def test_arbitrary_response(): + class ArbitraryType: + def __init__(self): + self.value = "value" + + @inject + def some_func(a: ArbitraryType) -> ArbitraryType: + return a + + assert isinstance(await some_func(ArbitraryType()), ArbitraryType) + + def test_validation_error(): @inject def some_func(a, b: str = Field(..., max_length=1)): # pragma: no cover