Skip to content

Commit

Permalink
feat: make expunge configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Jul 1, 2023
1 parent 704db82 commit 2258c3d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 32 deletions.
43 changes: 28 additions & 15 deletions litestar/contrib/sqlalchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,25 @@ class SQLAlchemyAsyncRepository(AbstractAsyncRepository[ModelT], Generic[ModelT]

match_fields: list[str] | str | None = None

def __init__(self, *, statement: Select[tuple[ModelT]] | None = None, session: AsyncSession, **kwargs: Any) -> None:
def __init__(
self,
*,
statement: Select[tuple[ModelT]] | None = None,
session: AsyncSession,
expunge: bool = False,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Args:
statement: To facilitate customization of the underlying select query.
session: Session managing the unit-of-work for the operation.
expunge: Remove object from session before returning.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.expunge = expunge
self.session = session
self.statement = statement if statement is not None else select(self.model_type)
if not self.session.bind:
Expand All @@ -49,6 +58,10 @@ def __init__(self, *, statement: Select[tuple[ModelT]] | None = None, session: A
raise ValueError("Session improperly configure")
self._dialect = self.session.bind.dialect

def _expunge(self, instance: ModelT) -> None:
if self.expunge:
self.session.expunge(instance)

async def add(self, data: ModelT) -> ModelT:
"""Add `data` to the collection.
Expand All @@ -62,7 +75,7 @@ async def add(self, data: ModelT) -> ModelT:
instance = await self._attach_to_session(data)
await self.session.flush()
await self.session.refresh(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

async def add_many(self, data: list[ModelT]) -> list[ModelT]:
Expand All @@ -79,7 +92,7 @@ async def add_many(self, data: list[ModelT]) -> list[ModelT]:
self.session.add_all(data)
await self.session.flush()
for datum in data:
self.session.expunge(datum)
self._expunge(datum)
return data

async def delete(self, item_id: Any) -> ModelT:
Expand All @@ -98,7 +111,7 @@ async def delete(self, item_id: Any) -> ModelT:
instance = await self.get(item_id)
await self.session.delete(instance)
await self.session.flush()
self.session.expunge(instance)
self._expunge(instance)
return instance

async def delete_many(self, item_ids: list[Any]) -> list[ModelT]:
Expand Down Expand Up @@ -135,7 +148,7 @@ async def delete_many(self, item_ids: list[Any]) -> list[ModelT]:
)
await self.session.flush()
for instance in instances:
self.session.expunge(instance)
self._expunge(instance)
return instances

async def exists(self, **kwargs: Any) -> bool:
Expand Down Expand Up @@ -169,7 +182,7 @@ async def get(self, item_id: Any, **kwargs: Any) -> ModelT:
statement = self._filter_select_by_kwargs(statement=statement, **{self.id_attribute: item_id})
instance = (await self._execute(statement)).scalar_one_or_none()
instance = self.check_not_found(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

async def get_one(self, **kwargs: Any) -> ModelT:
Expand All @@ -189,7 +202,7 @@ async def get_one(self, **kwargs: Any) -> ModelT:
statement = self._filter_select_by_kwargs(statement=statement, **kwargs)
instance = (await self._execute(statement)).scalar_one_or_none()
instance = self.check_not_found(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

async def get_one_or_none(self, **kwargs: Any) -> ModelT | None:
Expand All @@ -206,7 +219,7 @@ async def get_one_or_none(self, **kwargs: Any) -> ModelT | None:
statement = self._filter_select_by_kwargs(statement=statement, **kwargs)
instance = (await self._execute(statement)).scalar_one_or_none()
if instance:
self.session.expunge(instance)
self._expunge(instance)
return instance # type: ignore

async def get_or_create(
Expand Down Expand Up @@ -244,7 +257,7 @@ async def get_or_create(
existing = await self._attach_to_session(existing, strategy="merge")
await self.session.flush()
await self.session.refresh(existing)
self.session.expunge(existing)
self._expunge(existing)
return existing, False

async def count(self, *filters: FilterTypes, **kwargs: Any) -> int:
Expand Down Expand Up @@ -288,7 +301,7 @@ async def update(self, data: ModelT) -> ModelT:
instance = await self._attach_to_session(data, strategy="merge")
await self.session.flush()
await self.session.refresh(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

async def update_many(self, data: list[ModelT]) -> list[ModelT]:
Expand Down Expand Up @@ -321,7 +334,7 @@ async def update_many(self, data: list[ModelT]) -> list[ModelT]:
)
await self.session.flush()
for instance in instances:
self.session.expunge(instance)
self._expunge(instance)
return instances
await self.session.execute(
update(self.model_type),
Expand Down Expand Up @@ -371,7 +384,7 @@ async def _list_and_count_window(
count: int = 0
instances: list[ModelT] = []
for i, (instance, count_value) in enumerate(result):
self.session.expunge(instance)
self._expunge(instance)
instances.append(instance)
if i == 0:
count = count_value
Expand Down Expand Up @@ -404,7 +417,7 @@ async def _list_and_count_basic(
result = await self._execute(statement)
instances: list[ModelT] = []
for (instance,) in result:
self.session.expunge(instance)
self._expunge(instance)
instances.append(instance)
return instances, count

Expand All @@ -426,7 +439,7 @@ async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]:
result = await self._execute(statement)
instances = list(result.scalars())
for instance in instances:
self.session.expunge(instance)
self._expunge(instance)
return instances

async def upsert(self, data: ModelT) -> ModelT:
Expand All @@ -450,7 +463,7 @@ async def upsert(self, data: ModelT) -> ModelT:
instance = await self._attach_to_session(data, strategy="merge")
await self.session.flush()
await self.session.refresh(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

def filter_collection_by_kwargs( # type:ignore[override]
Expand Down
43 changes: 28 additions & 15 deletions litestar/contrib/sqlalchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,25 @@ class SQLAlchemySyncRepository(AbstractSyncRepository[ModelT], Generic[ModelT]):

match_fields: list[str] | str | None = None

def __init__(self, *, statement: Select[tuple[ModelT]] | None = None, session: Session, **kwargs: Any) -> None:
def __init__(
self,
*,
statement: Select[tuple[ModelT]] | None = None,
session: Session,
expunge: bool = False,
**kwargs: Any,
) -> None:
"""Repository pattern for SQLAlchemy models.
Args:
statement: To facilitate customization of the underlying select query.
session: Session managing the unit-of-work for the operation.
expunge: Remove object from session before returning.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.expunge = expunge
self.session = session
self.statement = statement if statement is not None else select(self.model_type)
if not self.session.bind:
Expand All @@ -51,6 +60,10 @@ def __init__(self, *, statement: Select[tuple[ModelT]] | None = None, session: S
raise ValueError("Session improperly configure")
self._dialect = self.session.bind.dialect

def _expunge(self, instance: ModelT) -> None:
if self.expunge:
self.session.expunge(instance)

def add(self, data: ModelT) -> ModelT:
"""Add `data` to the collection.
Expand All @@ -64,7 +77,7 @@ def add(self, data: ModelT) -> ModelT:
instance = self._attach_to_session(data)
self.session.flush()
self.session.refresh(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

def add_many(self, data: list[ModelT]) -> list[ModelT]:
Expand All @@ -81,7 +94,7 @@ def add_many(self, data: list[ModelT]) -> list[ModelT]:
self.session.add_all(data)
self.session.flush()
for datum in data:
self.session.expunge(datum)
self._expunge(datum)
return data

def delete(self, item_id: Any) -> ModelT:
Expand All @@ -100,7 +113,7 @@ def delete(self, item_id: Any) -> ModelT:
instance = self.get(item_id)
self.session.delete(instance)
self.session.flush()
self.session.expunge(instance)
self._expunge(instance)
return instance

def delete_many(self, item_ids: list[Any]) -> list[ModelT]:
Expand Down Expand Up @@ -137,7 +150,7 @@ def delete_many(self, item_ids: list[Any]) -> list[ModelT]:
)
self.session.flush()
for instance in instances:
self.session.expunge(instance)
self._expunge(instance)
return instances

def exists(self, **kwargs: Any) -> bool:
Expand Down Expand Up @@ -171,7 +184,7 @@ def get(self, item_id: Any, **kwargs: Any) -> ModelT:
statement = self._filter_select_by_kwargs(statement=statement, **{self.id_attribute: item_id})
instance = (self._execute(statement)).scalar_one_or_none()
instance = self.check_not_found(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

def get_one(self, **kwargs: Any) -> ModelT:
Expand All @@ -191,7 +204,7 @@ def get_one(self, **kwargs: Any) -> ModelT:
statement = self._filter_select_by_kwargs(statement=statement, **kwargs)
instance = (self._execute(statement)).scalar_one_or_none()
instance = self.check_not_found(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

def get_one_or_none(self, **kwargs: Any) -> ModelT | None:
Expand All @@ -208,7 +221,7 @@ def get_one_or_none(self, **kwargs: Any) -> ModelT | None:
statement = self._filter_select_by_kwargs(statement=statement, **kwargs)
instance = (self._execute(statement)).scalar_one_or_none()
if instance:
self.session.expunge(instance)
self._expunge(instance)
return instance # type: ignore

def get_or_create(
Expand Down Expand Up @@ -246,7 +259,7 @@ def get_or_create(
existing = self._attach_to_session(existing, strategy="merge")
self.session.flush()
self.session.refresh(existing)
self.session.expunge(existing)
self._expunge(existing)
return existing, False

def count(self, *filters: FilterTypes, **kwargs: Any) -> int:
Expand Down Expand Up @@ -290,7 +303,7 @@ def update(self, data: ModelT) -> ModelT:
instance = self._attach_to_session(data, strategy="merge")
self.session.flush()
self.session.refresh(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

def update_many(self, data: list[ModelT]) -> list[ModelT]:
Expand Down Expand Up @@ -323,7 +336,7 @@ def update_many(self, data: list[ModelT]) -> list[ModelT]:
)
self.session.flush()
for instance in instances:
self.session.expunge(instance)
self._expunge(instance)
return instances
self.session.execute(
update(self.model_type),
Expand Down Expand Up @@ -373,7 +386,7 @@ def _list_and_count_window(
count: int = 0
instances: list[ModelT] = []
for i, (instance, count_value) in enumerate(result):
self.session.expunge(instance)
self._expunge(instance)
instances.append(instance)
if i == 0:
count = count_value
Expand Down Expand Up @@ -406,7 +419,7 @@ def _list_and_count_basic(
result = self._execute(statement)
instances: list[ModelT] = []
for (instance,) in result:
self.session.expunge(instance)
self._expunge(instance)
instances.append(instance)
return instances, count

Expand All @@ -428,7 +441,7 @@ def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]:
result = self._execute(statement)
instances = list(result.scalars())
for instance in instances:
self.session.expunge(instance)
self._expunge(instance)
return instances

def upsert(self, data: ModelT) -> ModelT:
Expand All @@ -452,7 +465,7 @@ def upsert(self, data: ModelT) -> ModelT:
instance = self._attach_to_session(data, strategy="merge")
self.session.flush()
self.session.refresh(instance)
self.session.expunge(instance)
self._expunge(instance)
return instance

def filter_collection_by_kwargs( # type:ignore[override]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ async def test_lazy_load(item_repo: ItemAsyncRepository, tag_repo: TagAsyncRepos
"tag_names": ["A new tag"],
"id": first_item_id,
}
tags_to_add = await maybe_async(tag_repo.list(CollectionFilter("name", update_data.pop("tag_names", []))))
tags_to_add = await maybe_async(tag_repo.list(CollectionFilter("name", update_data.pop("tag_names", [])))) # type: ignore
assert len(tags_to_add) > 0
assert tags_to_add[0].id is not None
update_data["tags"] = tags_to_add
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ async def test_lazy_load(item_repo: ItemAsyncRepository, tag_repo: TagAsyncRepos
"tag_names": ["A new tag"],
"id": first_item_id,
}
tags_to_add = await maybe_async(tag_repo.list(CollectionFilter("name", update_data.pop("tag_names", []))))
tags_to_add = await maybe_async(tag_repo.list(CollectionFilter("name", update_data.pop("tag_names", [])))) # type: ignore
assert len(tags_to_add) > 0
assert tags_to_add[0].id is not None
update_data["tags"] = tags_to_add
Expand Down

0 comments on commit 2258c3d

Please sign in to comment.