Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Piccolo ORM DTO #1896

Merged
merged 9 commits into from
Jul 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ jobs:
run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV
- name: Test
if: ${{ !inputs.coverage }}
run: |
source $VENV
pytest
run: poetry run pytest docs/examples tests
- name: Test with coverage
if: inputs.coverage
run: poetry run pytest docs/examples tests --cov=litestar --cov-report=xml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def get_handler(state: State, request: Request, dep: Any) -> None:
logger.info("state value in handler from `Request`: %s", request.app.state.value)


app = Litestar(route_handlers=[get_handler], on_startup=[set_state_on_startup], debug=True)
app = Litestar(route_handlers=[get_handler], on_startup=[set_state_on_startup])
2 changes: 2 additions & 0 deletions docs/examples/contrib/jwt/using_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ async def retrieve_user_handler(token: Token, connection: "ASGIConnection[Any, A
@post("/login")
async def login_handler(data: User) -> Response[User]:
MOCK_DB[str(data.id)] = data
# you can do whatever you want to update the response instance here
# e.g. response.set_cookie(...)
return jwt_auth.login(identifier=str(data.id), response_body=data)


Expand Down
6 changes: 6 additions & 0 deletions docs/examples/contrib/jwt/using_oauth2_password_bearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,19 @@ async def retrieve_user_handler(token: "Token", connection: "ASGIConnection[Any,
@post("/login")
async def login_handler(request: "Request[Any, Any, Any]", data: "User") -> "Response[OAuth2Login]":
MOCK_DB[str(data.id)] = data
# if we do not define a response body, the login process will return a standard OAuth2 login response. Note the `Response[OAuth2Login]` return type.

# you can do whatever you want to update the response instance here
# e.g. response.set_cookie(...)
return oauth2_auth.login(identifier=str(data.id))


@post("/login_custom")
async def login_custom_response_handler(data: "User") -> "Response[User]":
MOCK_DB[str(data.id)] = data

# you can do whatever you want to update the response instance here
# e.g. response.set_cookie(...)
return oauth2_auth.login(identifier=str(data.id), response_body=data)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ async def handler(db_session: AsyncSession, db_engine: AsyncEngine) -> Tuple[int

config = SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///async.sqlite")
plugin = SQLAlchemyInitPlugin(config=config)
app = Litestar(route_handlers=[handler], plugins=[plugin], debug=True)
app = Litestar(route_handlers=[handler], plugins=[plugin])
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def handler(db_session: Session, db_engine: Engine) -> Tuple[int, int]:

config = SQLAlchemySyncConfig(connection_string="sqlite:///sync.sqlite")
plugin = SQLAlchemyInitPlugin(config=config)
app = Litestar(route_handlers=[handler], plugins=[plugin], debug=True)
app = Litestar(route_handlers=[handler], plugins=[plugin])
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,4 @@ async def on_startup() -> None:
on_startup=[on_startup],
plugins=[SQLAlchemyInitPlugin(config=sqlalchemy_config)],
dependencies={"limit_offset": Provide(provide_limit_offset_pagination)},
debug=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,4 @@ def on_startup() -> None:
on_startup=[on_startup],
plugins=[SQLAlchemyInitPlugin(config=sqlalchemy_config)],
dependencies={"limit_offset": Provide(provide_limit_offset_pagination)},
debug=True,
)
1 change: 0 additions & 1 deletion docs/examples/routing/mounting_starlette_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ async def index(request: "Request") -> JSONResponse:

starlette_app = asgi(path="/some/sub-path", is_mount=True)(
Starlette(
debug=True,
routes=[
Route("/", index),
Route("/abc/", index),
Expand Down
49 changes: 30 additions & 19 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
if TYPE_CHECKING:
from msgspec import Struct

from litestar.dto.types import ForType
from litestar.plugins import OpenAPISchemaPluginProtocol

try:
Expand Down Expand Up @@ -278,16 +279,20 @@
}


def _get_type_schema_name(value: Any) -> str:
def _get_type_schema_name(value: Any, dto_for: ForType | None) -> str:
"""Extract the schema name from a data container.

Args:
value: A data container
dto_for: The type of DTO to create the schema for.

Returns:
A string
"""
return cast("str", getattr(value, "__schema_name__", value.__name__))
name = cast("str", getattr(value, "__schema_name__", value.__name__))
if dto_for == "data":
return f"{name}RequestBody"
return f"{name}ResponseBody" if dto_for == "return" else name


def create_enum_schema(annotation: EnumMeta) -> Schema:
Expand Down Expand Up @@ -358,7 +363,7 @@ def create_schema_for_annotation(annotation: Any) -> Schema:


class SchemaCreator:
__slots__ = ("generate_examples", "plugins", "schemas", "prefer_alias")
__slots__ = ("generate_examples", "plugins", "schemas", "prefer_alias", "dto_for")

def __init__(
self,
Expand Down Expand Up @@ -389,11 +394,12 @@ def not_generating_examples(self) -> SchemaCreator:
new.generate_examples = False
return new

def for_field(self, field: SignatureField) -> Schema | Reference:
def for_field(self, field: SignatureField, dto_for: ForType | None = None) -> Schema | Reference:
"""Create a Schema for a given SignatureField.

Args:
field: A signature field instance.
dto_for: The type of DTO to create the schema for.

Returns:
A schema instance.
Expand All @@ -404,15 +410,15 @@ def for_field(self, field: SignatureField) -> Schema | Reference:
elif field.is_union:
result = self.for_union_field(field)
elif is_pydantic_model_class(field.field_type):
result = self.for_pydantic_model(field.field_type)
result = self.for_pydantic_model(field.field_type, dto_for)
elif is_attrs_class(field.field_type):
result = self.for_attrs_class(field.field_type)
result = self.for_attrs_class(field.field_type, dto_for)
elif is_struct_class(field.field_type):
result = self.for_struct_class(field.field_type)
result = self.for_struct_class(field.field_type, dto_for)
elif is_dataclass_class(field.field_type):
result = self.for_dataclass(field.field_type)
result = self.for_dataclass(field.field_type, dto_for)
elif is_typed_dict(field.field_type):
result = self.for_typed_dict(field.field_type)
result = self.for_typed_dict(field.field_type, dto_for)
elif plugins_for_annotation := [
plugin for plugin in self.plugins if plugin.is_plugin_supported_type(field.field_type)
]:
Expand Down Expand Up @@ -575,11 +581,12 @@ def for_plugin(self, field: SignatureField, plugin: OpenAPISchemaPluginProtocol)
)
return schema # pragma: no cover

def for_pydantic_model(self, field_type: type[BaseModel]) -> Schema:
def for_pydantic_model(self, field_type: type[BaseModel], dto_for: ForType | None) -> Schema:
"""Create a schema object for a given pydantic model class.

Args:
field_type: A pydantic model class.
dto_for: The type of DTO to generate a schema for.

Returns:
A schema instance.
Expand All @@ -604,15 +611,16 @@ def for_pydantic_model(self, field_type: type[BaseModel]) -> Schema:
for f in field_type.__fields__.values()
},
type=OpenAPIType.OBJECT,
title=title or _get_type_schema_name(field_type),
title=title or _get_type_schema_name(field_type, dto_for),
examples=[Example(example)] if example else None,
)

def for_attrs_class(self, field_type: type[AttrsInstance]) -> Schema:
def for_attrs_class(self, field_type: type[AttrsInstance], dto_for: ForType | None) -> Schema:
"""Create a schema object for a given attrs class.

Args:
field_type: An attrs class.
dto_for: The type of DTO to generate a schema for.

Returns:
A schema instance.
Expand All @@ -631,14 +639,15 @@ def for_attrs_class(self, field_type: type[AttrsInstance]) -> Schema:
),
properties={k: self.for_field(SignatureField.create(v, k)) for k, v in field_type_hints.items()},
type=OpenAPIType.OBJECT,
title=_get_type_schema_name(field_type),
title=_get_type_schema_name(field_type, dto_for),
)

def for_struct_class(self, field_type: type[Struct]) -> Schema:
def for_struct_class(self, field_type: type[Struct], dto_for: ForType | None) -> Schema:
"""Create a schema object for a given msgspec.Struct class.

Args:
field_type: A msgspec.Struct class.
dto_for: The type of DTO to generate a schema for.

Returns:
A schema instance.
Expand All @@ -656,14 +665,15 @@ def for_struct_class(self, field_type: type[Struct]) -> Schema:
for field in msgspec_struct_fields(field_type)
},
type=OpenAPIType.OBJECT,
title=_get_type_schema_name(field_type),
title=_get_type_schema_name(field_type, dto_for),
)

def for_dataclass(self, field_type: type[DataclassProtocol]) -> Schema:
def for_dataclass(self, field_type: type[DataclassProtocol], dto_for: ForType | None) -> Schema:
"""Create a schema object for a given dataclass class.

Args:
field_type: A dataclass class.
dto_for: The type of DTO to generate a schema for.

Returns:
A schema instance.
Expand All @@ -683,14 +693,15 @@ def for_dataclass(self, field_type: type[DataclassProtocol]) -> Schema:
),
properties={k: self.for_field(SignatureField.create(v, k)) for k, v in field_type_hints.items()},
type=OpenAPIType.OBJECT,
title=_get_type_schema_name(field_type),
title=_get_type_schema_name(field_type, dto_for),
)

def for_typed_dict(self, field_type: TypedDictClass) -> Schema:
def for_typed_dict(self, field_type: TypedDictClass, dto_for: ForType | None) -> Schema:
"""Create a schema object for a given typed dict.

Args:
field_type: A typed-dict class.
dto_for: The type of DTO to generate a schema for.

Returns:
A schema instance.
Expand All @@ -703,7 +714,7 @@ def for_typed_dict(self, field_type: TypedDictClass) -> Schema:
required=sorted(getattr(field_type, "__required_keys__", [])),
properties={k: self.for_field(SignatureField.create(v, k)) for k, v in annotations.items()},
type=OpenAPIType.OBJECT,
title=_get_type_schema_name(field_type),
title=_get_type_schema_name(field_type, dto_for),
)

def for_constrained_field(self, field: SignatureField) -> Schema:
Expand Down
7 changes: 5 additions & 2 deletions litestar/channels/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def encode_data(self, data: LitestarEncodableType) -> bytes:
if isinstance(data, bytes):
return data

return data.encode() if isinstance(data, str) else self._encode_json(data)
if isinstance(data, str):
return data.encode()

return self._encode_json(data)

def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""Plugin hook. Set up a ``channels`` dependency, add route handlers and register application hooks"""
Expand Down Expand Up @@ -229,7 +232,7 @@ async def unsubscribe(self, subscriber: Subscriber, channels: str | Iterable[str
if not channel_subscribers:
channels_to_unsubscribe.add(channel)

if all(subscriber not in queues for queues in self._channels.values()):
if not any(subscriber in queues for queues in self._channels.values()):
await subscriber.put(None) # this will stop any running task or generator by breaking the inner loop
if subscriber.is_running:
await subscriber.stop()
Expand Down
Empty file.
98 changes: 98 additions & 0 deletions litestar/contrib/piccolo/dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

from typing import Any, Generator, Generic, Optional, TypeVar

from _decimal import Decimal
from msgspec import Meta
from typing_extensions import Annotated

from litestar.dto.factory.abc import AbstractDTOFactory
from litestar.dto.factory.data_structures import FieldDefinition
from litestar.dto.factory.field import DTOField, Mark
from litestar.exceptions import MissingDependencyException
from litestar.types import Empty
from litestar.typing import ParsedType
from litestar.utils.helpers import get_fully_qualified_class_name

try:
import piccolo # noqa: F401
except ImportError as e:
raise MissingDependencyException("piccolo orm is not installed") from e

from piccolo.columns import Column, column_types
from piccolo.table import Table

T = TypeVar("T", bound=Table)

__all__ = ("PiccoloDTO",)


def _parse_piccolo_type(column: Column, extra: dict[str, Any]) -> ParsedType:
if isinstance(column, (column_types.Decimal, column_types.Numeric)):
column_type: Any = Decimal
meta = Meta(extra=extra)
elif isinstance(column, (column_types.Email, column_types.Varchar)):
column_type = str
meta = Meta(max_length=column.length, extra=extra)
elif isinstance(column, column_types.Array):
column_type = list[column.base_column.value_type] # type: ignore
meta = Meta(extra=extra)
elif isinstance(column, (column_types.JSON, column_types.JSONB)):
column_type = str
meta = Meta(extra={**extra, "format": "json"})
elif isinstance(column, column_types.Text):
column_type = str
meta = Meta(extra={**extra, "format": "text-area"})
elif isinstance(column, column_types.Secret):
column_type = str
meta = Meta(extra={"secret": True})
else:
column_type = column.value_type
meta = Meta(extra=extra)

if not column._meta.required:
column_type = Optional[column_type]

return ParsedType(Annotated[column_type, meta])


def _create_column_extra(column: Column) -> dict[str, Any]:
extra: dict[str, Any] = {}

if column._meta.help_text:
extra["help_text"] = column._meta.help_text

if column._meta.get_choices_dict():
extra["choices"] = column._meta.get_choices_dict()

if column._meta.db_column_name != column._meta.name:
extra["alias"] = column._meta.db_column_name

if isinstance(column, column_types.ForeignKey):
extra["foreign_key"] = True
extra["to"] = column._foreign_key_meta.resolved_references._meta.tablename
extra["target_column"] = column._foreign_key_meta.resolved_target_column._meta.name

return extra


class PiccoloDTO(AbstractDTOFactory[T], Generic[T]):
@classmethod
def generate_field_definitions(cls, model_type: type[Table]) -> Generator[FieldDefinition, None, None]:
unique_model_name = get_fully_qualified_class_name(model_type)

for column in model_type._meta.columns:
yield FieldDefinition(
default=Empty if column._meta.required else None,
default_factory=Empty,
# TODO: is there a better way of handling this?
dto_field=DTOField(mark=Mark.READ_ONLY if column._meta.primary_key else None),
dto_for=None,
name=column._meta.name,
parsed_type=_parse_piccolo_type(column, _create_column_extra(column)),
unique_model_name=unique_model_name,
)

@classmethod
def detect_nested_field(cls, parsed_type: ParsedType) -> bool:
return parsed_type.is_subclass_of(Table)
6 changes: 3 additions & 3 deletions litestar/contrib/sqlalchemy/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def _(

return [
FieldDefinition(
name=key,
default=default,
parsed_type=parsed_type,
default_factory=default_factory,
dto_field=elem.info.get(DTO_FIELD_META_KEY, DTOField()),
unique_model_name=model_name,
dto_for=None,
name=key,
parsed_type=parsed_type,
unique_model_name=model_name,
)
]

Expand Down
1 change: 1 addition & 0 deletions litestar/datastructures/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def from_scope(cls, scope: Scope) -> URL:
path = scope.get("root_path", "") + scope["path"]
query_string = scope.get("query_string", b"")

# we use iteration here because it's faster, and headers might not yet be cached
host = next(
(
header_value.decode("latin-1")
Expand Down
Loading