Skip to content

Commit

Permalink
Merge pull request #47 from alan-turing-institute/lifecyle
Browse files Browse the repository at this point in the history
Switch to fastapi lifespan for startup/shutdown handling
  • Loading branch information
Iain-S authored Jul 29, 2024
2 parents 332feae + 4f2591e commit 4836a23
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 32 deletions.
51 changes: 26 additions & 25 deletions rctab/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""The entrypoint of the FastAPI application."""

import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Final
from typing import Any, AsyncIterator, Callable, Dict, Final

import fastapimsal
import secure
Expand Down Expand Up @@ -34,13 +35,37 @@

templates = Jinja2Templates(directory=Path("rctab/templates"))


@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
"""Handle setup and teardown."""
await database.connect()
settings = get_settings()
logging.basicConfig(level=settings.log_level)
set_log_handler()
if not settings.ignore_whitelist:
logger = logging.getLogger(__name__)
logger.warning(
"Starting server with subscription whitelist: %s", settings.whitelist
)

yield

logger = logging.getLogger(__name__)
logger.warning("Shutting down server...")

logger.info("Disconnecting from database")
await database.disconnect()


app = FastAPI(
title="RCTab API",
description="API for RCTab",
version="0.1.0",
docs_url=None,
redoc_url=None,
openapi_url=None,
lifespan=lifespan,
)

server = secure.Server().set("Secure")
Expand Down Expand Up @@ -76,30 +101,6 @@ async def set_secure_headers(request: Any, call_next: Callable[[Any], Any]) -> A
)


@app.on_event("startup")
async def startup() -> None:
"""Start the server up."""
await database.connect()
settings = get_settings()
logging.basicConfig(level=settings.log_level)
set_log_handler()
if not settings.ignore_whitelist:
logger = logging.getLogger(__name__)
logger.warning(
"Starting server with subscription whitelist: %s", settings.whitelist
)


@app.on_event("shutdown")
async def shutdown() -> None:
"""Shut the server down."""
logger = logging.getLogger(__name__)
logger.warning("Shutting down server...")

logger.info("Disconnecting from database")
await database.disconnect()


@app.exception_handler(UniqueViolationError)
async def unicorn_exception_handler(
_: Request, exc: UniqueViolationError
Expand Down
10 changes: 3 additions & 7 deletions tests/test_routes/test_routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=redefined-outer-name,
import random
from datetime import date, timedelta
from typing import Any, AsyncGenerator, Callable, Coroutine, Dict, Optional, Tuple
from typing import Any, AsyncGenerator, Callable, Coroutine, Optional, Tuple
from unittest.mock import AsyncMock
from uuid import UUID

Expand Down Expand Up @@ -145,15 +145,11 @@ async def create_subscription(

def make_async_execute(
connection: Engine,
) -> Callable[
[VarArg(Tuple[Any, ...]), KwArg(Dict[str, Any])], Coroutine[Any, Any, ResultProxy]
]:
) -> Callable[[VarArg(Any), KwArg(Any)], Coroutine[Any, Any, ResultProxy]]:
"""We need an async function to patch database.execute() with
but connection.execute() is synchronous so make a wrapper for it."""

async def async_execute(
*args: Tuple[Any, ...], **kwargs: Dict[str, Any]
) -> ResultProxy:
async def async_execute(*args: Any, **kwargs: Any) -> ResultProxy:
"""An async wrapper around connection.execute()."""
return connection.execute(*args, **kwargs) # type: ignore

Expand Down

0 comments on commit 4836a23

Please sign in to comment.