diff --git a/docs/conf.py b/docs/conf.py index fd3ea919cb..e6f1b9617c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -115,6 +115,7 @@ (PY_CLASS, "NoneType"), (PY_CLASS, "litestar._openapi.schema_generation.schema.SchemaCreator"), (PY_CLASS, "litestar._signature.model.SignatureModel"), + (PY_CLASS, "litestar.contrib.sqlalchemy.plugins.init.config.compat._CreateEngineMixin"), (PY_CLASS, "litestar.utils.signature.ParsedSignature"), (PY_CLASS, "litestar.utils.sync.AsyncCallable"), # types in changelog that no longer exist diff --git a/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_init_plugin.rst b/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_init_plugin.rst index 158a147265..2ebd069f6a 100644 --- a/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_init_plugin.rst +++ b/docs/usage/databases/sqlalchemy/plugins/sqlalchemy_init_plugin.rst @@ -39,7 +39,8 @@ Renaming the dependencies ######################### You can change the name that the engine and session are bound to by setting the -:attr:`engine_dependency_key` and :attr:`session_dependency_key` +:attr:`engine_dependency_key ` +and :attr:`session_dependency_key ` attributes on the plugin configuration. Configuring the before send handler @@ -49,9 +50,9 @@ The plugin configures a ``before_send`` handler that is called before sending a session and removes it from the connection scope. You can change the handler by setting the -:attr:`before_send_handler` attribute -on the configuration object. For example, an alternate handler is available that will also commit the session on success -and rollback upon failure. +:attr:`before_send_handler ` +attribute on the configuration object. For example, an alternate handler is available that will also commit the session +on success and rollback upon failure. .. tab-set:: diff --git a/litestar/contrib/sqlalchemy/plugins/init/config/asyncio.py b/litestar/contrib/sqlalchemy/plugins/init/config/asyncio.py index 528b0e3c91..4f50e2bb71 100644 --- a/litestar/contrib/sqlalchemy/plugins/init/config/asyncio.py +++ b/litestar/contrib/sqlalchemy/plugins/init/config/asyncio.py @@ -2,10 +2,15 @@ from advanced_alchemy.config.asyncio import AlembicAsyncConfig, AsyncSessionConfig from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import ( - SQLAlchemyAsyncConfig, + SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig, +) +from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import ( autocommit_before_send_handler, default_before_send_handler, ) +from sqlalchemy.ext.asyncio import AsyncEngine + +from litestar.contrib.sqlalchemy.plugins.init.config.compat import _CreateEngineMixin __all__ = ( "SQLAlchemyAsyncConfig", @@ -14,3 +19,7 @@ "default_before_send_handler", "autocommit_before_send_handler", ) + + +class SQLAlchemyAsyncConfig(_SQLAlchemyAsyncConfig, _CreateEngineMixin[AsyncEngine]): + ... diff --git a/litestar/contrib/sqlalchemy/plugins/init/config/compat.py b/litestar/contrib/sqlalchemy/plugins/init/config/compat.py new file mode 100644 index 0000000000..af6c92bd2f --- /dev/null +++ b/litestar/contrib/sqlalchemy/plugins/init/config/compat.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar + +from litestar.utils.deprecation import deprecated + +if TYPE_CHECKING: + from sqlalchemy import Engine + from sqlalchemy.ext.asyncio import AsyncEngine + + +EngineT_co = TypeVar("EngineT_co", bound="Engine | AsyncEngine", covariant=True) + + +class HasGetEngine(Protocol[EngineT_co]): + def get_engine(self) -> EngineT_co: + ... + + +class _CreateEngineMixin(Generic[EngineT_co]): + @deprecated(version="2.1.1", removal_in="3.0.0", alternative="get_engine()") + def create_engine(self: HasGetEngine[EngineT_co]) -> EngineT_co: + return self.get_engine() diff --git a/litestar/contrib/sqlalchemy/plugins/init/config/sync.py b/litestar/contrib/sqlalchemy/plugins/init/config/sync.py index f033638817..a7839fb62c 100644 --- a/litestar/contrib/sqlalchemy/plugins/init/config/sync.py +++ b/litestar/contrib/sqlalchemy/plugins/init/config/sync.py @@ -2,10 +2,15 @@ from advanced_alchemy.config.sync import AlembicSyncConfig, SyncSessionConfig from advanced_alchemy.extensions.litestar.plugins.init.config.sync import ( - SQLAlchemySyncConfig, + SQLAlchemySyncConfig as _SQLAlchemySyncConfig, +) +from advanced_alchemy.extensions.litestar.plugins.init.config.sync import ( autocommit_before_send_handler, default_before_send_handler, ) +from sqlalchemy import Engine + +from litestar.contrib.sqlalchemy.plugins.init.config.compat import _CreateEngineMixin __all__ = ( "SQLAlchemySyncConfig", @@ -14,3 +19,7 @@ "default_before_send_handler", "autocommit_before_send_handler", ) + + +class SQLAlchemySyncConfig(_SQLAlchemySyncConfig, _CreateEngineMixin[Engine]): + ... diff --git a/tests/unit/test_contrib/test_sqlalchemy/test_init_plugin/test_config/__init__.py b/tests/unit/test_contrib/test_sqlalchemy/test_init_plugin/test_config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/test_contrib/test_sqlalchemy/test_init_plugin/test_config/test_asyncio.py b/tests/unit/test_contrib/test_sqlalchemy/test_init_plugin/test_config/test_asyncio.py new file mode 100644 index 0000000000..e02879d25c --- /dev/null +++ b/tests/unit/test_contrib/test_sqlalchemy/test_init_plugin/test_config/test_asyncio.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from litestar.contrib.sqlalchemy.plugins.init.config.asyncio import SQLAlchemyAsyncConfig + + +def test_create_engine_with_engine_instance() -> None: + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + config = SQLAlchemyAsyncConfig(engine_instance=engine) + with pytest.deprecated_call(): + assert engine is config.create_engine() + + +def test_create_engine_with_connection_string() -> None: + config = SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///:memory:") + with pytest.deprecated_call(): + engine = config.create_engine() + assert isinstance(engine, AsyncEngine) diff --git a/tests/unit/test_contrib/test_sqlalchemy/test_init_plugin/test_config/test_sync.py b/tests/unit/test_contrib/test_sqlalchemy/test_init_plugin/test_config/test_sync.py new file mode 100644 index 0000000000..6d58cb5240 --- /dev/null +++ b/tests/unit/test_contrib/test_sqlalchemy/test_init_plugin/test_config/test_sync.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import pytest +from sqlalchemy import Engine, create_engine + +from litestar.contrib.sqlalchemy.plugins.init.config.sync import SQLAlchemySyncConfig + + +def test_create_engine_with_engine_instance() -> None: + engine = create_engine("sqlite:///:memory:") + config = SQLAlchemySyncConfig(engine_instance=engine) + with pytest.deprecated_call(): + assert engine is config.create_engine() + + +def test_create_engine_with_connection_string() -> None: + config = SQLAlchemySyncConfig(connection_string="sqlite:///:memory:") + with pytest.deprecated_call(): + engine = config.create_engine() + assert isinstance(engine, Engine)