diff --git a/litestar/contrib/sqlalchemy/repository/_async.py b/litestar/contrib/sqlalchemy/repository/_async.py index 49bf21bffc..5de64c2ad7 100644 --- a/litestar/contrib/sqlalchemy/repository/_async.py +++ b/litestar/contrib/sqlalchemy/repository/_async.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, Literal, cast +from typing import TYPE_CHECKING, Any, Generic, Iterable, Literal, cast from sqlalchemy import Result, Select, delete, over, select, text, update from sqlalchemy import func as sql_func @@ -36,7 +36,8 @@ def __init__( *, statement: Select[tuple[ModelT]] | None = None, session: AsyncSession, - expunge: bool = False, + auto_expunge: bool = False, + auto_refresh: bool = False, **kwargs: Any, ) -> None: """Repository pattern for SQLAlchemy models. @@ -44,12 +45,14 @@ def __init__( Args: statement: To facilitate customization of the underlying select query. session: Session managing the unit-of-work for the operation. - expunge: Remove object from session before returning. + auto_expunge: Remove object from session before returning. + auto_refresh: Refresh object from session before returning. **kwargs: Additional arguments. """ super().__init__(**kwargs) - self.expunge = expunge + self.auto_expunge = auto_expunge + self.auto_refresh = auto_refresh self.session = session self.statement = statement if statement is not None else select(self.model_type) if not self.session.bind: @@ -58,10 +61,6 @@ def __init__( raise ValueError("Session improperly configure") self._dialect = self.session.bind.dialect - def _expunge(self, instance: ModelT) -> None: - if self.expunge: - self.session.expunge(instance) - async def add(self, data: ModelT) -> ModelT: """Add `data` to the collection. @@ -74,7 +73,7 @@ async def add(self, data: ModelT) -> ModelT: with wrap_sqlalchemy_exception(): instance = await self._attach_to_session(data) await self.session.flush() - await self.session.refresh(instance) + await self._refresh(instance) self._expunge(instance) return instance @@ -223,13 +222,20 @@ async def get_one_or_none(self, **kwargs: Any) -> ModelT | None: return instance # type: ignore async def get_or_create( - self, match_fields: list[str] | str | None = None, upsert: bool = True, **kwargs: Any + self, + match_fields: list[str] | str | None = None, + upsert: bool = True, + attribute_names: Iterable[str] | None = None, + with_for_update: bool | None = None, + **kwargs: Any, ) -> tuple[ModelT, bool]: """Get instance identified by ``kwargs`` or create if it doesn't exist. Args: match_fields: a list of keys to use to match the existing model. When empty, all fields are matched. upsert: When using match_fields and actual model values differ from `kwargs`, perform an update operation on the model. + attribute_names: an iterable of attribute names to pass into the ``update`` method. + with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT **kwargs: Identifier of the instance to be retrieved. Returns: @@ -256,7 +262,7 @@ async def get_or_create( setattr(existing, field_name, new_field_value) existing = await self._attach_to_session(existing, strategy="merge") await self.session.flush() - await self.session.refresh(existing) + await self._refresh(existing, attribute_names=attribute_names, with_for_update=with_for_update) self._expunge(existing) return existing, False @@ -280,19 +286,26 @@ async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: results = await self._execute(statement) return results.scalar_one() # type: ignore - async def update(self, data: ModelT) -> ModelT: + async def update( + self, + data: ModelT, + attribute_names: Iterable[str] | None = None, + with_for_update: bool | None = None, + ) -> ModelT: """Update instance with the attribute values present on `data`. Args: data: An instance that should have a value for `self.id_attribute` that exists in the collection. - + attribute_names: an iterable of attribute names to pass into the ``update`` method. + with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT Returns: The updated instance. Raises: NotFoundError: If no instance found with same identifier as `data`. """ + with wrap_sqlalchemy_exception(): item_id = self.get_id_attribute_value(data) # this will raise for not found, and will put the item in the session @@ -300,7 +313,7 @@ async def update(self, data: ModelT) -> ModelT: # this will merge the inbound data to the instance we just put in the session instance = await self._attach_to_session(data, strategy="merge") await self.session.flush() - await self.session.refresh(instance) + await self._refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update) self._expunge(instance) return instance @@ -361,6 +374,19 @@ async def list_and_count( return await self._list_and_count_basic(*filters, **kwargs) return await self._list_and_count_window(*filters, **kwargs) + def _expunge(self, instance: ModelT) -> None: + if self.auto_expunge: + self.session.expunge(instance) + + async def _refresh( + self, + instance: ModelT, + attribute_names: Iterable[str] | None = None, + with_for_update: bool | None = None, + ) -> None: + if self.auto_refresh: + await self.session.refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update) + async def _list_and_count_window( self, *filters: FilterTypes, @@ -442,7 +468,12 @@ async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]: self._expunge(instance) return instances - async def upsert(self, data: ModelT) -> ModelT: + async def upsert( + self, + data: ModelT, + attribute_names: Iterable[str] | None = None, + with_for_update: bool | None = None, + ) -> ModelT: """Update or create instance. Updates instance with the attribute values present on `data`, or creates a new instance if @@ -452,7 +483,8 @@ async def upsert(self, data: ModelT) -> ModelT: data: Instance to update existing, or be created. Identifier used to determine if an existing instance exists is the value of an attribute on `data` named as value of `self.id_attribute`. - + attribute_names: an iterable of attribute names to pass into the ``update`` method. + with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT Returns: The updated or created instance. @@ -462,7 +494,7 @@ async def upsert(self, data: ModelT) -> ModelT: with wrap_sqlalchemy_exception(): instance = await self._attach_to_session(data, strategy="merge") await self.session.flush() - await self.session.refresh(instance) + await self._refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update) self._expunge(instance) return instance diff --git a/litestar/contrib/sqlalchemy/repository/_sync.py b/litestar/contrib/sqlalchemy/repository/_sync.py index 7df202e743..9382e4413d 100644 --- a/litestar/contrib/sqlalchemy/repository/_sync.py +++ b/litestar/contrib/sqlalchemy/repository/_sync.py @@ -2,7 +2,7 @@ # litestar/contrib/sqlalchemy/repository/_async.py from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, Literal, cast +from typing import TYPE_CHECKING, Any, Generic, Iterable, Literal, cast from sqlalchemy import Result, Select, delete, over, select, text, update from sqlalchemy import func as sql_func @@ -38,7 +38,8 @@ def __init__( *, statement: Select[tuple[ModelT]] | None = None, session: Session, - expunge: bool = False, + auto_expunge: bool = False, + auto_refresh: bool = False, **kwargs: Any, ) -> None: """Repository pattern for SQLAlchemy models. @@ -46,12 +47,14 @@ def __init__( Args: statement: To facilitate customization of the underlying select query. session: Session managing the unit-of-work for the operation. - expunge: Remove object from session before returning. + auto_expunge: Remove object from session before returning. + auto_refresh: Refresh object from session before returning. **kwargs: Additional arguments. """ super().__init__(**kwargs) - self.expunge = expunge + self.auto_expunge = auto_expunge + self.auto_refresh = auto_refresh self.session = session self.statement = statement if statement is not None else select(self.model_type) if not self.session.bind: @@ -60,10 +63,6 @@ def __init__( raise ValueError("Session improperly configure") self._dialect = self.session.bind.dialect - def _expunge(self, instance: ModelT) -> None: - if self.expunge: - self.session.expunge(instance) - def add(self, data: ModelT) -> ModelT: """Add `data` to the collection. @@ -76,7 +75,7 @@ def add(self, data: ModelT) -> ModelT: with wrap_sqlalchemy_exception(): instance = self._attach_to_session(data) self.session.flush() - self.session.refresh(instance) + self._refresh(instance) self._expunge(instance) return instance @@ -225,13 +224,20 @@ def get_one_or_none(self, **kwargs: Any) -> ModelT | None: return instance # type: ignore def get_or_create( - self, match_fields: list[str] | str | None = None, upsert: bool = True, **kwargs: Any + self, + match_fields: list[str] | str | None = None, + upsert: bool = True, + attribute_names: Iterable[str] | None = None, + with_for_update: bool | None = None, + **kwargs: Any, ) -> tuple[ModelT, bool]: """Get instance identified by ``kwargs`` or create if it doesn't exist. Args: match_fields: a list of keys to use to match the existing model. When empty, all fields are matched. upsert: When using match_fields and actual model values differ from `kwargs`, perform an update operation on the model. + attribute_names: an iterable of attribute names to pass into the ``update`` method. + with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT **kwargs: Identifier of the instance to be retrieved. Returns: @@ -258,7 +264,7 @@ def get_or_create( setattr(existing, field_name, new_field_value) existing = self._attach_to_session(existing, strategy="merge") self.session.flush() - self.session.refresh(existing) + self._refresh(existing, attribute_names=attribute_names, with_for_update=with_for_update) self._expunge(existing) return existing, False @@ -282,19 +288,26 @@ def count(self, *filters: FilterTypes, **kwargs: Any) -> int: results = self._execute(statement) return results.scalar_one() # type: ignore - def update(self, data: ModelT) -> ModelT: + def update( + self, + data: ModelT, + attribute_names: Iterable[str] | None = None, + with_for_update: bool | None = None, + ) -> ModelT: """Update instance with the attribute values present on `data`. Args: data: An instance that should have a value for `self.id_attribute` that exists in the collection. - + attribute_names: an iterable of attribute names to pass into the ``update`` method. + with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT Returns: The updated instance. Raises: NotFoundError: If no instance found with same identifier as `data`. """ + with wrap_sqlalchemy_exception(): item_id = self.get_id_attribute_value(data) # this will raise for not found, and will put the item in the session @@ -302,7 +315,7 @@ def update(self, data: ModelT) -> ModelT: # this will merge the inbound data to the instance we just put in the session instance = self._attach_to_session(data, strategy="merge") self.session.flush() - self.session.refresh(instance) + self._refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update) self._expunge(instance) return instance @@ -363,6 +376,19 @@ def list_and_count( return self._list_and_count_basic(*filters, **kwargs) return self._list_and_count_window(*filters, **kwargs) + def _expunge(self, instance: ModelT) -> None: + if self.auto_expunge: + self.session.expunge(instance) + + def _refresh( + self, + instance: ModelT, + attribute_names: Iterable[str] | None = None, + with_for_update: bool | None = None, + ) -> None: + if self.auto_refresh: + self.session.refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update) + def _list_and_count_window( self, *filters: FilterTypes, @@ -444,7 +470,12 @@ def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]: self._expunge(instance) return instances - def upsert(self, data: ModelT) -> ModelT: + def upsert( + self, + data: ModelT, + attribute_names: Iterable[str] | None = None, + with_for_update: bool | None = None, + ) -> ModelT: """Update or create instance. Updates instance with the attribute values present on `data`, or creates a new instance if @@ -454,7 +485,8 @@ def upsert(self, data: ModelT) -> ModelT: data: Instance to update existing, or be created. Identifier used to determine if an existing instance exists is the value of an attribute on `data` named as value of `self.id_attribute`. - + attribute_names: an iterable of attribute names to pass into the ``update`` method. + with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT Returns: The updated or created instance. @@ -464,7 +496,7 @@ def upsert(self, data: ModelT) -> ModelT: with wrap_sqlalchemy_exception(): instance = self._attach_to_session(data, strategy="merge") self.session.flush() - self.session.refresh(instance) + self._refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update) self._expunge(instance) return instance diff --git a/tests/unit/test_contrib/test_sqlalchemy/models_uuid.py b/tests/unit/test_contrib/test_sqlalchemy/models_uuid.py index 89a1e491c4..877094fc4e 100644 --- a/tests/unit/test_contrib/test_sqlalchemy/models_uuid.py +++ b/tests/unit/test_contrib/test_sqlalchemy/models_uuid.py @@ -63,14 +63,14 @@ class UUIDModelWithFetchedValue(UUIDBase): class UUIDItem(UUIDBase): name: Mapped[str] = mapped_column(String(), unique=True) description: Mapped[str | None] - tags: Mapped[list[UUIDTag]] = relationship(secondary=lambda: uuid_item_tag, back_populates="items") + tags: Mapped[list[UUIDTag]] = relationship(secondary=lambda: uuid_item_tag, back_populates="items", lazy="noload") class UUIDTag(UUIDAuditBase): """The event log domain object.""" name: Mapped[str] = mapped_column(String(50), unique=True) - items: Mapped[list[UUIDItem]] = relationship(secondary=lambda: uuid_item_tag, back_populates="tags") + items: Mapped[list[UUIDItem]] = relationship(secondary=lambda: uuid_item_tag, back_populates="tags", lazy="noload") class UUIDRule(UUIDAuditBase):