Skip to content

Commit

Permalink
test create doc
Browse files Browse the repository at this point in the history
fixing some issues with the testing set up along the way
  • Loading branch information
raphaellaude committed Jul 21, 2024
1 parent 1a3d360 commit c4d4b31
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 29 deletions.
5 changes: 0 additions & 5 deletions backend/app/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@


def get_url():
database_url = os.getenv("DATABASE_URL", None)

if database_url:
return database_url

user = os.getenv("POSTGRES_USER", "postgres")
password = os.getenv("POSTGRES_PASSWORD", "")
server = os.getenv("POSTGRES_SERVER", "db")
Expand Down
71 changes: 71 additions & 0 deletions backend/app/alembic/versions/6f0559e497ba_assignments_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""assignments table
Revision ID: 6f0559e497ba
Revises: 966d8d72887e
Create Date: 2024-07-21 16:08:29.504177
"""

from typing import Sequence, Union

from alembic import op
from sqlmodel.sql import sqltypes
import sqlalchemy as sa

from app.models import UUIDType


# revision identifiers, used by Alembic.
revision: str = "6f0559e497ba"
down_revision: Union[str, None] = "966d8d72887e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"document",
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("CURRENT_TIMESTAMP"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("CURRENT_TIMESTAMP"),
nullable=False,
),
sa.Column("document_id", UUIDType(), nullable=False),
sa.PrimaryKeyConstraint("document_id"),
sa.UniqueConstraint("document_id"),
)
op.create_table(
"assignments",
sa.Column("document_id", sa.UUID(), nullable=False),
sa.Column("geo_id", sqltypes.AutoString(), nullable=False),
sa.Column("zone", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["document_id"],
["document.document_id"],
),
sa.PrimaryKeyConstraint("document_id", "geo_id"),
postgresql_partition_by="LIST (document_id)",
)
op.alter_column("gerrydbtable", "uuid", existing_type=sa.UUID(), nullable=False)
op.drop_column("gerrydbtable", "id")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"gerrydbtable",
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
)
op.alter_column("gerrydbtable", "uuid", existing_type=sa.UUID(), nullable=True)
op.drop_table("assignments")
op.drop_table("document")
# ### end Alembic commands ###
27 changes: 21 additions & 6 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from sqlalchemy.dialects.postgresql import insert
from typing import List
import logging
from uuid import uuid4

import sentry_sdk
from app.core.db import engine
from app.core.config import settings
from app.models import Assignments, Document
from app.models import Assignments, Document, DocumentPublic

if settings.ENVIRONMENT == "production":
sentry_sdk.init(
Expand Down Expand Up @@ -58,26 +59,40 @@ async def db_is_alive(session: Session = Depends(get_session)):
)


@app.post("/create_document")
@app.post(
"/create_document",
response_model=DocumentPublic,
status_code=status.HTTP_201_CREATED,
)
async def create_document(session: Session = Depends(get_session)):
doc = Document()
# To be created in the database
document_id = str(uuid4().hex).replace("-", "")
print(document_id)
doc = Document.model_validate({"document_id": document_id})
session.add(doc)
session.commit()
session.refresh(doc)
document_id = doc.document_id

if not doc.document_id:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Document creation failed",
)

document_id: str = doc.document_id.replace("-", "")
# Also create the partition in one go.
session.execute(
text(
f"""
CREATE TABLE assignments_{document_id} PARTITION OF assignments
VALUES IN ('{document_id}')
FOR VALUES IN ('{document_id}')
"""
)
)
return doc


@app.post("/update_assignments")
@app.patch("/update_assignments")
async def update_assignments(
assignments: List[Assignments], session: Session = Depends(get_session)
):
Expand Down
21 changes: 13 additions & 8 deletions backend/app/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
from sqlmodel import Field, SQLModel, UUID, TIMESTAMP, text, Column


Expand Down Expand Up @@ -29,22 +30,26 @@ class TimeStampMixin(SQLModel):
)


class GerryDBTableBase(TimeStampMixin, SQLModel):
id: int = Field(default=None, primary_key=True)
class GerryDBTable(TimeStampMixin, SQLModel, table=True):
uuid: str = Field(sa_column=Column(UUIDType, unique=True, primary_key=True))
name: str = Field(nullable=False, unique=True)


class GerryDBTable(GerryDBTableBase, table=True):
uuid: str = Field(sa_column=Column(UUIDType, unique=True))
name: str = Field(nullable=False, unique=True)
class Document(TimeStampMixin, SQLModel, table=True):
document_id: str | None = Field(
sa_column=Column(UUIDType, unique=True, primary_key=True)
)


class Document(TimeStampMixin, SQLModel):
document_id: str | None = Field(sa_column=Column(UUIDType, unique=True))
class DocumentPublic(BaseModel):
document_id: str
created_at: datetime
updated_at: datetime


class Assignments(SQLModel, table=True):
# this is the empty parent table; not a partition itself
document_id: str = Field(foreign_key="document.document_id", primary_key=True)
geo_id: str = Field(primary_key=True)
zone: int
__table_args__ = {"postgres_partition_by": "document_id"}
__table_args__ = {"postgresql_partition_by": "document_id"}
29 changes: 19 additions & 10 deletions backend/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlalchemy import text
from sqlalchemy.exc import OperationalError, ProgrammingError
import subprocess
import uuid


client = TestClient(app)
Expand Down Expand Up @@ -65,18 +66,17 @@ def engine_fixture(request):
except (OperationalError, ProgrammingError):
pass

if ENVIRONMENT != "test":
subprocess.run(["alembic", "upgrade", "head"], check=True, env=my_env)
subprocess.run(["alembic", "upgrade", "head"], check=True, env=my_env)

def teardown():
close_connections_query = f"""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{POSTGRES_TEST_DB}'
AND pid <> pg_backend_pid();
"""
conn.execute(text(close_connections_query))
conn.execute(text(f"DROP DATABASE {POSTGRES_TEST_DB}"))
# close_connections_query = f"""
# SELECT pg_terminate_backend(pg_stat_activity.pid)
# FROM pg_stat_activity
# WHERE pg_stat_activity.datname = '{POSTGRES_TEST_DB}'
# AND pid <> pg_backend_pid();
# """
# conn.execute(text(close_connections_query))
# conn.execute(text(f"DROP DATABASE {POSTGRES_TEST_DB}"))
conn.close()

request.addfinalizer(teardown)
Expand Down Expand Up @@ -109,3 +109,12 @@ def test_db_is_alive(client):
response = client.get("/db_is_alive")
assert response.status_code == 200
assert response.json() == {"message": "DB is alive"}


def test_new_document(client):
print(TEST_SQLALCHEMY_DATABASE_URI)
response = client.post("/create_document")
assert response.status_code == 201
document_id = response.json().get("document_id", None)
assert document_id is not None
assert isinstance(uuid.UUID(document_id), uuid.UUID)

0 comments on commit c4d4b31

Please sign in to comment.