From e5855130f3209f15c75f966149db4cfcab3f5e13 Mon Sep 17 00:00:00 2001 From: Matteo Campinoti Date: Thu, 17 Oct 2024 19:28:41 +0200 Subject: [PATCH] database - rewrite database from scratch to be simpler and easier to use --- acacore/database/__init__.py | 7 +- acacore/database/base.py | 994 ----------------------------- acacore/database/column.py | 466 ++------------ acacore/database/cursor.py | 46 ++ acacore/database/database.py | 96 +++ acacore/database/files_db.py | 308 ++------- acacore/database/table.py | 209 ++++++ acacore/database/table_keyvalue.py | 75 +++ acacore/database/table_view.py | 59 ++ acacore/database/upgrade.py | 18 +- 10 files changed, 620 insertions(+), 1658 deletions(-) delete mode 100644 acacore/database/base.py create mode 100644 acacore/database/cursor.py create mode 100644 acacore/database/database.py create mode 100644 acacore/database/table.py create mode 100644 acacore/database/table_keyvalue.py create mode 100644 acacore/database/table_view.py diff --git a/acacore/database/__init__.py b/acacore/database/__init__.py index 79429356..cc94c01a 100644 --- a/acacore/database/__init__.py +++ b/acacore/database/__init__.py @@ -1,6 +1 @@ -from . import upgrade -from .base import Column -from .base import model_to_columns -from .base import SelectColumn -from .base import Table -from .files_db import FileDB +from .files_db import FilesDB diff --git a/acacore/database/base.py b/acacore/database/base.py deleted file mode 100644 index d7749df2..00000000 --- a/acacore/database/base.py +++ /dev/null @@ -1,994 +0,0 @@ -from json import dumps -from json import loads -from os import PathLike -from pathlib import Path -from sqlite3 import Connection -from sqlite3 import Cursor as SQLiteCursor -from sqlite3 import DatabaseError -from sqlite3 import OperationalError -from types import TracebackType -from typing import Any -from typing import Generator -from typing import Generic -from typing import Iterator -from typing import Optional -from typing import overload -from typing import Sequence -from typing import Type -from typing import TypeVar - -from pydantic.main import BaseModel - -from acacore.utils.functions import or_none - -from .column import Column -from .column import dump_object -from .column import Index -from .column import model_to_columns -from .column import model_to_indices -from .column import SelectColumn -from .column import SQLValue - -M = TypeVar("M", bound=BaseModel) - - -class Cursor: - def __init__( - self, - cursor: SQLiteCursor, - columns: list[Column | SelectColumn], - table: Optional["Table"] = None, - ) -> None: - """ - A wrapper class for an SQLite cursor that returns its results as dicts (or objects). - - :param cursor: An SQLite cursor from a select transaction. - :param columns: A list of columns to use to convert the tuples returned by the cursor. - :param table: Optionally, the Table from which on which the select transaction was executed, defaults to - None. - """ - self.cursor: SQLiteCursor = cursor - self.columns: list[Column | SelectColumn] = columns - self.table: Table | None = table - - def __iter__(self) -> Generator[dict[str, Any], None, None]: - return self.fetchall() - - def __next__(self) -> dict[str, Any] | None: - return self.fetchone() - - def fetchalltuples(self) -> Generator[tuple, None, None]: - """ - Fetch all the results from the cursor as tuples and convert the data using the given columns. - - :return: A generator for the tuples in the cursor. - """ - return (tuple(c.from_entry(v) for c, v in zip(self.columns, vs)) for vs in self.cursor.fetchall()) - - def fetchonetuple(self) -> tuple | None: - """ - Fetch one result from the cursor as tuples and convert the data using the given columns. - - :return: A single tuple from the cursor. - """ - vs: tuple = self.cursor.fetchone() - - return tuple(c.from_entry(v) for c, v in zip(self.columns, vs)) if vs else None - - @overload - def fetchall(self) -> Generator[dict[str, Any], None, None]: ... - - @overload - def fetchall(self, model: Type[M]) -> Generator[M, None, None]: ... - - def fetchall(self, model: Type[M] | None = None) -> Generator[dict[str, Any] | M, None, None]: - """ - Fetch all results from the cursor and return them as dicts, with the columns' names/aliases used as keys. - - :param model: Optionally, a pydantic.BaseModel class to use instead of a dict, defaults to None. - :return: A generator for converted dicts (or models). - """ - select_columns: list[SelectColumn] = [SelectColumn.from_column(c) for c in self.columns] - - if model: - return ( - model.model_validate( - {c.alias or c.name: c.from_entry(v) for c, v in zip(select_columns, vs)}, - ) - for vs in self.cursor - ) - - return ({c.alias or c.name: c.from_entry(v) for c, v in zip(select_columns, vs)} for vs in self.cursor) - - @overload - def fetchmany(self, size: int) -> Generator[dict[str, Any], None, None]: ... - - @overload - def fetchmany(self, size: int, model: Type[M]) -> Generator[M, None, None]: ... - - def fetchmany(self, size: int, model: Type[M] | None = None) -> Generator[dict[str, Any] | M, None, None]: - """ - Fetch `size` results from the cursor and return them as dicts, with the columns' names/aliases used as keys. - - :param size: The amount of results to fetch. - :param model: Optionally, a pydantic.BaseModel class to use instead of a dict, defaults to None. - :return: A generator for converted dicts (or models). - """ - select_columns: list[SelectColumn] = [SelectColumn.from_column(c) for c in self.columns] - - if model: - return ( - model.model_validate( - {c.alias or c.name: c.from_entry(v) for c, v in zip(select_columns, vs)}, - ) - for vs in self.cursor.fetchmany(size) - ) - - return ( - {c.alias or c.name: c.from_entry(v) for c, v in zip(select_columns, vs)} - for vs in self.cursor.fetchmany(size) - ) - - @overload - def fetchone(self) -> dict[str, Any] | None: ... - - @overload - def fetchone(self, model: Type[M]) -> M | None: ... - - def fetchone(self, model: Type[M] | None = None) -> dict[str, Any] | M | None: - """ - Fetch one result from the cursor and return it as a dict, with the columns' names/aliases used as keys. - - :param model: Optionally, a pydantic.BaseModel class to use instead of a dict, defaults to None. - :return: A single dict (or model) if the cursor is not exhausted, otherwise None. - """ - select_columns: list[SelectColumn] = [SelectColumn.from_column(c) for c in self.columns] - vs: tuple = self.cursor.fetchone() - - if vs is None: - return None - - entry: dict[str, Any] = {c.name: c.from_entry(v) for c, v in zip(select_columns, vs)} - - return model.model_validate(entry) if model else entry - - -class ModelCursor(Cursor, Generic[M]): - def __init__( - self, - cursor: SQLiteCursor, - model: Type[M], - table: Optional["Table"] = None, - ) -> None: - """ - A wrapper class for an SQLite cursor that returns its results as model objects. - - :param cursor: An SQLite cursor from a select transaction. - :param model: A model representing the objects in the cursor. - :param table: Optionally, the Table from which on which the select transaction was executed, defaults to - None. - """ - super().__init__(cursor, model_to_columns(model), table) - self.model: Type[M] = model - - def __iter__(self) -> Generator[M, None, None]: - return self.fetchall() - - def __next__(self) -> M | None: - return self.fetchone() - - def fetchall(self, model: Type[M] | None = None) -> Generator[M, None, None]: - """ - Fetch all results from the cursor and return them as model objects. - - :param model: Optionally, a different pydantic.BaseModel class to use instead of the one in the ModelCursor, - defaults to None. - :return: A generator for converted objects. - """ - return super().fetchall(model or self.model) - - def fetchmany(self, size: int, model: Type[M] | None = None) -> Generator[dict[str, Any] | M, None, None]: - """ - Fetch `size` results from the cursor and return them as model objects. - - :param size: The amount of results to fetch. - :param model: Optionally, a different pydantic.BaseModel class to use instead of the one in the ModelCursor, - defaults to None. - :return: A generator for converted objects. - """ - return super().fetchmany(size, model or self.model) - - def fetchone(self, model: Type[M] | None = None) -> M | None: - """ - Fetch one result from the cursor and return it as model object. - - :param model: Optionally, a different pydantic.BaseModel class to use instead of the one in the ModelCursor, - defaults to None. - :return: A single object if the cursor is not exhausted, otherwise None. - """ - return super().fetchone(model or self.model) - - -# noinspection SqlNoDataSourceInspection -class Table: - def __init__( - self, - connection: "FileDBBase", - name: str, - columns: list[Column], - indices: list[Index] | None = None, - ) -> None: - """ - A class that holds information about a table. - - :param connection: A FileDBBase object connected to the database the table belongs to. - :param name: The name of the table. - :param columns: The columns of the table. - :param indices: The indices to create in the table, defaults to None. - """ - self.connection: FileDBBase = connection - self.name: str = name - self.columns: list[Column] = columns - self.indices: list[Index] = indices or [] - - def __repr__(self) -> str: - return f'{self.__class__.__name__}("{self.name}")' - - def __len__(self) -> int: - return self.connection.execute(f"select count(*) from {self.name}").fetchone()[0] - - def __iter__(self) -> Generator[dict[str, Any], None, None]: - return self.select().fetchall() - - @property - def keys(self) -> list[Column]: - """ - The list of PRIMARY KEY columns in the table. - - :return: A list of Column objects whose `primary_key` field is set to True. - """ - return [c for c in self.columns if c.primary_key] - - def create_statement(self, exist_ok: bool = True) -> str: - """ - Generate the expression that creates the table. - - :param exist_ok: True if existing tables with the same name should be ignored, defaults to True. - :return: A CREATE TABLE expression. - """ - elements: list[str] = ["create table"] - - if exist_ok: - elements.append("if not exists") - - elements.append(self.name) - - if self.columns: - columns_elements: list[str] = [] - for column in self.columns: - columns_elements.append(column.create_statement()) - if self.keys: - columns_elements.append(f"primary key ({','.join(c.name for c in self.keys)})") - elements.append(f"({','.join(columns_elements)})") - - return " ".join(elements) - - def create(self, exist_ok: bool = True): - self.connection.execute(self.create_statement(exist_ok)) - for index in self.indices: - self.connection.execute(index.create_statement(self.name, exist_ok)) - - def select( - self, - columns: list[Column | SelectColumn] | None = None, - where: str | None = None, - order_by: list[tuple[str | Column, str]] | None = None, - offset: int | None = None, - limit: int | None = None, - parameters: list[Any | None] | None = None, - ) -> Cursor: - """ - Select entries from the table. - - :param columns: A list of columns to be selected, defaults to None. - :param where: A WHERE expression, defaults to None. - :param order_by: A list tuples containing one column (either as Column or string) and a sorting direction - ("ASC", or "DESC"), defaults to None. - :param offset: The number of rows skip, defaults to None. - :param limit: The number of rows to limit the results to, defaults to None. - :param parameters: Values to substitute in the SELECT expression, both in the `where` and SelectColumn - statements, defaults to None. - :return: A Cursor object wrapping the SQLite cursor returned by the SELECT transaction. - """ - columns = columns or self.columns - parameters = parameters or [] - - assert columns, "Columns cannot be empty" - - select_columns: list[SelectColumn] = [SelectColumn.from_column(c) for c in columns] - - select_names = [f"{c.name} as {c.alias}" if c.alias else c.name for c in select_columns] - - statement: str = f"SELECT {','.join(select_names)} FROM {self.name}" - - if where: - statement += f" WHERE {where}" - - if order_by: - order_statements = [f"{c.name if isinstance(c, Column) else c} {s}" for c, s in order_by] - statement += f" ORDER BY {','.join(order_statements)}" - - if limit is not None: - statement += f" LIMIT {limit}" - - if offset is not None and offset > 0: - statement += f" OFFSET {offset}" - - return Cursor(self.connection.execute(statement, parameters), columns, self) - - def insert(self, entry: dict[str, Any], exist_ok: bool = False, replace: bool = False): - """ - Insert a row in the table. Existing rows with matching keys can be ignored or replaced. - - :param entry: The row to be inserted as a dict with keys matching the names of the columns. - The values need not be converted beforehand. - :param exist_ok: True if existing rows with the same keys should be ignored, False otherwise, defaults to - False. - :param replace: True if existing rows with the same keys should be replaced, False otherwise, defaults to - False. - """ - values: list[SQLValue] = [ - c.to_entry(entry[c.name]) if c.name in entry else c.default_value() for c in self.columns - ] - - elements: list[str] = ["INSERT"] - - if replace: - elements.append("OR REPLACE") - elif exist_ok: - elements.append("OR IGNORE") - - elements.append(f"INTO {self.name}") - - elements.append(f"({','.join(c.name for c in self.columns)})") - elements.append(f"VALUES ({','.join('?' * len(values))})") - - self.connection.execute(" ".join(elements), values) - - def insert_many( - self, - entries: Sequence[dict[str, Any]] | Iterator[dict[str, Any]], - exist_ok: bool = False, - replace: bool = False, - ): - """ - Insert multiple rows in the table. Existing rows with matching keys can be ignored or replaced. - - :param entries: The rows to be inserted as a list (or iterator) of dicts with keys matching the names of the - columns. The values need not be converted beforehand. - :param exist_ok: True if existing rows with the same keys should be ignored, False otherwise, defaults to - False. - :param replace: True if existing rows with the same keys should be replaced, False otherwise, defaults to - False. - """ - for entry in entries: - self.insert(entry, exist_ok, replace) - - def update(self, entry: dict[str, Any], where: dict[str, Any] | None = None): - """ - Update a row. - - If ``where`` is provided, then the WHERE clause is computed with those, otherwise the table's keys and values in - ``entry`` are used. - - :param entry: The values of the row to be updated as a dict with keys matching the names of the columns. - The values need not be converted beforehand. - :param where: Optionally, the columns and values to use in the WHERE statement. The values need not be - converted beforehand, defaults to None. - :raises OperationalError: If ``where`` is not provided and the table has no keys. - :raises KeyError: If ``where`` is not provided and one of the table's keys is missing from ``entry``. - """ - values: list[tuple[str, SQLValue]] = [ - (c.name, c.to_entry(entry[c.name])) for c in self.columns if c.name in entry - ] - elements: list[str] = [f"UPDATE {self.name} SET", ", ".join(f"{c} = ?" for c, _ in values)] - if where: - where_entry: dict[str, SQLValue] = { - c.name: c.to_entry(where[c.name]) for c in self.columns if c.name in where - } - elif self.keys: - where_entry: dict[str, SQLValue] = {c.name: c.to_entry(entry[c.name]) for c in self.keys} - else: - raise OperationalError("Table has no keys.") - elements.append("WHERE") - elements.append(" AND ".join(f"{c} = ?" for c in where_entry)) - values.extend(where_entry.items()) - - self.connection.execute(" ".join(elements), [v for _, v in values]) - - -class ModelTable(Table, Generic[M]): - def __init__( - self, - connection: "FileDBBase", - name: str, - model: Type[M], - indices: list[Index] | None = None, - ) -> None: - """ - A class that holds information about a table using a model. - - :param connection: A FileDBBase object connected to the database the table belongs to. - :param name: The name of the table. - :param model: The model representing the table. - :param indices: The indices to create in the table, defaults to None. - """ - super().__init__(connection, name, model_to_columns(model), indices) - self.model: Type[M] = model - - def __repr__(self) -> str: - return f'{self.__class__.__name__}[{self.model.__name__}]("{self.name}")' - - def __iter__(self) -> Generator[M, None, None]: - return self.select().fetchall() - - def select( - self, - model: Type[M] | None = None, - where: str | None = None, - order_by: list[tuple[str | Column, str]] | None = None, - offset: int | None = None, - limit: int | None = None, - parameters: list[Any] | None = None, - ) -> ModelCursor[M]: - """ - Select entries from the table. - - :param model: A model with the fields to be selected, defaults to None. - :param where: A WHERE expression, defaults to None. - :param order_by: A list tuples containing one column (either as Column or string) and a sorting direction - ("ASC", or "DESC"), defaults to None. - :param offset: The number of rows skip, defaults to None. - :param limit: The number of rows to limit the results to, defaults to None. - :param parameters: Values to substitute in the SELECT expression, both in the `where` and SelectColumn - statements, defaults to None. - :return: A Cursor object wrapping the SQLite cursor returned by the SELECT transaction. - """ - return ModelCursor[M]( - super().select(model_to_columns(model or self.model), where, order_by, offset, limit, parameters).cursor, - model or self.model, - self, - ) - - def insert(self, entry: M, exist_ok: bool = False, replace: bool = False): - """ - Insert a row in the table. Existing rows with matching keys can be ignored or replaced. - - :param entry: The row to be inserted as a model object with attributes matching the names of the columns. - :param exist_ok: True if existing rows with the same keys should be ignored, False otherwise, defaults to - False. - :param replace: True if existing rows with the same keys should be replaced, False otherwise, defaults to - False. - """ - super().insert(entry.model_dump(), exist_ok, replace) - - def insert_many( - self, - entries: Sequence[M] | Iterator[M], - exist_ok: bool = False, - replace: bool = False, - ): - """ - Insert multiple rows in the table. Existing rows with matching keys can be ignored or replaced. - - :param entries: The rows to be inserted as a list (or iterator) of model objects with attributes matching - the names of the columns. - :param exist_ok: True if existing rows with the same keys should be ignored, False otherwise, defaults to - False. - :param replace: True if existing rows with the same keys should be replaced, False otherwise, defaults to - False. - """ - for entry in entries: - self.insert(entry, exist_ok, replace) - - def update(self, entry: M | dict[str, Any], where: dict[str, Any] | None = None): - """ - Update a row. - - If ``where`` is provided, then the WHERE clause is computed with those, otherwise the table's keys and values in - ``entry`` are used. - - :param entry: The row to be inserted as a model object with attributes matching the names of the columns. - Alternatively, a dict with keys matching the names of the columns. - :param where: Optionally, the columns and values to use in the WHERE statement. The values need not be - converted beforehand, defaults to None. - :raises OperationalError: If ``where`` is not provided and the table has no keys. - :raises KeyError: If ``where`` is not provided and one of the table's keys is missing from ``entry``. - """ - super().update(entry if isinstance(entry, dict) else entry.model_dump(), where) - - -# noinspection SqlResolve -class KeysTable: - def __init__(self, connection: "FileDBBase", name: str, keys: list[Column]) -> None: - """ - A class that holds information about a key-value pairs table. - - :param connection: A FileDBBase object connected to the database the table belongs to. - :param name: The name of the table. - :param keys: The keys of the table. - """ - self.keys: list[Column] = keys - self.connection: FileDBBase = connection - self.name: str = name - self.columns: list[Column] = [ - Column("KEY", "text", str, str, True, True), - Column("VALUE", "text", or_none(lambda o: dumps(dump_object(o))), or_none(loads)), - ] - - def __repr__(self) -> str: - return f'{self.__class__.__name__}("{self.name}")' - - def __len__(self) -> int: - return len(self.keys) - - def __iter__(self) -> Generator[tuple[str, Any], None, None]: - return ((k, v) for k, v in self.select().items()) - - def create_statement(self, exist_ok: bool = True) -> str: - """ - Generate the expression that creates the table. - - :param exist_ok: True if existing tables with the same name should be ignored, defaults to True. - :return: A CREATE TABLE expression. - """ - return Table(self.connection, self.name, self.columns).create_statement(exist_ok) - - def create(self, exist_ok: bool = True): - self.connection.execute(self.create_statement(exist_ok)) - - def select(self) -> dict[str, Any] | None: - data = dict(self.connection.execute(f"select KEY, VALUE from {self.name}").fetchall()) - return {c.name: c.from_entry(data[c.name]) for c in self.keys} if data else None - - def update(self, entry: dict[str, Any]): - """ - Update the table with new data. - - Existing key-value pairs are replaced if the new entry contains an existing key. - - :param entry: A dictionary with string keys. - """ - entry = {k.lower(): v for k, v in entry.items()} - entry = {c.name: c.to_entry(entry[c.name.lower()]) if c.name in entry else c.default_value() for c in self.keys} - - for key, value in entry.items(): - self.connection.execute(f"insert or replace into {self.name} (KEY, VALUE) values (?, ?)", [key, value]) - - -class ModelKeysTable(KeysTable, Generic[M]): - def __init__(self, connection: "FileDBBase", name: str, model: Type[M]) -> None: - """ - A class that holds information about a key-value pairs table using a BaseModel for validation and parsing. - - :param connection: A FileDBBase object connected to the database the table belongs to. - :param name: The name of the table. - :param model: The model of the table. - """ - self.model: Type[M] = model - super().__init__(connection, name, model_to_columns(model)) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}[{self.model.__name__}]("{self.name}")' - - def select(self) -> M | None: - data = super().select() - return self.model.model_validate(data) if data else None - - def update(self, entry: M): - """ - Update the table with new data. - - Existing key-value pairs are replaced if the new entry contains an existing key. - - :param entry: A BaseModel object. - """ - assert issubclass(type(entry), self.model), f"{type(entry).__name__} is not a subclass of {self.model.__name__}" - super().update(entry.model_dump()) - - -# noinspection SqlNoDataSourceInspection -class View(Table): - def __init__( - self, - connection: "FileDBBase", - name: str, - on: Table | str, - columns: list[Column | SelectColumn], - where: str | None = None, - group_by: list[Column | SelectColumn] | None = None, - order_by: list[tuple[str | Column, str]] | None = None, - limit: int | None = None, - joins: list[str] | None = None, - ) -> None: - """ - A subclass of Table to handle views. - - :param connection: A FileDBBase object connected to the database the view belongs to. - :param name: The name of the table. - :param on: The table the view is based on. - :param columns: The columns of the view. - :param where: A WHERE expression for the view, defaults to None. - :param group_by: A GROUP BY expression for the view, defaults to None. - :param order_by: A list tuples containing one column (either as Column or string) and a sorting direction - ("ASC", or "DESC"), defaults to None. - :param limit: The number of rows to limit the results to, defaults to None. - :param joins: Join operations to apply to the view, defaults to None. - """ - assert columns, "Views must have columns" - super().__init__(connection, name, columns) - self.on: Table | str = on - self.where: str = where - self.group_by: list[Column | SelectColumn] = group_by or [] - self.order_by: list[tuple[str | Column, str]] | None = order_by or [] - self.limit: int | None = limit - self.joins: list[str] = joins or [] - - def __repr__(self) -> str: - return f'{self.__class__.__name__}("{self.name}", on={self.on!r})' - - def create_statement(self, exist_ok: bool = True) -> str: - """ - Generate the expression that creates the view. - - :param exist_ok: True if existing views with the same name should be ignored, defaults to True. - :return: A CREATE VIEW expression. - """ - on_table: str = self.on.name if isinstance(self.on, Table) else self.on - elements: list[str] = ["CREATE VIEW"] - - if exist_ok: - elements.append("IF NOT EXISTS") - - elements.append(self.name) - - elements.append("AS") - - if not any(isinstance(c, SelectColumn) for c in self.columns) and [c.name for c in self.columns] == [ - c.name for c in self.on.columns - ]: - select_names = ["*"] - else: - select_names = [ - f"{c.name} as {c.alias}" if c.alias else f"{on_table}.{c.name}" - for c in [SelectColumn.from_column(c) for c in self.columns] - ] - - elements.append( - f"SELECT {','.join(select_names)} " f"FROM {on_table}", - ) - - if self.joins: - elements.extend(self.joins) - - if self.where: - elements.append(f"WHERE {self.where}") - - if self.group_by: - elements.append("GROUP BY") - elements.append( - ",".join( - [c.alias or c.name for c in [SelectColumn.from_column(c) for c in self.group_by]], - ), - ) - - if self.order_by: - order_statements = [ - f"{(SelectColumn.from_column(c).name or c.name) if isinstance(c, Column) else c} {s}" - for c, s in self.order_by - ] - elements.append(f"ORDER BY {','.join(order_statements)}") - - if self.limit is not None: - elements.append(f"LIMIT {self.limit}") - - return " ".join(elements) - - def select( - self, - columns: list[Column | SelectColumn] | None = None, - where: str | None = None, - order_by: list[tuple[str | Column, str]] | None = None, - offset: int | None = None, - limit: int | None = None, - parameters: list[Any] | None = None, - ) -> Cursor: - """ - Select entries from the view. - - :param columns: A list of columns to be selected, defaults to None. - :param where: A WHERE expression, defaults to None. - :param order_by: A list tuples containing one column (either as Column or string) and a sorting direction - ("ASC", or "DESC"), defaults to None. - :param offset: The number of rows skip, defaults to None. - :param limit: The number of rows to limit the results to, defaults to None. - :param parameters: Values to substitute in the SELECT expression, both in the `where` and SelectColumn - statements, defaults to None. - :return: A Cursor object wrapping the SQLite cursor returned by the SELECT transaction. - """ - columns = columns or [ - Column( - c.alias or c.name, - c.sql_type, - c.to_entry, - c.from_entry, - c.unique, - c.primary_key, - c.not_null, - c.check, - c.default, - ) - for c in map(SelectColumn.from_column, self.columns) - ] - return super().select(columns, where, order_by, offset, limit, parameters) - - def insert(self, *_args, **_kwargs): - """ - Insert function. - - :raises OperationalError: Insert transactions are not allowed on views. - """ - raise OperationalError("Cannot insert into view") - - def insert_many(self, *_args, **_kwargs): - """ - Insert many. - - :raises OperationalError: Insert transactions are not allowed on views. - """ - raise OperationalError("Cannot insert into view") - - -class ModelView(View, Generic[M]): - def __init__( - self, - connection: "FileDBBase", - name: str, - on: Table | str, - model: Type[M], - columns: list[Column | SelectColumn] | None = None, - where: str | None = None, - group_by: list[Column | SelectColumn] | None = None, - order_by: list[tuple[str | Column, str]] | None = None, - limit: int | None = None, - joins: list[str] | None = None, - ) -> None: - """ - A subclass of Table to handle views with models. - - :param connection: A FileDBBase object connected to the database the view belongs to. - :param name: The name of the table. - :param on: The table the view is based on. - :param model: A BaseModel subclass. - :param columns: Optionally, the columns of the view if the model is too limited, defaults to None. - :param where: A WHERE expression for the view, defaults to None. - :param group_by: A GROUP BY expression for the view, defaults to None. - :param order_by: A list tuples containing one column (either as Column or string) and a sorting direction - ("ASC", or "DESC"), defaults to None. - :param limit: The number of rows to limit the results to, defaults to None. - :param joins: Join operations to apply to the view, defaults to None. - """ - super().__init__( - connection, - name, - on, - columns or model_to_columns(model), - where, - group_by, - order_by, - limit, - joins, - ) - self.model: Type[M] = model - - def __repr__(self) -> str: - return f'{self.__class__.__name__}[{self.model.__name__}]("{self.name}", on={self.on!r})' - - def select( - self, - model: Type[M] | None = None, - where: str | None = None, - order_by: list[tuple[str | Column, str]] | None = None, - offset: int | None = None, - limit: int | None = None, - parameters: list[Any] | None = None, - ) -> ModelCursor[M]: - return ModelCursor[M]( - super().select(model_to_columns(model or self.model), where, order_by, offset, limit, parameters).cursor, - model or self.model, - self, - ) - - -class FileDBBase(Connection): - def __init__( - self, - database: str | bytes | PathLike[str] | PathLike[bytes], - *, - timeout: float = 5.0, - detect_types: int = 0, - isolation_level: str | None = "DEFERRED", - check_same_thread: bool = True, - factory: Type[Connection] | None = Connection, - cached_statements: int = 100, - uri: bool = False, - ) -> None: - """ - A wrapper class for an SQLite connection. - - :param database: The path or URI to the database. - :param timeout: How many seconds the connection should wait before raising an OperationalError when a table - is locked, defaults to 5.0. - :param detect_types: Control whether and how data types not natively supported by SQLite are looked up to be - converted to Python types, defaults to 0. - :param isolation_level: The isolation_level of the connection, controlling whether and how transactions are - implicitly opened, defaults to "DEFERRED". - :param check_same_thread: If True (default), ProgrammingError will be raised if the database connection is - used by a thread other than the one that created it, defaults to True. - :param factory: A custom subclass of Connection to create the connection with, if not the default Connection - class, defaults to Connection. - :param cached_statements: The number of statements that sqlite3 should internally cache for this connection, - to avoid parsing overhead, defaults to 100. - :param uri: If set to True, database is interpreted as a URI with a file path and an optional query string, - defaults to False. - """ - self.committed_changes: int = 0 - super().__init__( - database, - timeout, - detect_types, - isolation_level, - check_same_thread, - factory, - cached_statements, - uri, - ) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.path})" - - def __enter__(self) -> "FileDBBase": - return self - - def __exit__(self, _exc_type: Type[BaseException], _exc_val: BaseException, _exc_tb: TracebackType) -> None: - self.close() - - @property - def path(self) -> Path | None: - for _, name, filename in self.execute("PRAGMA database_list"): - if name == "main" and filename: - return Path(filename) - - return None - - @property - def is_open(self) -> bool: - try: - self.execute("SELECT * FROM sqlite_master") - return True - except DatabaseError: - return False - - def commit(self): - """Commit any pending transaction to the database.""" - super().commit() - self.committed_changes = self.total_changes - - @overload - def create_table(self, name: str, columns: Type[M], indices: list[Index] | None = None) -> ModelTable[M]: ... - - @overload - def create_table(self, name: str, columns: list[Column], indices: list[Index] | None = None) -> Table: ... - - def create_table( - self, - name: str, - columns: Type[M] | list[Column], - indices: list[Index] | None = None, - ) -> Table | ModelTable[M]: - """ - Create a table in the database. - - When the `columns` argument is a subclass of BadeModel, a ModelTable object is returned. - - :param name: The name of the table. - :param columns: A BaseModel subclass or the columns of the table. - :param indices: The indices to create in the table, defaults to None. - """ - if issubclass(columns, BaseModel): - return ModelTable[M](self, name, columns, model_to_indices(columns) if indices is None else indices) - else: - return Table(self, name, columns, indices) - - @overload - def create_keys_table(self, name: str, columns: Type[M]) -> ModelKeysTable[M]: ... - - @overload - def create_keys_table(self, name: str, columns: list[Column]) -> KeysTable: ... - - def create_keys_table(self, name: str, columns: Type[M] | list[Column]) -> KeysTable | ModelKeysTable[M]: - """ - Create a key-value pairs table in the database. - - When the `columns` argument is a subclass of BaseModel, a ModelTable object is returned. - - :param name: The name of the table. - :param columns: A BaseModel subclass or the columns of the table. - """ - if issubclass(columns, BaseModel): - return ModelKeysTable[M](self, name, columns) - else: - return KeysTable(self, name, columns) - - @overload - def create_view( - self, - name: str, - on: Table | str, - columns: Type[M], - where: str | None = None, - group_by: list[Column | SelectColumn] | None = None, - order_by: list[tuple[str | Column, str]] | None = None, - limit: int | None = None, - joins: list[str] | None = None, - *, - select_columns: list[Column | SelectColumn] | None = None, - ) -> ModelView[M]: ... - - @overload - def create_view( - self, - name: str, - on: Table | str, - columns: list[Column | SelectColumn], - where: str | None = None, - group_by: list[Column | SelectColumn] | None = None, - order_by: list[tuple[str | Column, str]] | None = None, - limit: int | None = None, - joins: list[str] | None = None, - ) -> View: ... - - def create_view( - self, - name: str, - on: Table | str, - columns: list[Column | SelectColumn] | Type[M], - where: str | None = None, - group_by: list[Column | SelectColumn] | None = None, - order_by: list[tuple[str | Column, str]] | None = None, - limit: int | None = None, - joins: list[str] | None = None, - *, - select_columns: list[Column | SelectColumn] | None = None, - ) -> View | ModelView[M]: - """ - Create a view in the database. - - When the `columns` argument is a subclass of BadeModel, a ModelView object is returned. - - :param name: The name of the table. - :param on: The table the view is based on. - :param columns: A BaseModel subclass or the columns of the view. - :param where: A WHERE expression for the view, defaults to None. - :param group_by: A GROUP BY expression for the view, defaults to None. - :param order_by: A list tuples containing one column (either as Column or string) and a sorting direction - ("ASC", or "DESC"), defaults to None. - :param limit: The number of rows to limit the results to, defaults to None. - :param select_columns: Optionally, the columns of the view if a model is given and is too limited, defaults - to None. - :param joins: Join operations to apply to the view, defaults to None. - """ - if issubclass(columns, BaseModel): - return ModelView[M](self, name, on, columns, select_columns, where, group_by, order_by, limit, joins) - else: - return View(self, name, on, columns, where, group_by, order_by, limit, joins) diff --git a/acacore/database/column.py b/acacore/database/column.py index 41d4eb01..208a509d 100644 --- a/acacore/database/column.py +++ b/acacore/database/column.py @@ -1,178 +1,19 @@ +from dataclasses import dataclass from datetime import datetime -from functools import reduce -from json import dumps -from json import loads from pathlib import Path -from re import Pattern from typing import Any from typing import Callable -from typing import Generic -from typing import Literal -from typing import Optional -from typing import Sequence +from typing import Self from typing import Type -from typing import TypeVar from uuid import UUID -from pydantic import AliasChoices -from pydantic import AliasPath +from orjson import dumps +from orjson import loads from pydantic import BaseModel -from pydantic import Discriminator -from pydantic import Field -# noinspection PyProtectedMember -from pydantic.config import JsonDict - -# noinspection PyProtectedMember -from pydantic.fields import FieldInfo -from pydantic_core import PydanticUndefined +from acacore.utils.functions import or_none SQLValue = str | bytes | int | float | bool | datetime | None -T = TypeVar("T") -V = TypeVar("V", str, bytes, int, float, bool, datetime, None) - - -# noinspection PyPep8Naming -def DBField( - default: Any = PydanticUndefined, # noqa: ANN401 - *, - default_factory: Callable[[], Any] | None = PydanticUndefined, - alias: str | None = PydanticUndefined, - alias_priority: int | None = PydanticUndefined, - validation_alias: str | AliasPath | AliasChoices | None = PydanticUndefined, - serialization_alias: str | None = PydanticUndefined, - title: str | None = PydanticUndefined, - description: str | None = PydanticUndefined, - examples: list[Any] | None = PydanticUndefined, - exclude: bool | None = PydanticUndefined, - discriminator: str | Discriminator | None = PydanticUndefined, - deprecated: str | bool | None = PydanticUndefined, - json_schema_extra: JsonDict | Callable[[JsonDict], None] | None = PydanticUndefined, - frozen: bool | None = PydanticUndefined, - validate_default: bool | None = PydanticUndefined, - in_repr: bool = PydanticUndefined, - init: bool | None = PydanticUndefined, - init_var: bool | None = PydanticUndefined, - kw_only: bool | None = PydanticUndefined, - pattern: str | Pattern[str] | None = PydanticUndefined, - strict: bool | None = PydanticUndefined, - coerce_numbers_to_str: bool | None = PydanticUndefined, - gt: float | None = PydanticUndefined, - ge: float | None = PydanticUndefined, - lt: float | None = PydanticUndefined, - le: float | None = PydanticUndefined, - multiple_of: float | None = PydanticUndefined, - allow_inf_nan: bool | None = PydanticUndefined, - max_digits: int | None = PydanticUndefined, - decimal_places: int | None = PydanticUndefined, - min_length: int | None = PydanticUndefined, - max_length: int | None = PydanticUndefined, - union_mode: Literal["smart", "left_to_right"] = PydanticUndefined, - primary_key: bool | None = PydanticUndefined, - index: list[str] | None = PydanticUndefined, - ignore: bool | None = PydanticUndefined, -) -> FieldInfo: - """ - A wrapper around ``pydantic.Field`` with added parameters for database specs. - - :param primary_key: Whether the field is a primary key. - :param index: A list of indices the field belongs to. - :param ignore: Whether the field should be ignored when creating the table spec. - :param default: Default value if the field is not set. - :param default_factory: A callable to generate the default value, such as ``~datetime.utcnow``. - :param alias: The name to use for the attribute when validating or serializing by alias. - This is often used for things like converting between snake and camel case. - :param alias_priority: Priority of the alias. This affects whether an alias generator is used. - :param validation_alias: Like ``alias``, but only affects validation, not serialization. - :param serialization_alias: Like `alias`, but only affects serialization, not validation. - :param title: Human-readable title. - :param description: Human-readable description. - :param examples: Example values for this field. - :param exclude: Whether to exclude the field from the model serialization. - :param discriminator: Field name or Discriminator for discriminating the type in a tagged union. - :param deprecated: A deprecation message, an instance of ``warnings.deprecated`` or the - ``typing_extensions.deprecated`` backport, or a boolean. If ``True``, a default deprecation message will be - emitted when accessing the field. - :param json_schema_extra: A dict or callable to provide extra JSON schema properties. - :param frozen: Whether the field is frozen. If true, attempts to change the value on an instance will raise an - error. - :param validate_default: If ``True``, apply validation to the default value every time you create an instance. - Otherwise, for performance reasons, the default value of the field is trusted and not validated. - :param in_repr: A boolean indicating whether to include the field in the ``__repr__`` output. - :param init: Whether the field should be included in the constructor of the dataclass. - (Only applies to dataclasses.) - :param init_var: Whether the field should _only_ be included in the constructor of the dataclass. - (Only applies to dataclasses.) - :param kw_only: Whether the field should be a keyword-only argument in the constructor of the dataclass. - (Only applies to dataclasses.) - :param coerce_numbers_to_str: Whether to enable coercion of any ``Number`` type to ``str`` (not applicable in - ``strict`` mode). - :param strict: If ``True``, strict validation is applied to the field. - :param gt: Greater than. If set, value must be greater than this. Only applicable to numbers. - :param ge: Greater than or equal. If set, value must be greater than or equal to this. Only applicable to numbers. - :param lt: Less than. If set, value must be less than this. Only applicable to numbers. - :param le: Less than or equal. If set, value must be less than or equal to this. Only applicable to numbers. - :param multiple_of: Value must be a multiple of this. Only applicable to numbers. - :param min_length: Minimum length for iterables. - :param max_length: Maximum length for iterables. - :param pattern: Pattern for strings (a regular expression). - :param allow_inf_nan: Allow ``inf``, ``-inf``, ``nan``. Only applicable to numbers. - :param max_digits: Maximum number of allow digits for strings. - :param decimal_places: Maximum number of decimal places allowed for numbers. - :param union_mode: The strategy to apply when validating a union. Can be ``smart`` (the default), or - ``left_to_right``. - :return: A FieldInfo object - """ - extra: dict = ( - json_schema_extra - if isinstance(json_schema_extra, dict) - else json_schema_extra() - if callable(json_schema_extra) - else {} - ) - if primary_key is not PydanticUndefined: - extra["primary_key"] = primary_key - if index is not PydanticUndefined: - extra["index"] = index - if ignore is not PydanticUndefined: - extra["ignore"] = ignore - - return Field( - default=default, - default_factory=default_factory, - alias=alias, - alias_priority=alias_priority, - validation_alias=validation_alias, - serialization_alias=serialization_alias, - title=title, - description=description, - examples=examples, - exclude=exclude, - discriminator=discriminator, - deprecated=deprecated, - frozen=frozen, - validate_default=validate_default, - repr=in_repr, - init=init, - init_var=init_var, - kw_only=kw_only, - pattern=pattern, - strict=strict, - coerce_numbers_to_str=coerce_numbers_to_str, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - allow_inf_nan=allow_inf_nan, - max_digits=max_digits, - decimal_places=decimal_places, - min_length=min_length, - max_length=max_length, - union_mode=union_mode, - json_schema_extra=extra, - ) - _sql_schema_types: dict[str, str] = { "string": "text", @@ -214,260 +55,77 @@ def _value_to_sql(value: SQLValue) -> str: return str(value) -def dump_object(obj: list | tuple | dict | BaseModel) -> list | dict: +def _dump_object(obj: list | tuple | dict | BaseModel) -> list | dict: if isinstance(obj, dict): - return obj + return {k: _dump_object(v) for k, v in obj.items()} elif issubclass(type(obj), BaseModel): return obj.model_dump(mode="json") elif isinstance(obj, (list, tuple)): - return list(map(dump_object, obj)) + return list(map(_dump_object, obj)) else: return obj -def _schema_to_column(name: str, schema: dict, defs: dict[str, dict] | None = None) -> Optional["Column"]: - if schema.get("ignore"): - return None - - defs = defs or {} - if schema.get("$ref"): - schema.update(defs[schema.get("$ref", "").removeprefix("#/$defs/")]) - schema_type: str | None = schema.get("type") - schema_any_of: list[dict] = schema.get("anyOf", []) +@dataclass +class ColumnSpec: + name: str + type: str + to_sql: Callable[[Any | None], SQLValue] + from_sql: Callable[[SQLValue], Any | None] + nullable: bool - sql_type: str - to_entry: Callable[[T | None], V] - from_entry: Callable[[V], T | None] - not_null: bool = (schema_any_of or [{}])[-1].get("type", None) != "null" + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, nullable={self.nullable})" - if schema_type: - sql_type = _sql_schema_types.get(schema_type) - type_name: str = schema.get("format", schema_type) + def spec_sql(self) -> str: + return f"{self.name} {self.type} {'not null' if not self.nullable else ''}".strip() - if schema.get("enum") is not None: - to_entry, from_entry = schema["enum"][0].__class__ if schema["enum"] else str, str - elif schema_type in ("object", "array"): - sql_type = "text" - to_entry, from_entry = ( - lambda o: None if o is None else dumps(dump_object(o), default=str), - lambda o: None if o is None else loads(o), - ) - elif type_name in _sql_schema_type_converters: - to_entry, from_entry = _sql_schema_type_converters[type_name] + @classmethod + def from_schema(cls, name: str, schema: dict, defs: dict[str, dict] | None = None) -> Self: + defs = defs or {} + if schema.get("$ref"): + schema.update(defs[schema.get("$ref", "").removeprefix("#/$defs/")]) + schema_type: str | None = schema.get("type") + schema_any_of: list[dict] = schema.get("anyOf", []) + nullable: bool = any(s.get("type") == "null" for s in schema_any_of) + + sql_type: str + to_sql: Callable[[Any | None], SQLValue] + from_sql: Callable[[SQLValue], Any | None] + + if schema_type: + sql_type = _sql_schema_types.get(schema_type) + type_name: str = schema.get("format", schema_type) + + if schema_type in ("object", "array"): + sql_type, to_sql, from_sql = ( + "text", + lambda x: None if x is None else dumps(_dump_object(x), default=str).decode("utf-8"), + lambda x: None if x is None else loads(x), + ) + elif type_name in _sql_schema_type_converters: + to_sql, from_sql = _sql_schema_type_converters[type_name] + to_sql, from_sql = or_none(to_sql), or_none(from_sql) + else: + raise TypeError(f"Cannot recognize type from schema {schema!r}") + elif schema_any_of: + if not schema_any_of[0] or len(schema_any_of) > 2: + sql_type, to_sql, from_sql = ( + "text", + lambda x: None if x is None else dumps(_dump_object(x), default=str).decode("utf-8"), + lambda x: None if x is None else loads(x), + ) + else: + spec = cls.from_schema(name, {**schema_any_of[0], **schema}, defs) + sql_type, to_sql, from_sql = spec.type, spec.to_sql, spec.from_sql else: raise TypeError(f"Cannot recognize type from schema {schema!r}") - elif schema_any_of: - if not schema_any_of[0] or len(schema_any_of) > 2: - sql_type, to_entry, from_entry = ( - "text", - lambda x: None if x is None else dumps(dump_object(x), default=str), - lambda x: None if x is None else loads(x), - ) - else: - return _schema_to_column(name, {**schema_any_of[0], **schema}, defs) - else: - raise TypeError(f"Cannot recognize type from schema {schema!r}") - - return Column( - name, - sql_type, - lambda x: None if x is None else to_entry(x), - lambda x: None if x is None else from_entry(x), - unique=schema.get("default", False), - primary_key=schema.get("primary_key", False), - not_null=not_null, - default=schema.get("default", ...), - ) - - -def model_to_columns(model: Type[BaseModel]) -> list["Column"]: - schema: dict = model.model_json_schema() - columns = [_schema_to_column(p, s, schema.get("$defs")) for p, s in schema["properties"].items()] - return [c for c in columns if c] - - -def model_to_indices(model: Type[BaseModel]) -> list["Index"]: - columns: dict[str, Column] = {c.name: c for c in model_to_columns(model)} - schema: dict = model.model_json_schema() - indices: list[tuple[Column, str]] = [ - (columns[p], idx) for p, s in schema["properties"].items() if (idxs := s.get("index")) for idx in idxs - ] - unique_indices: list[tuple[Column, str]] = [ - (columns[p], idx) for p, s in schema["properties"].items() if (idxs := s.get("unique_index")) for idx in idxs - ] - indices_merged: dict[str, list[Column]] = reduce(lambda i, c: i | {c[1]: [*i.get(c[1], []), c[0]]}, indices, {}) - indices_merged |= reduce(lambda i, c: i | {c[1]: [*i.get(c[1], []), c[0]]}, unique_indices, {}) - return [Index(n, cs) for n, cs in indices_merged.items()] - - -class Column(Generic[T, V]): - def __init__( - self, - name: str, - sql_type: str, - to_entry: Callable[[T], V], - from_entry: Callable[[V], T], - unique: bool = False, - primary_key: bool = False, - not_null: bool = False, - check: str | None = None, - default: T | None = ..., - ) -> None: - """ - A class that stores information regarding a table column. - - :param name: The name of the column. - :param sql_type: The SQL type to use when creating a table. - :param to_entry: A function that returns a type supported by SQLite - (str, bytes, int, float, bool, datetime, or None). - :param from_entry: A function that takes a type returned by SQLite (str, bytes, int, float, or None) and - returns another object. - :param unique: True if the column should be set as UNIQUE, defaults to False. - :param primary_key: True if the column is a PRIMARY KEY, defaults to False. - :param not_null: True if the column is NOT NULL, defaults to False. - :param check: A string containing a CHECK expression, {name} substrings will be substituted with the name of - the column, defaults to None. - :param default: The column's DEFAULT value, which will be converted using `to_entry`. - Note that None is considered a valid default value; to set it to empty use Ellipsis (...), defaults to .... - """ - self.name: str = name - self.sql_type: str = sql_type - self.to_entry: Callable[[T], V] = to_entry - self.from_entry: Callable[[V], T] = from_entry - self.unique: bool = unique - self.primary_key: bool = primary_key - self.not_null: bool = not_null - self._check: str = check or "" - self.default: T | Ellipsis | None = default - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(" - f"{self.name}" - f", {self.sql_type!r}" - f", unique={self.unique}" - f", primary_key={self.primary_key}" - f", not_null={self.not_null}" - f"{f', default={self.default!r}' if self.default is not Ellipsis else ''}" - f")" - ) - - @classmethod - def from_model(cls, model: Type[BaseModel]) -> list["Column"]: - return model_to_columns(model) - - @property - def check(self) -> str: - return self._check.format(name=self.name) if self._check else "" - - @check.setter - def check(self, check: str | None): - self._check = check - - def default_value(self) -> V: - """ - Get the default value of the column formatted as an SQL parameter. - - :raises ValueError: If the column does not have a set default value. - :return: An object of the return type of the column's to_entry function. - """ - if self.default is Ellipsis: - raise ValueError("Column does not have a default value") - return self.to_entry(self.default) - - def create_statement(self) -> str: - """ - Generate the statement that creates the column. - - :return: A column statement for a CREATE TABLE expression. - """ - elements: list[str] = [self.name, self.sql_type] - if self.unique: - elements.append("unique") - if self.not_null: - elements.append("not null") - if self.default is not Ellipsis: - elements.append(f"default {_value_to_sql(self.default_value())}") - if self.check: - elements.append(f"check ({self.check})") - - return " ".join(elements) - -class SelectColumn(Column, Generic[T, V]): - def __init__( - self, - name: str, - from_entry: Callable[[V], T], - alias: str | None = None, - ) -> None: - """ - A subclass of Column for SELECT expressions that need complex statements and/or an alias. - - :param name: The name or select statement for the select expression (e.g., count(*)). - :param from_entry: A function that takes a type returned by SQLite (str, bytes, int, float, or None) and - returns another object. - :param alias: An alternative name for the select statement, it will be used with the AS keyword and as a key - by Cursor, defaults to None. - """ - super().__init__(name, "", lambda x: x, from_entry) - self.alias: str | None = alias + return cls(name=name, type=sql_type, to_sql=to_sql, from_sql=from_sql, nullable=nullable) @classmethod - def from_column(cls, column: Column, alias: str | None = None) -> "SelectColumn": - """ - Take a Column object and create a SelectColumn with the given alias. - - :param column: The Column object to be converted. - :param alias: An alternative name for the select statement, it will be used with the AS keyword and as a key - by Cursor, defaults to None. - :return: A SelectColumn object. - """ - select_column = SelectColumn(column.name, column.from_entry, alias) - select_column.sql_type = column.sql_type - select_column.to_entry = column.to_entry - select_column.unique = column.unique - select_column.primary_key = column.primary_key - select_column.not_null = column.not_null - select_column._check = column._check - - if isinstance(column, SelectColumn): - select_column.alias = alias or column.alias - - return select_column - - -class Index: - def __init__(self, name: str, columns: Sequence[Column], unique: bool = False) -> None: - """ - A class that stores information regarding an index. - - :param name: The name of the index. - :param columns: The list of columns that the index applies to. - :param unique: Whether the index is unique or not, defaults to False. - """ - self.name: str = name - self.columns: list[Column] = list(columns) - self.unique: bool = unique - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(" - f"{self.name}" - f", unique={self.unique}" - f", columns={[c.name for c in self.columns]}" - f")" - ) - - def create_statement(self, table: str, exist_ok: bool = True): - """ - Generate the expression that creates the index. + def from_model(cls, model: Type[BaseModel], ignore: list[str] | None = None) -> list[Self]: + schema: dict = model.model_json_schema() + ignore = ignore or [] - :param table: The name of the table. - :param exist_ok: True if existing tables with the same name should be ignored, defaults to True. - :return: A CREATE TABLE expression. - """ - return ( - f"create {'unique' if self.unique else ''} index {'if not exists' if exist_ok else ''} {self.name}" - f" on {table} ({','.join(c.name for c in self.columns)})" - ) + return [cls.from_schema(p, s, schema.get("$defs")) for p, s in schema["properties"].items() if p not in ignore] diff --git a/acacore/database/cursor.py b/acacore/database/cursor.py new file mode 100644 index 00000000..90b1251c --- /dev/null +++ b/acacore/database/cursor.py @@ -0,0 +1,46 @@ +from itertools import islice +from sqlite3 import Cursor as SQLiteCursor +from sqlite3 import Row +from typing import Any +from typing import Callable +from typing import Generator +from typing import Generic +from typing import Type +from typing import TypeVar + +from pydantic import BaseModel + +from .column import ColumnSpec +from .column import SQLValue + +M = TypeVar("M", bound=BaseModel) + + +class Cursor(Generic[M]): + def __init__(self, cursor: SQLiteCursor, model: Type[M], columns: list[ColumnSpec]) -> None: + self.cursor: SQLiteCursor[Row] = cursor + self.cursor.row_factory = Row + self.model: Type[M] = model + self.columns: list[ColumnSpec] = columns + + _cols: dict[str, Callable[[SQLValue], Any]] = {c.name: c.from_sql for c in columns} + + def _loader(row: Row) -> M: + return self.model.model_validate({k: f(row[k]) for k, f in _cols.items()}) + + self.entries: Generator[M, None, None] = (_loader(row) for row in self.cursor) + + def __iter__(self) -> Generator[M, None, None]: + yield from self.entries + + def __next__(self) -> M: + return next(self.entries) + + def fetchone(self) -> M | None: + return next(self, None) + + def fetchmany(self, size: int) -> list[M]: + return list(islice(self, size)) + + def fetchall(self) -> list[M]: + return list(self) diff --git a/acacore/database/database.py b/acacore/database/database.py new file mode 100644 index 00000000..125621a1 --- /dev/null +++ b/acacore/database/database.py @@ -0,0 +1,96 @@ +from collections.abc import Sequence +from os import PathLike +from pathlib import Path +from sqlite3 import Connection +from sqlite3 import Cursor as SQLiteCursor +from typing import Iterable +from typing import Mapping +from typing import overload +from typing import Type +from typing import TypeAlias +from typing import TypeVar + +from pydantic import BaseModel + +from .column import SQLValue +from .table import Table +from .table_keyvalue import KeysTable +from .table_view import View + +_M = TypeVar("_M", bound=BaseModel) +P: TypeAlias = Sequence[SQLValue] | Mapping[str, SQLValue] + + +class Database: + def __init__( + self, + path: str | PathLike[str], + *, + timeout: float = 5.0, + detect_types: int = 0, + isolation_level: str | None = "DEFERRED", + check_same_thread: bool = True, + cached_statements: int = 100, + ) -> None: + self.path: Path = Path(path) + self.connection: Connection = Connection( + self.path, + timeout=timeout, + detect_types=detect_types, + isolation_level=isolation_level, + check_same_thread=check_same_thread, + cached_statements=cached_statements, + ) + self._committed_changes: int = 0 + + @overload + def execute(self, sql: str, /) -> SQLiteCursor: ... + + @overload + def execute(self, sql: str, parameters: P, /) -> SQLiteCursor: ... + + def execute(self, sql: str, parameters: P | None = None, /) -> SQLiteCursor: + return self.connection.execute(sql, parameters or []) + + def executemany(self, sql: str, parameters: Iterable[P], /) -> SQLiteCursor: + return self.connection.executemany(sql, parameters) + + def commit(self): + self.connection.commit() + self._committed_changes = self.total_changes + + def rollback(self): + self.connection.rollback() + + @property + def total_changes(self): + return self.connection.total_changes + + @property + def committed_changes(self): + return self._committed_changes + + @property + def uncommitted_changes(self): + return self.total_changes - self._committed_changes + + def close(self): + self.connection.close() + + def create_table( + self, + model: Type[_M], + name: str, + primary_keys: list[str] | None = None, + indices: dict[str, list[str]] | None = None, + ignore: list[str] | None = None, + *, + exist_ok: bool = True, + ) -> Table[_M]: + return Table(self.connection, model, name, primary_keys, indices, ignore).create(exist_ok=exist_ok) + + def create_view(self, model: Type[_M], name: str, select: str, *, exist_ok: bool = True) -> View[_M]: + return View(self.connection, model, name, select).create(exist_ok=exist_ok) + + def create_keys_table(self, model: Type[_M], name: str, *, exist_ok: bool = True) -> KeysTable[_M]: + return KeysTable(self.connection, model, name).create(exist_ok=exist_ok) diff --git a/acacore/database/files_db.py b/acacore/database/files_db.py index e0781c71..be739d90 100644 --- a/acacore/database/files_db.py +++ b/acacore/database/files_db.py @@ -1,22 +1,15 @@ -from datetime import datetime from os import PathLike from pathlib import Path -from sqlite3 import Connection -from typing import Type -from uuid import UUID from pydantic import BaseModel -from acacore.models.file import File +from acacore.models.file import ConvertedFile +from acacore.models.file import OriginalFile from acacore.models.history import HistoryEntry from acacore.models.metadata import Metadata from acacore.models.reference_files import TActionType -from acacore.utils.functions import or_none -from .base import Column -from .base import FileDBBase -from .base import SelectColumn -from .column import model_to_columns +from .database import Database class HistoryEntryPath(HistoryEntry): @@ -39,270 +32,91 @@ class ActionCount(BaseModel): count: int -class FileDB(FileDBBase): +class FilesDB(Database): def __init__( self, - database: str | bytes | PathLike[str] | PathLike[bytes], + path: str | PathLike[str], *, timeout: float = 5.0, detect_types: int = 0, isolation_level: str | None = "DEFERRED", check_same_thread: bool = True, - factory: Type[Connection] | None = Connection, cached_statements: int = 100, - uri: bool = False, - check_version: bool = True, ) -> None: - """ - A class that handles the SQLite database used by AArhus City Archives to process data archives. - - :param database: The path or URI to the database. - :param timeout: How many seconds the connection should wait before raising an OperationalError when a table - is locked, defaults to 5.0. - :param detect_types: Control whether and how data types not natively supported by SQLite are looked up to be - converted to Python types, defaults to 0. - :param isolation_level: The isolation_level of the connection, controlling whether and how transactions are - implicitly opened, defaults to "DEFERRED". - :param check_same_thread: If True (default), ProgrammingError will be raised if the database connection is - used by a thread other than the one that created it, defaults to True. - :param factory: A custom subclass of Connection to create the connection with, if not the default Connection - class, defaults to Connection. - :param cached_statements: The number of statements that sqlite3 should internally cache for this connection, - to avoid parsing overhead, defaults to 100. - :param uri: If set to True, database is interpreted as a URI with a file path and an optional query string, - defaults to False. - :param check_version: If set to True, check the database version and ensure it is the latest. - """ super().__init__( - database, + path, timeout=timeout, detect_types=detect_types, isolation_level=isolation_level, check_same_thread=check_same_thread, - factory=factory, cached_statements=cached_statements, - uri=uri, ) - self.files = self.create_table("Files", File) - self.history = self.create_table("History", HistoryEntry) - self.metadata = self.create_keys_table("Metadata", Metadata) + self.original_files = self.create_table( + OriginalFile, + "files_original", + ["relative_path"], + {"uuid": ["uuid"], "checksum": ["checksum"], "action": ["action"]}, + ["root"], + ) + self.master_files = self.create_table( + ConvertedFile, + "files_master", + ["relative_path"], + {"uuid": ["uuid"], "checksum": ["checksum"], "action": ["action"]}, + ["root"], + ) + self.access_files = self.create_table( + ConvertedFile, + "files_master", + ["relative_path"], + {"uuid": ["uuid"], "checksum": ["checksum"], "action": ["action"]}, + ["root"], + ) + self.statutory_files = self.create_table( + ConvertedFile, + "files_statutory", + ["relative_path"], + {"uuid": ["uuid"], "checksum": ["checksum"], "action": ["action"]}, + ["root"], + ) - self.history_paths = self.create_view( - "_HistoryPaths", - self.history, + self.log = self.create_table( + HistoryEntry, + "log", + indices={"uuid": ["uuid"], "time": ["time"], "operation": ["operation"]}, + ) + self.log_paths = self.create_view( HistoryEntryPath, - select_columns=[ - SelectColumn("F.relative_path", str, "relative_path"), - *model_to_columns(HistoryEntry), - ], - joins=[f"left join {self.files.name} F on F.UUID = {self.history.name}.uuid"], + "log_paths", + f"select f.relative_path as relative_path, h.* from {self.log.name} h left join {self.original_files.name} f on f.uuid = h.uuid", ) + self.identification_warnings = self.create_view( - "_IdentificationWarnings", - self.files, - self.files.model, - f'("{self.files.name}".warning is not null or "{self.files.name}".puid is null)' - f' and "{self.files.name}".size != 0', + OriginalFile, + "view_identification_warnings", + f"select * from {self.original_files.name} where (warning is not null or puid is null) and size != 0", ) - self.checksum_count = self.create_view( - "_ChecksumCount", - self.files, - ChecksumCount, - None, - [ - Column("checksum", "varchar", str, str, False, False, False), - ], - [ - (Column("count", "int", str, str), "DESC"), - ], - select_columns=[ - Column( - "checksum", - "varchar", - or_none(str), - or_none(str), - False, - False, - False, - ), - SelectColumn( - f'count("{self.files.name}.checksum")', - int, - "count", - ), - ], - ) - self.signature_count = self.create_view( - "_SignatureCount", - self.files, + + self.signatures_count = self.create_view( SignatureCount, - None, - [ - Column("puid", "varchar", str, str, False, False, False), - ], - [ - (Column("count", "int", str, str), "ASC"), - ], - select_columns=[ - Column( - "puid", - "varchar", - or_none(str), - or_none(str), - False, - False, - False, - ), - Column( - "signature", - "varchar", - or_none(str), - or_none(str), - False, - False, - False, - ), - SelectColumn( - f"count(" - f'CASE WHEN ("{self.files.name}".puid IS NULL) ' - f"THEN 'None' " - f'ELSE "{self.files.name}".puid ' - f"END)", - int, - "count", - ), - ], + "view_signatures_count", + f"select puid, count(*) as count from {self.original_files.name} group by puid, signature order by count desc", ) + self.actions_count = self.create_view( - "_ActionsCount", - self.files, ActionCount, - None, - [ - Column("action", "varchar", str, str, False, False, False), - ], - [ - (Column("count", "int", str, str), "DESC"), - ], - select_columns=[ - Column( - "action", - "varchar", - or_none(str), - or_none(str), - False, - False, - False, - ), - SelectColumn( - f'count("{self.files.name}.action")', - int, - "count", - ), - ], + "view_actions_count", + f"select action, count(*) as count from {self.original_files.name} group by action order by count desc", ) - if self.is_initialised(check_views=False, check_indices=False): - if check_version: - from acacore.database.upgrade import is_latest - - is_latest(self, raise_on_difference=True) - else: - self.init() - - def is_initialised(self, *, check_views: bool = True, check_indices: bool = True) -> bool: - """ - Check if the database is initialised. - - :param check_views: Whether to check if all views are present. Defaults to ``True``. - :param check_indices: Whether to check if all indices are present. Defaults to ``True``. - :return: ``True`` if the database is initialised, ``False`` otherwise. - """ - tables: set[str] = {t.lower() for [t] in self.execute("select name from sqlite_master where type = 'table'")} - if not {self.files.name.lower(), self.history.name.lower(), self.metadata.name.lower()}.issubset(set(tables)): - return False - - if check_views: - views: set[str] = {n.lower() for [n] in self.execute("select name from sqlite_master where type = 'view'")} - expected_views: set[str] = { - self.history_paths.name.lower(), - self.identification_warnings.name.lower(), - self.checksum_count.name.lower(), - self.signature_count.name.lower(), - self.actions_count.name.lower(), - } - if not expected_views.issubset(views): - return False - - if check_indices: - indices: set[str] = { - n.lower() for [n] in self.execute("select name from sqlite_master where type = 'index'") - } - expected_indices: set[str] = { - i.name.lower() - for i in [ - *self.history_paths.indices, - *self.identification_warnings.indices, - *self.checksum_count.indices, - *self.signature_count.indices, - *self.actions_count.indices, - ] - } - if not expected_indices.issubset(indices): - return False - - return True - - def init(self): - """Initialize the database with all the necessary tables and views.""" - self.files.create(True) - self.history.create(True) - self.metadata.create(True) - self.history_paths.create(True) - self.identification_warnings.create(True) - self.checksum_count.create(True) - self.signature_count.create(True) - self.actions_count.create(True) - self.metadata.update(self.metadata.model()) - self.commit() - - def upgrade(self): - """Upgrade the database to the latest version.""" - from acacore.database.upgrade import upgrade - - upgrade(self) - self.init() - - def is_empty(self) -> bool: - """Check if the database contains any files.""" - return not self.files.select(limit=1).fetchone() + self.checksums_count = self.create_view( + ChecksumCount, + "view_checksums_count", + f"select checksum, count(*) as count from {self.original_files.name} group by checksum order by count desc", + ) - def add_history( - self, - uuid: UUID | None, - operation: str, - data: dict | list | str | int | float | bool | datetime | None, - reason: str | None = None, - *, - time: datetime | None = None, - ) -> HistoryEntry: - """ - Add a history entry to the database. + self.metadata = self.create_keys_table(Metadata, "metadata", exist_ok=True) - :param uuid: The UUID of the file the event refers to, if any. - :param operation: The operation that was performed. - :param data: The data attached to the event. - :param reason: The reason for the event. - :param time: The time of the event, defaults to current time. - :return: The ``HistoryEntry`` object representing the event. - """ - entry = self.history.model( - uuid=uuid, - operation=operation, - data=data, - reason=reason, - time=time or datetime.now(), - ) - self.history.insert(entry) - return entry + if not self.metadata.get(): + self.metadata.set(Metadata()) diff --git a/acacore/database/table.py b/acacore/database/table.py new file mode 100644 index 00000000..a2b4f199 --- /dev/null +++ b/acacore/database/table.py @@ -0,0 +1,209 @@ +from re import sub +from sqlite3 import Connection +from sqlite3 import ProgrammingError +from typing import Generator +from typing import Generic +from typing import Literal +from typing import Self +from typing import Type +from typing import TypeAlias +from typing import TypeVar + +from pydantic import BaseModel + +from .column import ColumnSpec +from .column import SQLValue +from .cursor import Cursor + +M = TypeVar("M", bound=BaseModel) +_Where: TypeAlias = str | dict[str, SQLValue | list[SQLValue]] + + +def _where_dict_to_sql(where: dict[str, SQLValue | list[SQLValue]]) -> tuple[str, list[SQLValue]]: + params: list[SQLValue] = [] + sql: list[str] = [] + + for k, vs in where.items(): + vs = vs if isinstance(vs, list) else [vs] + col_sql: list[str] = [] + + for v in vs: + if v is None: + col_sql.append(f"{k} is null") + else: + col_sql.append(f"{k} = ?") + params.append(v) + + if len(col_sql) == 1: + sql.append(col_sql[0]) + elif col_sql: + sql.append(f"({' or '.join(col_sql)})") + + return " and ".join(sql).strip(), params + + +def _where_to_sql( + where: _Where | BaseModel, + params: list[SQLValue] | None, + primary_keys: list[ColumnSpec], +) -> tuple[str, list[SQLValue]]: + params = params or [] + + if where is None: + where = "" + elif isinstance(where, BaseModel): + where, params = _where_dict_to_sql({pk.name: pk.to_sql(getattr(where, pk.name)) for pk in primary_keys}) + elif isinstance(where, str): + where = sub(r"^where\s+", "", where) if where.strip() else "" + elif isinstance(where, dict): + where, params = _where_dict_to_sql(where) + else: + raise TypeError(f"Unsupported type {type(where)}") + + return where.strip(), params if where else [] + + +class Table(Generic[M]): + def __init__( + self, + database: Connection, + model: Type[M], + name: str, + primary_keys: list[str] | None = None, + indices: dict[str, list[str]] | None = None, + ignore: list[str] | None = None, + ) -> None: + self.database: Connection = database + self.model: Type[M] = model + self.name: str = name + self.columns: dict[str, ColumnSpec] = {c.name: c for c in ColumnSpec.from_model(self.model, ignore)} + + _primary_keys: set[str] = set(primary_keys or []) + _indices: dict[str, set[str]] = {_i: set(cs) for i, cs in (indices or {}).items() if (_i := i.strip())} + + if missing_keys := [pk for pk in _primary_keys if pk not in self.columns]: + raise ValueError( + f"Primary keys {', '.join(map(repr, missing_keys))} do not exist in model {self.model.__name__!r}" + ) + + if missing_keys := [c for cs in _indices.values() for c in cs if c not in self.columns]: + raise ValueError( + f"Index keys {', '.join(map(repr, missing_keys))} do not exist in model {self.model.__name__!r}" + ) + + self.primary_keys: list[ColumnSpec] = [self.columns[pk] for pk in _primary_keys] + self.indices: dict[str, list[ColumnSpec]] = {i: [self.columns[c] for c in cs] for i, cs in _indices.items()} + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.name!r}, {self.model.__name__})" + + def __iter__(self) -> Generator[M, None, None]: + yield from self.select() + + def __getitem__(self, where: _Where | M) -> M | None: + return self.select(where, limit=1).fetchone() + + def __setitem__(self, where: _Where | M | slice, row: M) -> None: + if isinstance(where, slice): + self.insert(row) + else: + self.update(row, where) + + def __delitem__(self, where: _Where | M) -> None: + self.delete(where) + + def __contains__(self, where: M) -> bool: + return self.select(where, limit=1).cursor.fetchone() is not None + + def create_sql(self, *, exist_ok: bool = False) -> str: + sql: list[str] = ["create table"] + + if exist_ok: + sql.append("if not exists") + + sql.append(self.name) + + sql_cols = [c.spec_sql() for c in self.columns.values()] + if self.primary_keys: + sql_cols.append(f"primary key ({','.join(pk.name for pk in self.primary_keys)})") + + sql.append(f"({','.join(sql_cols)})") + + return " ".join(sql) + + def indices_sql(self, *, exist_ok: bool = False) -> list[str]: + return [ + f"create index {'if not exists' if exist_ok else ''} idx_{self.name}_{index} on {self.name} ({','.join(c.name for c in cols)})" + for index, cols in self.indices.items() + ] + + def create(self, *, exist_ok: bool = False) -> Self: + self.database.execute(self.create_sql(exist_ok=exist_ok)) + for index_sql in self.indices_sql(exist_ok=exist_ok): + self.database.execute(index_sql) + return self + + def select( + self, + where: _Where | M | None = None, + params: list[SQLValue] | None = None, + order_by: list[tuple[str, str]] | None = None, + limit: int | None = None, + offset: int | None = None, + ) -> Cursor[M]: + where, params = _where_to_sql(where, params, self.primary_keys) + + sql: list[str] = [f"select * from {self.name}"] + + if where: + sql.append(f"where {where}") + + if order_by: + sql.append(f"order by {','.join(o + ' ' + d for o, d in order_by)}") + + if limit is not None: + sql.append(f"limit {limit}") + if offset is not None: + sql.append(f"offset {offset}") + + return Cursor(self.database.execute(" ".join(sql), params), self.model, list(self.columns.values())) + + def insert(self, *rows: M, on_exists: Literal["ignore", "replace", "error"] = "error") -> int: + cols: list[ColumnSpec] = list(self.columns.values()) + sql: list[str] = ["insert"] + + if on_exists in ("ignore", "replace"): + sql.append(f"or {on_exists}") + + sql.append(f"into {self.name}") + + sql.append(f"({','.join(c.name for c in cols)}) values ({','.join('?' * len(cols))})") + + return self.database.executemany( + " ".join(sql), + (tuple(c.to_sql(getattr(row, c.name)) for c in cols) for row in rows), + ).rowcount + + def upsert(self, *rows: M) -> int: + return self.insert(*rows, on_exists="replace") + + def update(self, row: M, where: _Where | M = None, params: list[SQLValue] | None = None) -> int: + where, params = _where_to_sql(where or row, params, self.primary_keys) + + if not where: + raise ProgrammingError("Update without where") + + cols: list[ColumnSpec] = list(self.columns.values()) + + return self.database.execute( + f"update {self.name} set {','.join(f'{c.name} = ?' for c in cols)} where {where}", + [*[c.to_sql(getattr(row, c.name)) for c in cols], *params], + ).rowcount + + def delete(self, where: _Where | M) -> int: + where, params = _where_to_sql(where, [], self.primary_keys) + + if not where: + raise ProgrammingError("Delete without where") + + return self.database.execute(f"delete from {self.name} where {where}", params).rowcount diff --git a/acacore/database/table_keyvalue.py b/acacore/database/table_keyvalue.py new file mode 100644 index 00000000..a5db3dca --- /dev/null +++ b/acacore/database/table_keyvalue.py @@ -0,0 +1,75 @@ +from sqlite3 import Connection +from typing import Any +from typing import Generator +from typing import Generic +from typing import overload +from typing import Type +from typing import TypeVar + +from pydantic import BaseModel + +from .table import Table + +M = TypeVar("M", bound=BaseModel) + + +class KeysTableModel(BaseModel): + key: str + value: object | None + + +class KeysTable(Generic[M]): + def __init__(self, database: Connection, model: Type[M], name: str) -> None: + self.table: Table[KeysTableModel] = Table(database, KeysTableModel, name) + self.model: Type[M] = model + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.table.name!r}, {self.model.__name__})" + + def __iter__(self) -> Generator[tuple[str, object | None], None, None]: + yield from ((kv.key, kv.value) for kv in self.table.select()) + + def __getitem__(self, key: str) -> Any | None: # noqa: ANN401 + return self.get(key) + + def __setitem__(self, key: str, value: object | None) -> None: + if key not in self.model.model_fields: + raise AttributeError(f"{self.model.__name__!r} object has no attribute {key!r}") + + self.table.insert(KeysTableModel(key=key, value=value), on_exists="replace") + + def create_sql(self, *, exist_ok: bool = False) -> str: + return self.table.create_sql(exist_ok=exist_ok) + + def create(self, *, exist_ok: bool = False): + self.table.create(exist_ok=exist_ok) + return self + + def set(self, obj: M): + self.table.insert(*(KeysTableModel(key=k, value=o) for k, o in obj.model_dump().items()), on_exists="replace") + + @overload + def get(self) -> M | None: ... + + @overload + def get(self, key: str) -> Any | None: ... # noqa: ANN401 + + def get(self, key: str | None = None) -> M | Any | None: + if key is not None and key not in self.model.model_fields: + raise AttributeError(f"{self.model.__name__!r} object has no attribute {key!r}") + + items = self.table.select().fetchall() + if not items: + return None + obj = self.model.model_validate({i.key: i.value for i in items}) + if key is not None: + return getattr(obj, key) + else: + return obj + + def update(self, **kwargs: object | None): + if missing_keys := [k for k in kwargs if k not in self.model.model_fields]: + raise AttributeError( + f"Fields {', '.join(map(repr, missing_keys))} do not exist in model {self.model.__name__!r}" + ) + return self.table.insert(*(KeysTableModel(key=k, value=o) for k, o in kwargs.items()), on_exists="replace") diff --git a/acacore/database/table_view.py b/acacore/database/table_view.py new file mode 100644 index 00000000..fb4a581b --- /dev/null +++ b/acacore/database/table_view.py @@ -0,0 +1,59 @@ +from sqlite3 import Connection +from typing import Generator +from typing import Generic +from typing import Self +from typing import Type +from typing import TypeVar + +from pydantic import BaseModel + +from .column import SQLValue +from .cursor import Cursor +from .table import _Where +from .table import Table + +M = TypeVar("M", bound=BaseModel) + + +class View(Generic[M]): + def __init__( + self, + database: Connection, + model: Type[M], + name: str, + select: str, + ) -> None: + self.database: Connection = database + self.model: Type[M] = model + self.name: str = name + self.select_stmt: str = select + self._table: Table[M] = Table(self.database, self.model, self.name) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.name!r}, {self.model.__name__})" + + def __iter__(self) -> Generator[M, None, None]: + yield from self.select() + + def __getitem__(self, where: _Where | M) -> M | None: + return self._table.select(where, limit=1).fetchone() + + def __contains__(self, where: M) -> bool: + return self._table.select(where, limit=1).cursor.fetchone() is not None + + def create_sql(self, *, exist_ok: bool = False) -> str: + return f"create view {'if not exists' if exist_ok else ''} {self.name} as {self.select_stmt}" + + def create(self, *, exist_ok: bool = False) -> Self: + self.database.execute(self.create_sql(exist_ok=exist_ok)) + return self + + def select( + self, + where: _Where | None = None, + params: list[SQLValue] | None = None, + order_by: list[tuple[str, str]] | None = None, + limit: int | None = None, + offset: int | None = None, + ) -> Cursor[M]: + return self._table.select(where, params, order_by, limit, offset) diff --git a/acacore/database/upgrade.py b/acacore/database/upgrade.py index 1d371395..4d1e3191 100644 --- a/acacore/database/upgrade.py +++ b/acacore/database/upgrade.py @@ -16,7 +16,7 @@ ] -from .files_db import FileDB +from .files_db import FilesDB def get_db_version(conn: Connection) -> Version | None: @@ -153,12 +153,14 @@ def upgrade_3to3_0_2(conn: Connection) -> Version: return set_db_version(conn, Version("3.0.2")) +# noinspection SqlResolve def upgrade_3_0_2to3_0_6(conn: Connection) -> Version: conn.execute("update Files set action = 'ignore' where action = 'template'") conn.commit() return set_db_version(conn, Version("3.0.6")) +# noinspection SqlResolve def upgrade_3_0_6to3_0_7(conn: Connection) -> Version: def convert_action_data(data: dict) -> dict | None: if (reidentify := data.get("reidentify")) and reidentify.get("on_fail"): @@ -180,6 +182,7 @@ def convert_action_data(data: dict) -> dict | None: return set_db_version(conn, Version("3.0.7")) +# noinspection SqlResolve def upgrade_3_1to3_2(conn: Connection) -> Version: def convert_action_data(data: dict) -> dict: if not data.get("convert"): @@ -235,7 +238,8 @@ def get_upgrade_function(current_version: Version, latest_version: Version) -> C return lambda _: latest_version -def is_latest(db: FileDB, *, raise_on_difference: bool = False) -> bool: +# noinspection SqlResolve +def is_latest(db: FilesDB, *, raise_on_difference: bool = False) -> bool: """ Check if a database is using the latest version of acacore. @@ -249,7 +253,7 @@ def is_latest(db: FileDB, *, raise_on_difference: bool = False) -> bool: if not db.is_initialised(check_views=False, check_indices=False): raise DatabaseError("Database is not initialised") - current_version: Version | None = get_db_version(db) + current_version: Version | None = get_db_version(db.connection) latest_version: Version = Version(__version__) if not current_version: @@ -262,7 +266,7 @@ def is_latest(db: FileDB, *, raise_on_difference: bool = False) -> bool: return current_version == latest_version -def upgrade(db: FileDB): +def upgrade(db: FilesDB): """ Upgrade a database to the latest version of acacore. @@ -270,15 +274,15 @@ def upgrade(db: FileDB): """ if not db.is_initialised(check_views=False, check_indices=False): raise DatabaseError("Database is not initialised") - if db.committed_changes != db.total_changes: + if db.uncommitted_changes: raise DatabaseError("Database has uncommited transactions") if is_latest(db): return - current_version: Version = get_db_version(db) + current_version: Version = get_db_version(db.connection) latest_version: Version = Version(__version__) while current_version < latest_version: update_function = get_upgrade_function(current_version, latest_version) - current_version = update_function(db) + current_version = update_function(db.connection)