Skip to content

Commit

Permalink
database = use private type variables (with leading underscore)
Browse files Browse the repository at this point in the history
  • Loading branch information
MatteoCampinoti94 committed Oct 25, 2024
1 parent 42db98e commit e26d597
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 56 deletions.
24 changes: 13 additions & 11 deletions acacore/database/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,35 @@
from .column import ColumnSpec
from .column import SQLValue

M = TypeVar("M", bound=BaseModel)
_M = TypeVar("_M", bound=BaseModel)


class Cursor(Generic[M]):
def __init__(self, cursor: SQLiteCursor, model: Type[M], columns: list[ColumnSpec]) -> None:
class Cursor(Generic[_M]):
def __init__(self, cursor: SQLiteCursor, model: Type[_M], columns: list[ColumnSpec]) -> None:
self.cursor: SQLiteCursor[Row] = cursor
self.cursor.row_factory = Row
self.model: Type[M] = model
self.model: Type[_M] = model
self.columns: list[ColumnSpec] = columns
self._cols: dict[str, Callable[[SQLValue], Any]] = {c.name: c.from_sql for c in columns}
self._row: Callable[[Row], M] = lambda r: self.model.model_validate({k: f(r[k]) for k, f in self._cols.items()})
self._row: Callable[[Row], _M] = lambda r: self.model.model_validate(
{k: f(r[k]) for k, f in self._cols.items()}
)

@property
def rows(self) -> Generator[M, None, None]:
def rows(self) -> Generator[_M, None, None]:
return (self._row(row) for row in self.cursor)

def __iter__(self) -> Generator[M, None, None]:
def __iter__(self) -> Generator[_M, None, None]:
yield from self.rows

def __next__(self) -> M:
def __next__(self) -> _M:
return next(self.rows)

def fetchone(self) -> M | None:
def fetchone(self) -> _M | None:
return next(self.rows, None)

def fetchmany(self, size: int) -> list[M]:
def fetchmany(self, size: int) -> list[_M]:
return list(islice(self.rows, size))

def fetchall(self) -> list[M]:
def fetchall(self) -> list[_M]:
return list(self.rows)
8 changes: 4 additions & 4 deletions acacore/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .table_view import View

_M = TypeVar("_M", bound=BaseModel)
P: TypeAlias = Sequence[SQLValue] | Mapping[str, SQLValue]
_P: TypeAlias = Sequence[SQLValue] | Mapping[str, SQLValue]


class Database:
Expand Down Expand Up @@ -56,12 +56,12 @@ def __exit__(self, _exc_type: Type[BaseException], _exc_val: BaseException, _exc
def execute(self, sql: str, /) -> SQLiteCursor: ...

@overload
def execute(self, sql: str, parameters: P, /) -> SQLiteCursor: ...
def execute(self, sql: str, parameters: _P, /) -> SQLiteCursor: ...

def execute(self, sql: str, parameters: P | None = None, /) -> SQLiteCursor:
def execute(self, sql: str, parameters: _P | None = None, /) -> SQLiteCursor:
return self.connection.execute(sql, parameters or [])

def executemany(self, sql: str, parameters: Iterable[P], /) -> SQLiteCursor:
def executemany(self, sql: str, parameters: Iterable[_P], /) -> SQLiteCursor:
return self.connection.executemany(sql, parameters)

def commit(self):
Expand Down
34 changes: 17 additions & 17 deletions acacore/database/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from .column import SQLValue
from .cursor import Cursor

M = TypeVar("M", bound=BaseModel)
_Where: TypeAlias = str | dict[str, SQLValue | list[SQLValue]]
_M = TypeVar("_M", bound=BaseModel)
_W: TypeAlias = str | dict[str, SQLValue | list[SQLValue]]


def _where_dict_to_sql(where: dict[str, SQLValue | list[SQLValue]]) -> tuple[str, list[SQLValue]]:
Expand All @@ -43,7 +43,7 @@ def _where_dict_to_sql(where: dict[str, SQLValue | list[SQLValue]]) -> tuple[str


def _where_to_sql(
where: _Where | BaseModel,
where: _W | BaseModel,
params: list[SQLValue] | None,
primary_keys: list[ColumnSpec],
) -> tuple[str, list[SQLValue]]:
Expand All @@ -63,18 +63,18 @@ def _where_to_sql(
return where.strip(), params if where else []


class Table(Generic[M]):
class Table(Generic[_M]):
def __init__(
self,
database: Connection,
model: Type[M],
model: Type[_M],
name: str,
primary_keys: list[str] | None = None,
indices: dict[str, list[str]] | None = None,
ignore: list[str] | None = None,
) -> None:
self.database: Connection = database
self.model: Type[M] = model
self.model: Type[_M] = model
self.name: str = name
self.columns: dict[str, ColumnSpec] = {c.name: c for c in ColumnSpec.from_model(self.model, ignore)}

Expand All @@ -97,25 +97,25 @@ def __init__(
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r}, {self.model.__name__})"

def __iter__(self) -> Generator[M, None, None]:
def __iter__(self) -> Generator[_M, None, None]:
yield from self.select()

def __len__(self) -> int:
return self.database.execute(f"select count(*) from {self.name}").fetchone()[0]

def __getitem__(self, where: _Where | M) -> M | None:
def __getitem__(self, where: _W | _M) -> _M | None:
return self.select(where, limit=1).fetchone()

def __setitem__(self, where: _Where | M | slice, row: M) -> None:
def __setitem__(self, where: _W | _M | slice, row: _M) -> None:
if isinstance(where, slice):
self.insert(row)
else:
self.update(row, where)

def __delitem__(self, where: _Where | M) -> None:
def __delitem__(self, where: _W | _M) -> None:
self.delete(where)

def __contains__(self, where: M) -> bool:
def __contains__(self, where: _M) -> bool:
return self.select(where, limit=1).cursor.fetchone() is not None

def create_sql(self, *, exist_ok: bool = False) -> str:
Expand Down Expand Up @@ -148,12 +148,12 @@ def create(self, *, exist_ok: bool = False) -> Self:

def select(
self,
where: _Where | M | None = None,
where: _W | _M | None = None,
params: list[SQLValue] | None = None,
order_by: list[tuple[str, str]] | None = None,
limit: int | None = None,
offset: int | None = None,
) -> Cursor[M]:
) -> Cursor[_M]:
where, params = _where_to_sql(where, params, self.primary_keys)

sql: list[str] = [f"select * from {self.name}"]
Expand All @@ -171,7 +171,7 @@ def select(

return Cursor(self.database.execute(" ".join(sql), params), self.model, list(self.columns.values()))

def insert(self, *rows: M, on_exists: Literal["ignore", "replace", "error"] = "error") -> int:
def insert(self, *rows: _M, on_exists: Literal["ignore", "replace", "error"] = "error") -> int:
cols: list[ColumnSpec] = list(self.columns.values())
sql: list[str] = ["insert"]

Expand All @@ -187,10 +187,10 @@ def insert(self, *rows: M, on_exists: Literal["ignore", "replace", "error"] = "e
(tuple(c.to_sql(getattr(row, c.name)) for c in cols) for row in rows),
).rowcount

def upsert(self, *rows: M) -> int:
def upsert(self, *rows: _M) -> int:
return self.insert(*rows, on_exists="replace")

def update(self, row: M, where: _Where | M = None, params: list[SQLValue] | None = None) -> int:
def update(self, row: _M, where: _W | _M = None, params: list[SQLValue] | None = None) -> int:
where, params = _where_to_sql(where or row, params, self.primary_keys)

if not where:
Expand All @@ -203,7 +203,7 @@ def update(self, row: M, where: _Where | M = None, params: list[SQLValue] | None
[*[c.to_sql(getattr(row, c.name)) for c in cols], *params],
).rowcount

def delete(self, where: _Where | M) -> int:
def delete(self, where: _W | _M) -> int:
where, params = _where_to_sql(where, [], self.primary_keys)

if not where:
Expand Down
16 changes: 7 additions & 9 deletions acacore/database/table_keyvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,22 @@
from typing import Generic
from typing import overload
from typing import Type
from typing import TypeVar

from pydantic import BaseModel

from .table import _M
from .table import Table

M = TypeVar("M", bound=BaseModel)


class KeysTableModel(BaseModel):
key: str
value: object | None


class KeysTable(Generic[M]):
def __init__(self, database: Connection, model: Type[M], name: str) -> None:
class KeysTable(Generic[_M]):
def __init__(self, database: Connection, model: Type[_M], name: str) -> None:
self.table: Table[KeysTableModel] = Table(database, KeysTableModel, name)
self.model: Type[M] = model
self.model: Type[_M] = model

@property
def name(self) -> str:
Expand Down Expand Up @@ -53,16 +51,16 @@ def create(self, *, exist_ok: bool = False):
self.table.create(exist_ok=exist_ok)
return self

def set(self, obj: M):
def set(self, obj: _M):
self.table.insert(*(KeysTableModel(key=k, value=o) for k, o in obj.model_dump().items()), on_exists="replace")

@overload
def get(self) -> M | None: ...
def get(self) -> _M | None: ...

@overload
def get(self, key: str) -> Any | None: ... # noqa: ANN401

def get(self, key: str | None = None) -> M | Any | None:
def get(self, key: str | None = None) -> _M | Any | None:
if key is not None and key not in self.model.model_fields:
raise AttributeError(f"{self.model.__name__!r} object has no attribute {key!r}")

Expand Down
26 changes: 11 additions & 15 deletions acacore/database/table_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,42 @@
from typing import Generic
from typing import Self
from typing import Type
from typing import TypeVar

from pydantic import BaseModel

from .column import SQLValue
from .cursor import Cursor
from .table import _Where
from .table import _M
from .table import _W
from .table import Table

M = TypeVar("M", bound=BaseModel)


class View(Generic[M]):
class View(Generic[_M]):
def __init__(
self,
database: Connection,
model: Type[M],
model: Type[_M],
name: str,
select: str,
ignore: list[str] | None = None,
) -> None:
self.database: Connection = database
self.model: Type[M] = model
self.model: Type[_M] = model
self.name: str = name
self.select_stmt: str = select
self._table: Table[M] = Table(self.database, self.model, self.name, ignore=ignore)
self._table: Table[_M] = Table(self.database, self.model, self.name, ignore=ignore)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r}, {self.model.__name__})"

def __iter__(self) -> Generator[M, None, None]:
def __iter__(self) -> Generator[_M, None, None]:
yield from self.select()

def __len__(self) -> int:
return len(self._table)

def __getitem__(self, where: _Where | M) -> M | None:
def __getitem__(self, where: _W | _M) -> _M | None:
return self._table.select(where, limit=1).fetchone()

def __contains__(self, where: M) -> bool:
def __contains__(self, where: _M) -> bool:
return self._table.select(where, limit=1).cursor.fetchone() is not None

def create_sql(self, *, exist_ok: bool = False) -> str:
Expand All @@ -54,10 +50,10 @@ def create(self, *, exist_ok: bool = False) -> Self:

def select(
self,
where: _Where | None = None,
where: _W | None = None,
params: list[SQLValue] | None = None,
order_by: list[tuple[str, str]] | None = None,
limit: int | None = None,
offset: int | None = None,
) -> Cursor[M]:
) -> Cursor[_M]:
return self._table.select(where, params, order_by, limit, offset)

0 comments on commit e26d597

Please sign in to comment.