From 300652966e5cf568f8ebc15a15fcf350bb387ad2 Mon Sep 17 00:00:00 2001 From: Richie Cahill Date: Sun, 11 Aug 2024 12:46:54 +0000 Subject: [PATCH] created BaseLookupTable --- chorus_thing/database.py | 70 ++++++++++++++++++++++++++++++++++++---- pyproject.toml | 3 ++ 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/chorus_thing/database.py b/chorus_thing/database.py index 218de27..fd159ff 100644 --- a/chorus_thing/database.py +++ b/chorus_thing/database.py @@ -1,26 +1,84 @@ """tests.""" +import logging from datetime import UTC, datetime +from typing import Self, cast -from sqlalchemy import MetaData +from sqlalchemy import MetaData, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import AbstractConcreteBase -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, object_session -class ChorusThing(DeclarativeBase): +def safe_object_session(class_: object) -> Session: + """Get a session for a class.""" + if session := object_session(class_): + return session + error = f"No session found for {class_}" + raise RuntimeError(error) + + +class Chorusitem(DeclarativeBase): """A base class for contacts.""" - metadata = MetaData(schema="chorus_thing") + metadata = MetaData(schema="chorus_item") -class BaseTable(AbstractConcreteBase, ChorusThing): - """A base class for tables.""" +class BassColumns: + """Base columns.""" id: Mapped[int] = mapped_column(primary_key=True) created: Mapped[datetime] = mapped_column(default=datetime.now(tz=UTC)) modified: Mapped[datetime] = mapped_column(default=datetime.now(tz=UTC), onupdate=datetime.now(tz=UTC)) +class BaseTable(BassColumns, AbstractConcreteBase, Chorusitem): + """A base class for tables.""" + + +class BaseLookupTable(BassColumns, AbstractConcreteBase, Chorusitem): + """A lookup table.""" + + name: Mapped[str] = mapped_column(primary_key=True) + + @classmethod + def get(class_, name: str) -> Self | None: + """Get a lookup table by name. + + Args: + name (str): The name of the lookup table. + + Returns: + BaseLookupTable | None: The lookup table, or None if not found. + """ + return cast(Session, object_session(class_)).scalars(select(class_).where(class_.name == name)).one_or_none() + + @classmethod + def add(class_, name: str) -> Self: + """Add a lookup table. + + Args: + name (str): The name of the lookup table. + + Returns: + BaseLookupTable: The lookup table. + """ + if item := class_.get(name): + return item + session = safe_object_session(class_) + try: + item = class_(name=name) + session.add(item) + session.commit() + except IntegrityError: + if item := class_.get(name): + msg = f"Duplicate item in lookup table {name}" + logging.info(msg) + return item + raise + return item + + class Event(BaseTable): """Table of Choral Events.""" diff --git a/pyproject.toml b/pyproject.toml index 560d898..98f8b53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,9 @@ lint.ignore = [ "ISC001", # (TEMP) conflicts when used with the formatter ] +[tool.ruff.lint.pep8-naming] +extend-ignore-names = ["class_"] + [tool.ruff.lint.per-file-ignores] "tests/**" = [