Skip to content

Commit

Permalink
Merge pull request #169 from chaen/contextvar
Browse files Browse the repository at this point in the history
Use a ContextVar to manage the sql connection instance as an instance…
  • Loading branch information
chrisburr authored Nov 1, 2023
2 parents 117e98e + b726d2a commit 904f45a
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions src/diracx/db/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import logging
import os
from abc import ABCMeta
from contextvars import ContextVar
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import TYPE_CHECKING, AsyncIterator, Self
from typing import TYPE_CHECKING, AsyncIterator, Self, cast

from pydantic import parse_obj_as
from sqlalchemy import Column as RawColumn
Expand Down Expand Up @@ -73,7 +74,12 @@ class BaseSQLDB(metaclass=ABCMeta):
metadata: MetaData

def __init__(self, db_url: str) -> None:
self._conn = None
# We use a ContextVar to make sure that self._conn
# is specific to each context, and avoid parallel
# route executions to overlap
self._conn: ContextVar[AsyncConnection | None] = ContextVar(
"_conn", default=None
)
self._db_url = db_url
self._engine: AsyncEngine | None = None

Expand Down Expand Up @@ -121,16 +127,22 @@ def transaction(cls) -> Self:
def engine(self) -> AsyncEngine:
"""The engine to use for database operations.
It is normally not necessary to use the engine directly,
unless you are doing something special, like writing a
test fixture that gives you a db.
Requires that the engine_context has been entered.
"""
assert self._engine is not None, "engine_context must be entered"
return self._engine

@contextlib.asynccontextmanager
async def engine_context(self) -> AsyncIterator[None]:
"""Context manage to manage the engine lifecycle.
Tables are automatically created upon entering
This is called once at the application startup
(see ``lifetime_functions``)
"""
assert self._engine is None, "engine_context cannot be nested"

Expand All @@ -144,19 +156,30 @@ async def engine_context(self) -> AsyncIterator[None]:

@property
def conn(self) -> AsyncConnection:
if self._conn is None:
if self._conn.get() is None:
raise RuntimeError(f"{self.__class__} was used before entering")
return self._conn
return cast(AsyncConnection, self._conn.get())

async def __aenter__(self):
self._conn = await self.engine.connect().__aenter__()
"""
Create a connection.
This is called by the Dependency mechanism (see ``db_transaction``),
It will create a new connection/transaction for each route call.
"""

self._conn.set(await self.engine.connect().__aenter__())
return self

async def __aexit__(self, exc_type, exc, tb):
"""
This is called when exciting a route.
If there was no exception, the changes in the DB are committed.
Otherwise, they are rollbacked.
"""
if exc_type is None:
await self._conn.commit()
await self._conn.__aexit__(exc_type, exc, tb)
self._conn = None
await self._conn.get().commit()
await self._conn.get().__aexit__(exc_type, exc, tb)
self._conn.set(None)


def apply_search_filters(table, stmt, search):
Expand Down

0 comments on commit 904f45a

Please sign in to comment.