Skip to content

Commit

Permalink
Merge branch 'main' into dto-codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Oct 3, 2023
2 parents 7ca5bb7 + df33c4f commit 62a5ff8
Show file tree
Hide file tree
Showing 19 changed files with 316 additions and 63 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
(PY_CLASS, "NoneType"),
(PY_CLASS, "litestar._openapi.schema_generation.schema.SchemaCreator"),
(PY_CLASS, "litestar._signature.model.SignatureModel"),
(PY_CLASS, "litestar.contrib.sqlalchemy.plugins.init.config.compat._CreateEngineMixin"),
(PY_CLASS, "litestar.utils.signature.ParsedSignature"),
(PY_CLASS, "litestar.utils.sync.AsyncCallable"),
# types in changelog that no longer exist
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ Renaming the dependencies
#########################

You can change the name that the engine and session are bound to by setting the
:attr:`engine_dependency_key` and :attr:`session_dependency_key`
:attr:`engine_dependency_key <advanced_alchemy.extensions.litestar.plugins.init.config.asyncio.SQLAlchemyAsyncConfig.engine_dependency_key>`
and :attr:`session_dependency_key <advanced_alchemy.extensions.litestar.plugins.init.config.asyncio.SQLAlchemyAsyncConfig.session_dependency_key>`
attributes on the plugin configuration.

Configuring the before send handler
Expand All @@ -49,9 +50,9 @@ The plugin configures a ``before_send`` handler that is called before sending a
session and removes it from the connection scope.

You can change the handler by setting the
:attr:`before_send_handler` attribute
on the configuration object. For example, an alternate handler is available that will also commit the session on success
and rollback upon failure.
:attr:`before_send_handler <advanced_alchemy.extensions.litestar.plugins.init.config.asyncio.SQLAlchemyAsyncConfig.before_send_handler>`
attribute on the configuration object. For example, an alternate handler is available that will also commit the session
on success and rollback upon failure.

.. tab-set::

Expand Down
6 changes: 6 additions & 0 deletions litestar/_asgi/routing_trie/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def build_route_middleware_stack(
from litestar.middleware.allowed_hosts import AllowedHostsMiddleware
from litestar.middleware.compression import CompressionMiddleware
from litestar.middleware.csrf import CSRFMiddleware
from litestar.middleware.response_cache import ResponseCacheMiddleware
from litestar.routes import HTTPRoute

# we wrap the route.handle method in the ExceptionHandlerMiddleware
asgi_handler = wrap_in_exception_handler(
Expand All @@ -197,6 +199,10 @@ def build_route_middleware_stack(

if app.compression_config:
asgi_handler = CompressionMiddleware(app=asgi_handler, config=app.compression_config)

if isinstance(route, HTTPRoute) and any(r.cache for r in route.route_handlers):
asgi_handler = ResponseCacheMiddleware(app=asgi_handler, config=app.response_cache_config)

if app.allowed_hosts:
asgi_handler = AllowedHostsMiddleware(app=asgi_handler, config=app.allowed_hosts)

Expand Down
24 changes: 17 additions & 7 deletions litestar/_parsers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from __future__ import annotations

from collections import defaultdict
from functools import lru_cache
from http.cookies import _unquote as unquote_cookie
from typing import Any, Iterable
from typing import Iterable
from urllib.parse import unquote

from fast_query_parsers import parse_query_string as fast_parse_query_string
from fast_query_parsers import parse_url_encoded_dict
try:
from fast_query_parsers import parse_query_string as parse_qsl
except ImportError:
from urllib.parse import parse_qsl as _parse_qsl

def parse_qsl(qs: bytes, separator: str) -> list[tuple[str, str]]:
return _parse_qsl(qs.decode("latin-1"), keep_blank_values=True, separator=separator)


__all__ = ("parse_cookie_string", "parse_headers", "parse_query_string", "parse_url_encoded_form_data")


@lru_cache(1024)
def parse_url_encoded_form_data(encoded_data: bytes) -> dict[str, Any]:
def parse_url_encoded_form_data(encoded_data: bytes) -> dict[str, str | list[str]]:
"""Parse an url encoded form data dict.
Args:
Expand All @@ -21,11 +28,14 @@ def parse_url_encoded_form_data(encoded_data: bytes) -> dict[str, Any]:
Returns:
A parsed dict.
"""
return parse_url_encoded_dict(qs=encoded_data, parse_numbers=False)
decoded_dict: defaultdict[str, list[str]] = defaultdict(list)
for k, v in parse_qsl(encoded_data, separator="&"):
decoded_dict[k].append(v)
return {k: v if len(v) > 1 else v[0] for k, v in decoded_dict.items()}


@lru_cache(1024)
def parse_query_string(query_string: bytes) -> tuple[tuple[str, Any], ...]:
def parse_query_string(query_string: bytes) -> tuple[tuple[str, str], ...]:
"""Parse a query string into a tuple of key value pairs.
Args:
Expand All @@ -34,7 +44,7 @@ def parse_query_string(query_string: bytes) -> tuple[tuple[str, Any], ...]:
Returns:
A tuple of key value pairs.
"""
return tuple(fast_parse_query_string(query_string, "&"))
return tuple(parse_qsl(query_string, separator="&"))


@lru_cache(1024)
Expand Down
1 change: 1 addition & 0 deletions litestar/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SCOPE_STATE_DEPENDENCY_CACHE: Final = "dependency_cache"
SCOPE_STATE_NAMESPACE: Final = "__litestar__"
SCOPE_STATE_RESPONSE_COMPRESSED: Final = "response_compressed"
SCOPE_STATE_IS_CACHED: Final = "is_cached"
SKIP_VALIDATION_NAMES: Final = {"request", "socket", "scope", "receive", "send"}
UNDEFINED_SENTINELS: Final = {Signature.empty, Empty, Ellipsis, MISSING, UnsetType}
WEBSOCKET_CLOSE: Final = "websocket.close"
Expand Down
11 changes: 10 additions & 1 deletion litestar/contrib/sqlalchemy/plugins/init/config/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from advanced_alchemy.config.asyncio import AlembicAsyncConfig, AsyncSessionConfig
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
SQLAlchemyAsyncConfig,
SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import (
autocommit_before_send_handler,
default_before_send_handler,
)
from sqlalchemy.ext.asyncio import AsyncEngine

from litestar.contrib.sqlalchemy.plugins.init.config.compat import _CreateEngineMixin

__all__ = (
"SQLAlchemyAsyncConfig",
Expand All @@ -14,3 +19,7 @@
"default_before_send_handler",
"autocommit_before_send_handler",
)


class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig, _CreateEngineMixin[AsyncEngine]):
...
23 changes: 23 additions & 0 deletions litestar/contrib/sqlalchemy/plugins/init/config/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, Protocol, TypeVar

from litestar.utils.deprecation import deprecated

if TYPE_CHECKING:
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine


EngineT_co = TypeVar("EngineT_co", bound="Engine | AsyncEngine", covariant=True)


class HasGetEngine(Protocol[EngineT_co]):
def get_engine(self) -> EngineT_co:
...


class _CreateEngineMixin(Generic[EngineT_co]):
@deprecated(version="2.1.1", removal_in="3.0.0", alternative="get_engine()")
def create_engine(self: HasGetEngine[EngineT_co]) -> EngineT_co:
return self.get_engine()
11 changes: 10 additions & 1 deletion litestar/contrib/sqlalchemy/plugins/init/config/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from advanced_alchemy.config.sync import AlembicSyncConfig, SyncSessionConfig
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
SQLAlchemySyncConfig,
SQLAlchemySyncConfig as _SQLAlchemySyncConfig,
)
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import (
autocommit_before_send_handler,
default_before_send_handler,
)
from sqlalchemy import Engine

from litestar.contrib.sqlalchemy.plugins.init.config.compat import _CreateEngineMixin

__all__ = (
"SQLAlchemySyncConfig",
Expand All @@ -14,3 +19,7 @@
"default_before_send_handler",
"autocommit_before_send_handler",
)


class SQLAlchemySyncConfig(_SQLAlchemySyncConfig, _CreateEngineMixin[Engine]):
...
11 changes: 9 additions & 2 deletions litestar/middleware/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from io import BytesIO
from typing import TYPE_CHECKING, Any, Literal, Optional

from litestar.constants import SCOPE_STATE_RESPONSE_COMPRESSED
from litestar.constants import SCOPE_STATE_IS_CACHED, SCOPE_STATE_RESPONSE_COMPRESSED
from litestar.datastructures import Headers, MutableScopeHeaders
from litestar.enums import CompressionEncoding, ScopeType
from litestar.exceptions import MissingDependencyException
from litestar.middleware.base import AbstractMiddleware
from litestar.utils import Ref, set_litestar_scope_state
from litestar.utils import Ref, get_litestar_scope_state, set_litestar_scope_state

__all__ = ("CompressionFacade", "CompressionMiddleware")

Expand Down Expand Up @@ -176,6 +176,8 @@ def create_compression_send_wrapper(
initial_message = Ref[Optional["HTTPResponseStartEvent"]](None)
started = Ref[bool](False)

_own_encoding = compression_encoding.encode("latin-1")

async def send_wrapper(message: Message) -> None:
"""Handle and compresses the HTTP Message with brotli.
Expand All @@ -187,6 +189,11 @@ async def send_wrapper(message: Message) -> None:
initial_message.value = message
return

if initial_message.value and get_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED):
await send(initial_message.value)
await send(message)
return

if initial_message.value and message["type"] == "http.response.body":
body = message["body"]
more_body = message.get("more_body")
Expand Down
48 changes: 48 additions & 0 deletions litestar/middleware/response_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

from msgspec.msgpack import encode as encode_msgpack

from litestar.enums import ScopeType
from litestar.utils import get_litestar_scope_state

from .base import AbstractMiddleware

__all__ = ["ResponseCacheMiddleware"]

from typing import TYPE_CHECKING, cast

from litestar import Request
from litestar.constants import SCOPE_STATE_IS_CACHED

if TYPE_CHECKING:
from litestar.config.response_cache import ResponseCacheConfig
from litestar.handlers import HTTPRouteHandler
from litestar.types import ASGIApp, Message, Receive, Scope, Send


class ResponseCacheMiddleware(AbstractMiddleware):
def __init__(self, app: ASGIApp, config: ResponseCacheConfig) -> None:
self.config = config
super().__init__(app=app, scopes={ScopeType.HTTP})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
route_handler = cast("HTTPRouteHandler", scope["route_handler"])
store = self.config.get_store_from_app(scope["app"])

expires_in: int | None = None
if route_handler.cache is True:
expires_in = self.config.default_expiration
elif route_handler.cache is not False and isinstance(route_handler.cache, int):
expires_in = route_handler.cache

messages = []

async def wrapped_send(message: Message) -> None:
if not get_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED):
messages.append(message)
if message["type"] == "http.response.body" and not message["more_body"]:
key = (route_handler.cache_key_builder or self.config.key_builder)(Request(scope))
await store.set(key, encode_msgpack(messages), expires_in=expires_in)
await send(message)

await self.app(scope, receive, wrapped_send)
48 changes: 15 additions & 33 deletions litestar/routes/http.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import pickle
from itertools import chain
from typing import TYPE_CHECKING, Any, cast

from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
from msgspec.msgpack import decode as _decode_msgpack_plain

from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS, SCOPE_STATE_IS_CACHED
from litestar.datastructures.headers import Headers
from litestar.datastructures.upload_file import UploadFile
from litestar.enums import HttpMethod, MediaType, ScopeType
Expand All @@ -13,6 +14,7 @@
from litestar.response import Response
from litestar.routes.base import BaseRoute
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from litestar.utils import set_litestar_scope_state

if TYPE_CHECKING:
from litestar._kwargs import KwargsModel
Expand Down Expand Up @@ -128,19 +130,10 @@ async def _get_response_for_request(
):
return response

response = await self._call_handler_function(
return await self._call_handler_function(
scope=scope, request=request, parameter_model=parameter_model, route_handler=route_handler
)

if route_handler.cache:
await self._set_cached_response(
response=response,
request=request,
route_handler=route_handler,
)

return response

async def _call_handler_function(
self, scope: Scope, request: Request, parameter_model: KwargsModel, route_handler: HTTPRouteHandler
) -> ASGIApp:
Expand Down Expand Up @@ -225,30 +218,19 @@ async def _get_cached_response(request: Request, route_handler: HTTPRouteHandler
cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request)
store = cache_config.get_store_from_app(request.app)

cached_response = await store.get(key=cache_key)

if cached_response:
return cast("ASGIApp", pickle.loads(cached_response)) # noqa: S301
if not (cached_response_data := await store.get(key=cache_key)):
return None

return None
# we use the regular msgspec.msgpack.decode here since we don't need any of
# the added decoders
messages = _decode_msgpack_plain(cached_response_data)

@staticmethod
async def _set_cached_response(
response: Response | ASGIApp, request: Request, route_handler: HTTPRouteHandler
) -> None:
"""Pickles and caches a response object."""
cache_config = request.app.response_cache_config
cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request)

expires_in: int | None = None
if route_handler.cache is True:
expires_in = cache_config.default_expiration
elif route_handler.cache is not False and isinstance(route_handler.cache, int):
expires_in = route_handler.cache

store = cache_config.get_store_from_app(request.app)
async def cached_response(scope: Scope, receive: Receive, send: Send) -> None:
set_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED, True)
for message in messages:
await send(message)

await store.set(key=cache_key, value=pickle.dumps(response, pickle.HIGHEST_PROTOCOL), expires_in=expires_in)
return cached_response

def create_options_handler(self, path: str) -> HTTPRouteHandler:
"""Args:
Expand Down
Loading

0 comments on commit 62a5ff8

Please sign in to comment.