Skip to content

Commit

Permalink
feat: exposes refresh attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Jul 1, 2023
1 parent 2258c3d commit d960ab7
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 36 deletions.
66 changes: 49 additions & 17 deletions litestar/contrib/sqlalchemy/repository/_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Generic, Literal, cast
from typing import TYPE_CHECKING, Any, Generic, Iterable, Literal, cast

from sqlalchemy import Result, Select, delete, over, select, text, update
from sqlalchemy import func as sql_func
Expand Down Expand Up @@ -36,20 +36,23 @@ def __init__(
*,
statement: Select[tuple[ModelT]] | None = None,
session: AsyncSession,
expunge: bool = False,
auto_expunge: bool = False,
auto_refresh: bool = False,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Args:
statement: To facilitate customization of the underlying select query.
session: Session managing the unit-of-work for the operation.
expunge: Remove object from session before returning.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.expunge = expunge
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.session = session
self.statement = statement if statement is not None else select(self.model_type)
if not self.session.bind:
Expand All @@ -58,10 +61,6 @@ def __init__(
raise ValueError("Session improperly configure")
self._dialect = self.session.bind.dialect

def _expunge(self, instance: ModelT) -> None:
if self.expunge:
self.session.expunge(instance)

async def add(self, data: ModelT) -> ModelT:
"""Add `data` to the collection.
Expand All @@ -74,7 +73,7 @@ async def add(self, data: ModelT) -> ModelT:
with wrap_sqlalchemy_exception():
instance = await self._attach_to_session(data)
await self.session.flush()
await self.session.refresh(instance)
await self._refresh(instance)
self._expunge(instance)
return instance

Expand Down Expand Up @@ -223,13 +222,20 @@ async def get_one_or_none(self, **kwargs: Any) -> ModelT | None:
return instance # type: ignore

async def get_or_create(
self, match_fields: list[str] | str | None = None, upsert: bool = True, **kwargs: Any
self,
match_fields: list[str] | str | None = None,
upsert: bool = True,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` or create if it doesn't exist.
Args:
match_fields: a list of keys to use to match the existing model. When empty, all fields are matched.
upsert: When using match_fields and actual model values differ from `kwargs`, perform an update operation on the model.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT
**kwargs: Identifier of the instance to be retrieved.
Returns:
Expand All @@ -256,7 +262,7 @@ async def get_or_create(
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
await self.session.flush()
await self.session.refresh(existing)
await self._refresh(existing, attribute_names=attribute_names, with_for_update=with_for_update)
self._expunge(existing)
return existing, False

Expand All @@ -280,27 +286,34 @@ async def count(self, *filters: FilterTypes, **kwargs: Any) -> int:
results = await self._execute(statement)
return results.scalar_one() # type: ignore

async def update(self, data: ModelT) -> ModelT:
async def update(
self,
data: ModelT,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
) -> ModelT:
"""Update instance with the attribute values present on `data`.
Args:
data: An instance that should have a value for `self.id_attribute` that exists in the
collection.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT
Returns:
The updated instance.
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""

with wrap_sqlalchemy_exception():
item_id = self.get_id_attribute_value(data)
# this will raise for not found, and will put the item in the session
await self.get(item_id)
# this will merge the inbound data to the instance we just put in the session
instance = await self._attach_to_session(data, strategy="merge")
await self.session.flush()
await self.session.refresh(instance)
await self._refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update)
self._expunge(instance)
return instance

Expand Down Expand Up @@ -361,6 +374,19 @@ async def list_and_count(
return await self._list_and_count_basic(*filters, **kwargs)
return await self._list_and_count_window(*filters, **kwargs)

def _expunge(self, instance: ModelT) -> None:
if self.auto_expunge:
self.session.expunge(instance)

async def _refresh(
self,
instance: ModelT,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
) -> None:
if self.auto_refresh:
await self.session.refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update)

async def _list_and_count_window(
self,
*filters: FilterTypes,
Expand Down Expand Up @@ -442,7 +468,12 @@ async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]:
self._expunge(instance)
return instances

async def upsert(self, data: ModelT) -> ModelT:
async def upsert(
self,
data: ModelT,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
) -> ModelT:
"""Update or create instance.
Updates instance with the attribute values present on `data`, or creates a new instance if
Expand All @@ -452,7 +483,8 @@ async def upsert(self, data: ModelT) -> ModelT:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT
Returns:
The updated or created instance.
Expand All @@ -462,7 +494,7 @@ async def upsert(self, data: ModelT) -> ModelT:
with wrap_sqlalchemy_exception():
instance = await self._attach_to_session(data, strategy="merge")
await self.session.flush()
await self.session.refresh(instance)
await self._refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update)
self._expunge(instance)
return instance

Expand Down
66 changes: 49 additions & 17 deletions litestar/contrib/sqlalchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# litestar/contrib/sqlalchemy/repository/_async.py
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Generic, Literal, cast
from typing import TYPE_CHECKING, Any, Generic, Iterable, Literal, cast

from sqlalchemy import Result, Select, delete, over, select, text, update
from sqlalchemy import func as sql_func
Expand Down Expand Up @@ -38,20 +38,23 @@ def __init__(
*,
statement: Select[tuple[ModelT]] | None = None,
session: Session,
expunge: bool = False,
auto_expunge: bool = False,
auto_refresh: bool = False,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Args:
statement: To facilitate customization of the underlying select query.
session: Session managing the unit-of-work for the operation.
expunge: Remove object from session before returning.
auto_expunge: Remove object from session before returning.
auto_refresh: Refresh object from session before returning.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.expunge = expunge
self.auto_expunge = auto_expunge
self.auto_refresh = auto_refresh
self.session = session
self.statement = statement if statement is not None else select(self.model_type)
if not self.session.bind:
Expand All @@ -60,10 +63,6 @@ def __init__(
raise ValueError("Session improperly configure")
self._dialect = self.session.bind.dialect

def _expunge(self, instance: ModelT) -> None:
if self.expunge:
self.session.expunge(instance)

def add(self, data: ModelT) -> ModelT:
"""Add `data` to the collection.
Expand All @@ -76,7 +75,7 @@ def add(self, data: ModelT) -> ModelT:
with wrap_sqlalchemy_exception():
instance = self._attach_to_session(data)
self.session.flush()
self.session.refresh(instance)
self._refresh(instance)
self._expunge(instance)
return instance

Expand Down Expand Up @@ -225,13 +224,20 @@ def get_one_or_none(self, **kwargs: Any) -> ModelT | None:
return instance # type: ignore

def get_or_create(
self, match_fields: list[str] | str | None = None, upsert: bool = True, **kwargs: Any
self,
match_fields: list[str] | str | None = None,
upsert: bool = True,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` or create if it doesn't exist.
Args:
match_fields: a list of keys to use to match the existing model. When empty, all fields are matched.
upsert: When using match_fields and actual model values differ from `kwargs`, perform an update operation on the model.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT
**kwargs: Identifier of the instance to be retrieved.
Returns:
Expand All @@ -258,7 +264,7 @@ def get_or_create(
setattr(existing, field_name, new_field_value)
existing = self._attach_to_session(existing, strategy="merge")
self.session.flush()
self.session.refresh(existing)
self._refresh(existing, attribute_names=attribute_names, with_for_update=with_for_update)
self._expunge(existing)
return existing, False

Expand All @@ -282,27 +288,34 @@ def count(self, *filters: FilterTypes, **kwargs: Any) -> int:
results = self._execute(statement)
return results.scalar_one() # type: ignore

def update(self, data: ModelT) -> ModelT:
def update(
self,
data: ModelT,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
) -> ModelT:
"""Update instance with the attribute values present on `data`.
Args:
data: An instance that should have a value for `self.id_attribute` that exists in the
collection.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT
Returns:
The updated instance.
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""

with wrap_sqlalchemy_exception():
item_id = self.get_id_attribute_value(data)
# this will raise for not found, and will put the item in the session
self.get(item_id)
# this will merge the inbound data to the instance we just put in the session
instance = self._attach_to_session(data, strategy="merge")
self.session.flush()
self.session.refresh(instance)
self._refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update)
self._expunge(instance)
return instance

Expand Down Expand Up @@ -363,6 +376,19 @@ def list_and_count(
return self._list_and_count_basic(*filters, **kwargs)
return self._list_and_count_window(*filters, **kwargs)

def _expunge(self, instance: ModelT) -> None:
if self.auto_expunge:
self.session.expunge(instance)

def _refresh(
self,
instance: ModelT,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
) -> None:
if self.auto_refresh:
self.session.refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update)

def _list_and_count_window(
self,
*filters: FilterTypes,
Expand Down Expand Up @@ -444,7 +470,12 @@ def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]:
self._expunge(instance)
return instances

def upsert(self, data: ModelT) -> ModelT:
def upsert(
self,
data: ModelT,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
) -> ModelT:
"""Update or create instance.
Updates instance with the attribute values present on `data`, or creates a new instance if
Expand All @@ -454,7 +485,8 @@ def upsert(self, data: ModelT) -> ModelT:
data: Instance to update existing, or be created. Identifier used to determine if an
existing instance exists is the value of an attribute on `data` named as value of
`self.id_attribute`.
attribute_names: an iterable of attribute names to pass into the ``update`` method.
with_for_update: indicating FOR UPDATE should be used, or may be a dictionary containing flags to indicate a more specific set of FOR UPDATE flags for the SELECT
Returns:
The updated or created instance.
Expand All @@ -464,7 +496,7 @@ def upsert(self, data: ModelT) -> ModelT:
with wrap_sqlalchemy_exception():
instance = self._attach_to_session(data, strategy="merge")
self.session.flush()
self.session.refresh(instance)
self._refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update)
self._expunge(instance)
return instance

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_contrib/test_sqlalchemy/models_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ class UUIDModelWithFetchedValue(UUIDBase):
class UUIDItem(UUIDBase):
name: Mapped[str] = mapped_column(String(), unique=True)
description: Mapped[str | None]
tags: Mapped[list[UUIDTag]] = relationship(secondary=lambda: uuid_item_tag, back_populates="items")
tags: Mapped[list[UUIDTag]] = relationship(secondary=lambda: uuid_item_tag, back_populates="items", lazy="noload")


class UUIDTag(UUIDAuditBase):
"""The event log domain object."""

name: Mapped[str] = mapped_column(String(50), unique=True)
items: Mapped[list[UUIDItem]] = relationship(secondary=lambda: uuid_item_tag, back_populates="tags")
items: Mapped[list[UUIDItem]] = relationship(secondary=lambda: uuid_item_tag, back_populates="tags", lazy="noload")


class UUIDRule(UUIDAuditBase):
Expand Down

0 comments on commit d960ab7

Please sign in to comment.