diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a69ddae7..b28cd653 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.12.0 + rev: v1.13.0 hooks: - id: mypy entry: bash -c "poetry run mypy ." diff --git a/README.md b/README.md index d5d30dee..1966dc91 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ Copyright (c) 2023-2024 IIASA - Energy, Climate, and Environment Program (ECE) [![license: MIT](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://github.com/iiasa/ixmp4/blob/main/LICENSE) [![python](https://img.shields.io/badge/python-3.10_|_3.11_|_3.12_|_3.13-blue?logo=python&logoColor=white)](https://github.com/iiasa/ixmp4) [![Code style: ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) +[![Checked with mypy](https://www.mypy-lang.org/static/mypy_badge.svg)](https://mypy-lang.org/) ## Overview diff --git a/doc/source/conf.py b/doc/source/conf.py index 7c2aca4f..0dd0f1d2 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -31,7 +31,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] +exclude_patterns: list[str] = [] # -- Options for HTML output ------------------------------------------------- diff --git a/ixmp4/cli/__init__.py b/ixmp4/cli/__init__.py index 09262bb3..59163c1b 100644 --- a/ixmp4/cli/__init__.py +++ b/ixmp4/cli/__init__.py @@ -42,7 +42,7 @@ def login( prompt=True, hide_input=True, ), -): +) -> None: try: auth = ManagerAuth(username, password, str(settings.manager_url)) user = auth.get_user() @@ -62,7 +62,7 @@ def login( @app.command() -def logout(): +def logout() -> None: if typer.confirm( "Are you sure you want to log out and delete locally saved credentials?" ): @@ -80,7 +80,7 @@ def test( with_backend: Optional[bool] = False, with_benchmarks: Optional[bool] = False, dry: Optional[bool] = False, - ): + ) -> None: opts = [ "--cov-report", "xml:.coverage.xml", diff --git a/ixmp4/cli/platforms.py b/ixmp4/cli/platforms.py index cc49169f..d037c1e7 100644 --- a/ixmp4/cli/platforms.py +++ b/ixmp4/cli/platforms.py @@ -1,7 +1,8 @@ import re +from collections.abc import Generator, Iterator from itertools import cycle from pathlib import Path -from typing import Generator, Optional +from typing import Any, Optional, TypeVar import typer from rich.progress import Progress, track @@ -21,7 +22,7 @@ app = typer.Typer() -def validate_name(name: str): +def validate_name(name: str) -> str: match = re.match(r"^[\w\-_]*$", name) if match is None: raise typer.BadParameter("Platform name must be slug-like.") @@ -29,7 +30,7 @@ def validate_name(name: str): return name -def validate_dsn(dsn: str | None): +def validate_dsn(dsn: str | None) -> str | None: if dsn is None: return None match = re.match(r"^(sqlite|postgresql\+psycopg|https|http)(\:\/\/)", dsn) @@ -41,7 +42,7 @@ def validate_dsn(dsn: str | None): return dsn -def prompt_sqlite_dsn(name: str): +def prompt_sqlite_dsn(name: str) -> str: path = sqlite.get_database_path(name) dsn = sqlite.get_dsn(path) if path.exists(): @@ -75,7 +76,7 @@ def add( help="Data source name. Can be a http(s) URL or a database connection string.", callback=validate_dsn, ), -): +) -> None: try: settings.toml.get_platform(name) raise typer.BadParameter( @@ -95,11 +96,11 @@ def add( utils.good("\nPlatform added successfully.") -def prompt_sqlite_removal(dsn: str): +def prompt_sqlite_removal(dsn: str) -> None: path = Path(dsn.replace("sqlite://", "")) path_str = typer.style(path, fg=typer.colors.CYAN) if typer.confirm( - "Do you want to remove the associated database file at " f"{path_str} as well?" # type: ignore + "Do you want to remove the associated database file at " f"{path_str} as well?" ): path.unlink() utils.echo("\nDatabase file deleted.") @@ -112,7 +113,7 @@ def remove( name: str = typer.Argument( ..., help="The string identifier of the platform to remove." ), -): +) -> None: try: platform = settings.toml.get_platform(name) except PlatformNotFound: @@ -127,7 +128,7 @@ def remove( settings.toml.remove_platform(name) -def tabulate_toml_platforms(platforms: list[TomlPlatformInfo]): +def tabulate_toml_platforms(platforms: list[TomlPlatformInfo]) -> None: toml_path_str = typer.style(settings.toml.path, fg=typer.colors.CYAN) utils.echo(f"\nPlatforms registered in '{toml_path_str}'") if len(platforms): @@ -140,7 +141,7 @@ def tabulate_toml_platforms(platforms: list[TomlPlatformInfo]): def tabulate_manager_platforms( platforms: list[ManagerPlatformInfo], -): +) -> None: manager_url_str = typer.style(settings.manager.url, fg=typer.colors.CYAN) utils.echo(f"\nPlatforms accessible via '{manager_url_str}'") utils.echo("\nName".ljust(21) + "Access".ljust(10) + "Notice") @@ -154,7 +155,7 @@ def tabulate_manager_platforms( @app.command("list", help="Lists all registered platforms.") -def list_(): +def list_() -> None: tabulate_toml_platforms(settings.toml.list_platforms()) if settings.manager is not None: tabulate_manager_platforms(settings.manager.list_platforms()) @@ -166,7 +167,8 @@ def list_(): "revision." ) ) -def upgrade(): +def upgrade() -> None: + platform_list: list[ManagerPlatformInfo] | list[TomlPlatformInfo] if settings.managed: utils.echo( f"Establishing self-signed admin connection to '{settings.manager_url}'." @@ -232,7 +234,7 @@ def generate( num_datapoints: int = typer.Option( 30_000, "--datapoints", help="Number of mock datapoints to generate." ), -): +) -> None: try: platform = Platform(platform_name) except PlatformNotFound: @@ -270,7 +272,12 @@ def generate( utils.good("Done!") -def create_cycle(generator: Generator, name: str, total: int): +T = TypeVar("T") + + +def create_cycle( + generator: Generator[T, Any, None], name: str, total: int +) -> Iterator[T]: return cycle( [ m @@ -283,7 +290,7 @@ def create_cycle(generator: Generator, name: str, total: int): ) -def generate_data(generator: MockDataGenerator): +def generate_data(generator: MockDataGenerator) -> None: model_names = create_cycle( generator.yield_model_names(), "Model", generator.num_models ) @@ -301,7 +308,7 @@ def generate_data(generator: MockDataGenerator): progress.advance(task, len(df)) -def _shorten(value: str, length: int): +def _shorten(value: str, length: int) -> str: """Shorten and adjust a string to a given length adding `...` if necessary""" if len(value) > length - 4: value = value[: length - 4] + "..." diff --git a/ixmp4/cli/server.py b/ixmp4/cli/server.py index 9d00a1ce..083205a3 100644 --- a/ixmp4/cli/server.py +++ b/ixmp4/cli/server.py @@ -2,7 +2,7 @@ from typing import Optional import typer -import uvicorn # type: ignore[import] +import uvicorn from fastapi.openapi.utils import get_openapi from ixmp4.conf import settings @@ -33,7 +33,9 @@ def start( @app.command() -def dump_schema(output_file: Optional[typer.FileTextWrite] = typer.Option(None, "-o")): +def dump_schema( + output_file: Optional[typer.FileTextWrite] = typer.Option(None, "-o"), +) -> None: schema = get_openapi( title=v1.title, version=v1.version, diff --git a/ixmp4/conf/__init__.py b/ixmp4/conf/__init__.py index 2d765ba3..111d4ad0 100644 --- a/ixmp4/conf/__init__.py +++ b/ixmp4/conf/__init__.py @@ -3,6 +3,4 @@ from ixmp4.conf.settings import Settings load_dotenv() -# strict typechecking fails due to a bug -# https://docs.pydantic.dev/visual_studio_code/#adding-a-default-with-field -settings = Settings() # type: ignore +settings = Settings() diff --git a/ixmp4/conf/auth.py b/ixmp4/conf/auth.py index 356492ff..aa31c64a 100644 --- a/ixmp4/conf/auth.py +++ b/ixmp4/conf/auth.py @@ -1,5 +1,7 @@ import logging +from collections.abc import Generator from datetime import datetime, timedelta +from typing import Any, cast from uuid import uuid4 import httpx @@ -13,10 +15,11 @@ class BaseAuth(object): - def __call__(self, *args, **kwargs): + # This should never be called + def __call__(self, *args: Any, **kwargs: Any) -> httpx.Request: raise NotImplementedError - def auth_flow(self, request): + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, Any, None]: yield self(request) def get_user(self) -> User: @@ -39,7 +42,7 @@ def __init__(self, secret: str, username: str = "ixmp4"): ) self.token = self.get_local_jwt() - def __call__(self, r): + def __call__(self, r: httpx.Request) -> httpx.Request: try: jwt.decode(self.token, self.secret, algorithms=["HS256"]) except (jwt.InvalidTokenError, jwt.ExpiredSignatureError): @@ -48,7 +51,7 @@ def __call__(self, r): r.headers["Authorization"] = "Bearer " + self.token return r - def get_local_jwt(self): + def get_local_jwt(self) -> str: self.jti = uuid4().hex return jwt.encode( { @@ -63,7 +66,7 @@ def get_local_jwt(self): algorithm="HS256", ) - def get_expiration_timestamp(self): + def get_expiration_timestamp(self) -> int: return int((datetime.now() + timedelta(minutes=15)).timestamp()) def get_user(self) -> User: @@ -72,11 +75,11 @@ def get_user(self) -> User: class AnonymousAuth(BaseAuth, httpx.Auth): - def __init__(self): + def __init__(self) -> None: self.user = anonymous_user logger.info("Connecting to service anonymously and without credentials.") - def __call__(self, r): + def __call__(self, r: httpx.Request) -> httpx.Request: return r def get_user(self) -> User: @@ -97,7 +100,7 @@ def __init__( self.password = password self.obtain_jwt() - def __call__(self, r): + def __call__(self, r: httpx.Request) -> httpx.Request: try: jwt.decode( self.access_token, @@ -109,7 +112,7 @@ def __call__(self, r): r.headers["Authorization"] = "Bearer " + self.access_token return r - def obtain_jwt(self): + def obtain_jwt(self) -> None: res = self.client.post( "/token/obtain/", json={ @@ -133,7 +136,7 @@ def obtain_jwt(self): self.set_user(self.access_token) self.refresh_token = json["refresh"] - def refresh_or_reobtain_jwt(self): + def refresh_or_reobtain_jwt(self) -> None: try: jwt.decode( self.refresh_token, @@ -143,7 +146,7 @@ def refresh_or_reobtain_jwt(self): except jwt.ExpiredSignatureError: self.obtain_jwt() - def refresh_jwt(self): + def refresh_jwt(self) -> None: res = self.client.post( "/token/refresh/", json={ @@ -157,13 +160,16 @@ def refresh_jwt(self): self.access_token = res.json()["access"] self.set_user(self.access_token) - def decode_token(self, token: str): - return jwt.decode( - token, - options={"verify_signature": False, "verify_exp": False}, + def decode_token(self, token: str) -> dict[str, Any]: + return cast( + dict[str, Any], + jwt.decode( + token, + options={"verify_signature": False, "verify_exp": False}, + ), ) - def set_user(self, token: str): + def set_user(self, token: str) -> None: token_dict = self.decode_token(token) user_dict = token_dict["user"] self.user = User(**user_dict, jti=token_dict.get("jti")) diff --git a/ixmp4/conf/credentials.py b/ixmp4/conf/credentials.py index 0eab7d3b..96bac8d3 100644 --- a/ixmp4/conf/credentials.py +++ b/ixmp4/conf/credentials.py @@ -5,16 +5,16 @@ class Credentials(object): - credentials: dict + credentials: dict[str, dict[str, str]] def __init__(self, toml_file: Path) -> None: self.path = toml_file self.load() - def load(self): + def load(self) -> None: self.credentials = toml.load(self.path) - def dump(self): + def dump(self) -> None: f = self.path.open("w+") toml.dump(self.credentials, f) @@ -22,14 +22,14 @@ def get(self, key: str) -> tuple[str, str]: c = self.credentials[key] return (c["username"], c["password"]) - def set(self, key: str, username: str, password: str): + def set(self, key: str, username: str, password: str) -> None: self.credentials[key] = { "username": username, "password": password, } self.dump() - def clear(self, key: str): + def clear(self, key: str) -> None: with suppress(KeyError): del self.credentials[key] self.dump() diff --git a/ixmp4/conf/manager.py b/ixmp4/conf/manager.py index 12b5b2f1..faaaa6aa 100644 --- a/ixmp4/conf/manager.py +++ b/ixmp4/conf/manager.py @@ -3,11 +3,15 @@ import os import re from functools import lru_cache +from typing import Any, cast import httpx import pandas as pd from pydantic import Field +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import TypedDict, Unpack + from ixmp4.core.exceptions import ManagerApiError, PlatformNotFound from .auth import BaseAuth @@ -17,10 +21,10 @@ logger = logging.getLogger(__name__) -class hashabledict(dict): +class hashabledict(dict[str, Any]): """Hashable dict type used for caching.""" - def __hash__(self): + def __hash__(self) -> int: return hash(tuple(sorted(self.items()))) @@ -40,10 +44,14 @@ class Accessibilty(str, enum.Enum): accessibility: Accessibilty +class JtiKwargs(TypedDict, total=False): + jti: str | None + + class ManagerConfig(Config): template_pattern = re.compile(r"(\{env\:(\w+)\})") - def __init__(self, url: str, auth: BaseAuth, remote: bool = False) -> None: + def __init__(self, url: str, auth: BaseAuth | None, remote: bool = False) -> None: # TODO: Find the sweet-spot for `maxsize` # -> a trade-off between memory usage # and load on the management service @@ -74,8 +82,13 @@ def expand_dsn(self, dsn: str) -> str: return dsn def _uncached_request( - self, method: str, path: str, *args, jti: str | None = None, **kwargs - ): + self, + method: str, + path: str, + params: dict[str, int | None] | None = None, + json: dict[str, Any] | list[Any] | tuple[Any] | None = None, + jti: str | None = None, + ) -> dict[str, Any]: del jti # `jti` is only used to affect `@lru_cache` # if the token id changes a new cache entry will be created @@ -84,50 +97,46 @@ def _uncached_request( # NOTE: since this cache is not shared amongst processes, it's efficacy # declines with the scale of the whole infrastructure unless counteracted # with increased cache size / memory usage - res = self.client.request(method, path, *args, **kwargs) + res = self.client.request(method, path, params=params, json=json) if res.status_code != 200: raise ManagerApiError(f"[{str(res.status_code)}] {res.text}") - return res.json() + # NOTE we can assume this type, might get replaced with scse-toolkit + return cast(dict[str, Any], res.json()) def _request( self, method: str, path: str, - *args, - params: dict | None = None, - json: dict | list | tuple | None = None, - **kwargs, - ): + # Seems to be just that based on references + params: dict[str, int | None] | None = None, + # Seems to not be included with any references? + json: dict[str, Any] | list[Any] | tuple[Any] | None = None, + **kwargs: Unpack[JtiKwargs], + ) -> dict[str, Any]: if params is not None: params = hashabledict(params) if json is not None: - if isinstance(json, dict): - json = hashabledict(json) - else: - json = tuple(json) + json = hashabledict(json) if isinstance(json, dict) else tuple(json) logger.debug(f"Trying cache: {method} {path} {params} {json}") - return self._cached_request( - method, path, *args, params=params, json=json, **kwargs - ) + return self._cached_request(method, path, params=params, json=json, **kwargs) - def fetch_platforms(self, **kwargs) -> list[ManagerPlatformInfo]: + def fetch_platforms(self, **kwargs: Unpack[JtiKwargs]) -> list[ManagerPlatformInfo]: json = self._request("GET", "/ixmp4", params={"page_size": -1}, **kwargs) return [ManagerPlatformInfo(**c) for c in json["results"]] - def list_platforms(self, **kwargs) -> list[ManagerPlatformInfo]: + def list_platforms(self, **kwargs: Unpack[JtiKwargs]) -> list[ManagerPlatformInfo]: platforms = self.fetch_platforms(**kwargs) for i, p in enumerate(platforms): - if self.remote: - platforms[i].dsn = p.url - else: - platforms[i].dsn = self.expand_dsn(p.dsn) + platforms[i].dsn = p.url if self.remote else self.expand_dsn(p.dsn) return platforms - def get_platform(self, key: str, **kwargs) -> ManagerPlatformInfo: + def get_platform( + self, key: str, **kwargs: Unpack[JtiKwargs] + ) -> ManagerPlatformInfo: for p in self.list_platforms(**kwargs): if p.name == key: return p @@ -138,7 +147,10 @@ def get_platform(self, key: str, **kwargs) -> ManagerPlatformInfo: ) def fetch_user_permissions( - self, user: User, platform: ManagerPlatformInfo, **kwargs + self, + user: User, + platform: ManagerPlatformInfo, + **kwargs: Unpack[JtiKwargs], ) -> pd.DataFrame: if not user.is_authenticated: return pd.DataFrame( @@ -159,7 +171,10 @@ def fetch_user_permissions( ) def fetch_group_permissions( - self, group_id: int, platform: ManagerPlatformInfo, **kwargs + self, + group_id: int, + platform: ManagerPlatformInfo, + **kwargs: Unpack[JtiKwargs], ) -> pd.DataFrame: json = self._request( "GET", @@ -186,7 +201,10 @@ def fetch_platforms(self) -> list[ManagerPlatformInfo]: return self.platforms def fetch_user_permissions( - self, user: User, platform: ManagerPlatformInfo, **kwargs + self, + user: User, + platform: ManagerPlatformInfo, + **kwargs: Unpack[JtiKwargs], ) -> pd.DataFrame: pdf = self.permissions return pdf.where(pdf["group"].isin(user.groups)).where( @@ -194,7 +212,10 @@ def fetch_user_permissions( ) def fetch_group_permissions( - self, group_id: int, platform: ManagerPlatformInfo, **kwargs + self, + group_id: int, + platform: ManagerPlatformInfo, + **kwargs: Unpack[JtiKwargs], ) -> pd.DataFrame: pdf = self.permissions return pdf.where(pdf["group"] == group_id).where(pdf["instance"] == platform.id) diff --git a/ixmp4/conf/settings.py b/ixmp4/conf/settings.py index 1a217711..d7182e61 100644 --- a/ixmp4/conf/settings.py +++ b/ixmp4/conf/settings.py @@ -2,7 +2,7 @@ import logging import logging.config from pathlib import Path -from typing import Literal +from typing import Any, Literal from httpx import ConnectError from pydantic import Field, HttpUrl, field_validator @@ -25,10 +25,10 @@ class Settings(BaseSettings): mode: Literal["production"] | Literal["development"] | Literal["debug"] = ( "production" ) - storage_directory: Path = Field("~/.local/share/ixmp4/") + storage_directory: Path = Field(Path("~/.local/share/ixmp4/")) secret_hs256: str = "default_secret_hs256" migration_db_uri: str = "sqlite:///./run/db.sqlite" - manager_url: HttpUrl = Field("https://api.manager.ece.iiasa.ac.at/v1") + manager_url: HttpUrl = Field(HttpUrl("https://api.manager.ece.iiasa.ac.at/v1")) managed: bool = True max_page_size: int = 10_000 default_page_size: int = 5_000 @@ -39,7 +39,8 @@ class Settings(BaseSettings): client_timeout: int = Field(30) model_config = SettingsConfigDict(env_prefix="ixmp4_", extra="allow") - def __init__(self, *args, **kwargs) -> None: + # We don't pass any args or kwargs, so allow all to flow through + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.storage_directory.mkdir(parents=True, exist_ok=True) @@ -52,50 +53,52 @@ def __init__(self, *args, **kwargs) -> None: self.configure_logging(self.mode) - self._credentials = None - self._toml = None - self._default_auth = None - self._manager = None + self._credentials: Credentials | None = None + self._toml: TomlConfig | None = None + self._default_auth: ManagerAuth | AnonymousAuth | None = None + self._manager: ManagerConfig | None = None logger.debug(f"Settings loaded: {self}") @property - def credentials(self): + def credentials(self) -> Credentials: if self._credentials is None: self.load_credentials() - return self._credentials + # For this and similar below, mypy doesn't realize that the attribute will not + # be None after the load() call + return self._credentials # type: ignore[return-value] @property - def default_credentials(self): + def default_credentials(self) -> tuple[str, str] | None: try: return self.credentials.get("default") except KeyError: - pass + return None @property - def toml(self): + def toml(self) -> TomlConfig: if self._toml is None: self.load_toml_config() - return self._toml + return self._toml # type: ignore[return-value] @property - def default_auth(self): + def default_auth(self) -> ManagerAuth | AnonymousAuth | None: if self._default_auth is None: self.get_auth() return self._default_auth @property - def manager(self): + def manager(self) -> ManagerConfig: if self._manager is None: self.load_manager_config() - return self._manager + return self._manager # type: ignore[return-value] - def load_credentials(self): + def load_credentials(self) -> None: credentials_config = self.storage_directory / "credentials.toml" credentials_config.touch() self._credentials = Credentials(credentials_config) - def get_auth(self): + def get_auth(self) -> None: if self.default_credentials is not None: try: self._default_auth = ManagerAuth( @@ -112,12 +115,12 @@ def get_auth(self): else: self._default_auth = AnonymousAuth() - def load_manager_config(self): + def load_manager_config(self) -> None: self._manager = ManagerConfig( str(self.manager_url), self.default_auth, remote=True ) - def load_toml_config(self): + def load_toml_config(self) -> None: if self.default_auth is not None: toml_user = self.default_auth.get_user() if not toml_user.is_authenticated: @@ -130,14 +133,14 @@ def load_toml_config(self): self._toml = TomlConfig(toml_config, toml_user) @field_validator("storage_directory") - def expand_user(cls, v): + def expand_user(cls, v: Path) -> Path: # translate ~/asdf into /home/user/asdf return Path.expanduser(v) - def get_server_logconf(self): + def get_server_logconf(self) -> Path: return here / "./logging/server.json" - def configure_logging(self, config: str): + def configure_logging(self, config: str) -> None: self.access_file = str((self.log_dir / "access.log").absolute()) self.debug_file = str((self.log_dir / "debug.log").absolute()) self.error_file = str((self.log_dir / "error.log").absolute()) @@ -147,7 +150,7 @@ def configure_logging(self, config: str): config_dict = json.load(file) logging.config.dictConfig(config_dict) - def check_credentials(self): + def check_credentials(self) -> None: if self.default_credentials is not None: username, password = self.default_credentials ManagerAuth(username, password, str(self.manager_url)) diff --git a/ixmp4/conf/toml.py b/ixmp4/conf/toml.py index ad724497..a53c4875 100644 --- a/ixmp4/conf/toml.py +++ b/ixmp4/conf/toml.py @@ -27,7 +27,7 @@ def load(self) -> None: list_: list[dict[str, Any]] = [{"name": k, **v} for k, v in dict_.items()] self.platforms = {x["name"]: TomlPlatformInfo(**x) for x in list_} - def dump(self): + def dump(self) -> None: obj = {} for c in self.platforms.values(): dict_ = json.loads(c.model_dump_json()) @@ -47,7 +47,7 @@ def get_platform(self, key: str) -> TomlPlatformInfo: except KeyError as e: raise PlatformNotFound(f"Platform '{key}' was not found.") from e - def add_platform(self, name: str, dsn: str): + def add_platform(self, name: str, dsn: str) -> None: try: self.get_platform(name) except PlatformNotFound: @@ -56,7 +56,7 @@ def add_platform(self, name: str, dsn: str): return raise PlatformNotUnique(f"Platform '{name}' already exists, remove it first.") - def remove_platform(self, key: str): + def remove_platform(self, key: str) -> None: try: del self.platforms[key] except KeyError as e: diff --git a/ixmp4/core/__init__.py b/ixmp4/core/__init__.py index a86c4ce2..59aca1e3 100644 --- a/ixmp4/core/__init__.py +++ b/ixmp4/core/__init__.py @@ -1,11 +1,10 @@ -# flake8: noqa from .iamc.variable import Variable as Variable from .model import Model as Model from .optimization.equation import Equation as Equation from .optimization.indexset import IndexSet as IndexSet +from .optimization.parameter import Parameter as Parameter from .optimization.scalar import Scalar as Scalar from .optimization.table import Table as Table -from .optimization.parameter import Parameter as Parameter # TODO Is this really the name we want to use? from .optimization.variable import Variable as OptimizationVariable diff --git a/ixmp4/core/base.py b/ixmp4/core/base.py index 054f3e2e..f06a3264 100644 --- a/ixmp4/core/base.py +++ b/ixmp4/core/base.py @@ -17,8 +17,10 @@ class BaseModelFacade(BaseFacade): backend: Backend _model: BaseModel - def __init__(self, *args, _model: BaseModel | None = None, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__( + self, _model: BaseModel | None = None, **kwargs: Backend | None + ) -> None: + super().__init__(**kwargs) if _model is not None: self._model = _model diff --git a/ixmp4/core/decorators.py b/ixmp4/core/decorators.py index 126dcc07..15fee2e6 100644 --- a/ixmp4/core/decorators.py +++ b/ixmp4/core/decorators.py @@ -1,16 +1,32 @@ import functools +from collections.abc import Callable +from typing import Any, TypeVar import pandera as pa from pandera.errors import SchemaError as PanderaSchemaError +from pandera.typing import DataFrame + +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + +from ixmp4.data.abstract.base import BaseRepository from .exceptions import SchemaError +T = TypeVar("T") + -def check_types(func): +def check_types( + func: Callable[[BaseRepository, DataFrame[T]], None], +) -> Callable[[BaseRepository, DataFrame[T]], None]: checked_func = pa.check_types(func) @functools.wraps(func) - def wrapper(*args, skip_validation: bool = False, **kwargs): + def wrapper( + *args: Unpack[tuple[BaseRepository, DataFrame[T]]], + skip_validation: bool = False, + **kwargs: Any, + ) -> None: if skip_validation: return func(*args, **kwargs) try: diff --git a/ixmp4/core/exceptions.py b/ixmp4/core/exceptions.py index 6c43fbf0..6441f939 100644 --- a/ixmp4/core/exceptions.py +++ b/ixmp4/core/exceptions.py @@ -1,23 +1,31 @@ -from typing import ClassVar, Dict +from typing import Any, ClassVar -registry: dict = dict() +# TODO Import this from typing when dropping support for 3.10 +from typing_extensions import Self + +registry: dict[str, type["IxmpError"]] = dict() class ProgrammingError(Exception): pass -ExcMeta: type = type(Exception) - - -class RemoteExceptionMeta(ExcMeta): - def __new__(cls, name, bases, namespace, **kwargs): +class RemoteExceptionMeta(type): + def __new__( + cls, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> type["IxmpError"]: http_error_name = namespace.get("http_error_name", None) if http_error_name is not None: try: return registry[http_error_name] except KeyError: - registry[http_error_name] = super().__new__( + # NOTE Since this is a meta class, super().__new__() won't ever return + # this type, but the IxmpError instead + registry[http_error_name] = super().__new__( # type: ignore[assignment] cls, name, bases, namespace, **kwargs ) return registry[http_error_name] @@ -29,14 +37,14 @@ class IxmpError(Exception, metaclass=RemoteExceptionMeta): _message: str = "" http_status_code: int = 500 http_error_name: ClassVar[str] = "ixmp_error" - kwargs: Dict + kwargs: dict[str, Any] def __init__( self, - *args, + *args: str, message: str | None = None, status_code: int | None = None, - **kwargs, + **kwargs: Any, ) -> None: if len(args) > 0: self._message = args[0] @@ -64,7 +72,7 @@ def message(self) -> str: return message @classmethod - def from_dict(cls, dict_): + def from_dict(cls, dict_: dict[str, Any]) -> Self: return cls(message=dict_["message"], **dict_["kwargs"]) diff --git a/ixmp4/core/iamc/__init__.py b/ixmp4/core/iamc/__init__.py index 34a7eb59..e20e9c63 100644 --- a/ixmp4/core/iamc/__init__.py +++ b/ixmp4/core/iamc/__init__.py @@ -1,3 +1 @@ -# flake8: noqa - -from .data import RunIamcData, PlatformIamcData +from .data import PlatformIamcData, RunIamcData diff --git a/ixmp4/core/iamc/data.py b/ixmp4/core/iamc/data.py index e137123b..18329412 100644 --- a/ixmp4/core/iamc/data.py +++ b/ixmp4/core/iamc/data.py @@ -1,11 +1,17 @@ -from typing import Optional +from collections.abc import Iterable +from typing import Optional, TypeVar import pandas as pd import pandera as pa +from pandera.engines import pandas_engine from pandera.typing import Series +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data.abstract import DataPoint as DataPointModel from ixmp4.data.abstract import Run +from ixmp4.data.abstract.iamc.datapoint import EnumerateKwargs from ixmp4.data.backend import Backend from ..base import BaseFacade @@ -16,7 +22,9 @@ class RemoveDataPointFrameSchema(pa.DataFrameModel): type: Optional[Series[pa.String]] = pa.Field(isin=[t for t in DataPointModel.Type]) step_year: Optional[Series[pa.Int]] = pa.Field(coerce=True, nullable=True) - step_datetime: Optional[Series[pa.DateTime]] = pa.Field(coerce=True, nullable=True) + step_datetime: Optional[Series[pandas_engine.DateTime]] = pa.Field( + coerce=True, nullable=True + ) step_category: Optional[Series[pa.String]] = pa.Field(nullable=True) region: Optional[Series[pa.String]] = pa.Field(coerce=True) @@ -42,9 +50,10 @@ def convert_to_std_format(df: pd.DataFrame, join_runs: bool) -> pd.DataFrame: df.rename(columns={"step_year": "year"}, inplace=True) time_col = "year" else: + T = TypeVar("T", bool, float, int, str) - def map_step_column(df: pd.Series): - df["time"] = df[MAP_STEP_COLUMN[df.type]] + def map_step_column(df: "pd.Series[T]") -> "pd.Series[T]": + df["time"] = df[MAP_STEP_COLUMN[str(df.type)]] return df df = df.apply(map_step_column, axis=1) @@ -78,8 +87,8 @@ class RunIamcData(BaseFacade): run: Run - def __init__(self, *args, run: Run, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, run: Run, **kwargs: Backend | None) -> None: + super().__init__(**kwargs) self.run = run def _get_or_create_ts(self, df: pd.DataFrame) -> pd.DataFrame: @@ -108,8 +117,8 @@ def add( self, df: pd.DataFrame, type: Optional[DataPointModel.Type] = None, - ): - df = AddDataPointFrameSchema.validate(df) # type:ignore + ) -> None: + df = AddDataPointFrameSchema.validate(df) # type: ignore[assignment] df["run__id"] = self.run.id df = self._get_or_create_ts(df) substitute_type(df, type) @@ -119,8 +128,8 @@ def remove( self, df: pd.DataFrame, type: Optional[DataPointModel.Type] = None, - ): - df = RemoveDataPointFrameSchema.validate(df) # type:ignore + ) -> None: + df = RemoveDataPointFrameSchema.validate(df) # type: ignore[assignment] df["run__id"] = self.run.id df = self._get_or_create_ts(df) substitute_type(df, type) @@ -130,9 +139,9 @@ def remove( def tabulate( self, *, - variable: dict | None = None, - region: dict | None = None, - unit: dict | None = None, + variable: dict[str, str | Iterable[str]] | None = None, + region: dict[str, str | Iterable[str]] | None = None, + unit: dict[str, str | Iterable[str]] | None = None, raw: bool = False, ) -> pd.DataFrame: df = self.backend.iamc.datapoints.tabulate( @@ -154,7 +163,13 @@ def __init__(self, _backend: Backend | None = None) -> None: self.variables = VariableRepository(_backend=_backend) super().__init__(_backend=_backend) - def tabulate(self, *, join_runs: bool = True, raw: bool = False, **kwargs): + def tabulate( + self, + *, + join_runs: bool = True, + raw: bool = False, + **kwargs: Unpack[EnumerateKwargs], + ) -> pd.DataFrame: df = self.backend.iamc.datapoints.tabulate( join_parameters=True, join_runs=join_runs, **kwargs ).dropna(how="all", axis="columns") diff --git a/ixmp4/core/iamc/variable.py b/ixmp4/core/iamc/variable.py index ff304341..437994b1 100644 --- a/ixmp4/core/iamc/variable.py +++ b/ixmp4/core/iamc/variable.py @@ -30,21 +30,21 @@ def created_by(self) -> str | None: return self._model.created_by @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.iamc.variables.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.iamc.variables.docs.delete(self.id) else: self.backend.iamc.variables.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.iamc.variables.docs.delete(self.id) # TODO: silently failing @@ -56,10 +56,7 @@ def __str__(self) -> str: class VariableRepository(BaseFacade): - def create( - self, - name: str, - ) -> Variable: + def create(self, name: str) -> Variable: model = self.backend.iamc.variables.create(name) return Variable(_backend=self.backend, _model=model) @@ -75,9 +72,8 @@ def tabulate(self, name: str | None = None) -> pd.DataFrame: return self.backend.iamc.variables.tabulate(name=name) def _get_variable_id(self, variable: str) -> int | None: - if variable is None: - return None - elif isinstance(variable, str): + # NOTE leaving this check for users without mypy + if isinstance(variable, str): obj = self.backend.iamc.variables.get(variable) return obj.id else: diff --git a/ixmp4/core/meta.py b/ixmp4/core/meta.py index 2bcaa0ce..4eb9e36d 100644 --- a/ixmp4/core/meta.py +++ b/ixmp4/core/meta.py @@ -1,12 +1,17 @@ import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + +from ixmp4.data.abstract.meta import EnumerateKwargs + from .base import BaseFacade class MetaRepository(BaseFacade): - def tabulate(self, **kwargs) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: # TODO: accept list of `Run` instances as arg # TODO: expand run-id to model-scenario-version-id columns return self.backend.meta.tabulate(join_run_index=True, **kwargs).drop( - columns=["id", "type"] + columns=["id", "dtype"] ) diff --git a/ixmp4/core/model.py b/ixmp4/core/model.py index 55c3eb92..9db8c382 100644 --- a/ixmp4/core/model.py +++ b/ixmp4/core/model.py @@ -30,21 +30,21 @@ def created_by(self) -> str | None: return self._model.created_by @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.models.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.models.docs.delete(self.id) else: self.backend.models.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.models.docs.delete(self.id) # TODO: silently failing @@ -75,9 +75,8 @@ def tabulate(self, name: str | None = None) -> pd.DataFrame: return self.backend.models.tabulate(name=name) def _get_model_id(self, model: str) -> int | None: - if model is None: - return None - elif isinstance(model, str): + # NOTE leaving this check for users without mypy + if isinstance(model, str): obj = self.backend.models.get(model) return obj.id else: diff --git a/ixmp4/core/optimization/__init__.py b/ixmp4/core/optimization/__init__.py index d13d3fde..fb1ce0ee 100644 --- a/ixmp4/core/optimization/__init__.py +++ b/ixmp4/core/optimization/__init__.py @@ -1 +1,10 @@ +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict + +from ixmp4.data.backend import Backend + from .data import OptimizationData + + +class InitKwargs(TypedDict): + _backend: Backend | None diff --git a/ixmp4/core/optimization/data.py b/ixmp4/core/optimization/data.py index 69c18ee0..c604dbe1 100644 --- a/ixmp4/core/optimization/data.py +++ b/ixmp4/core/optimization/data.py @@ -1,4 +1,5 @@ from ixmp4.data.abstract import Run +from ixmp4.data.backend import Backend from ..base import BaseFacade from .equation import EquationRepository @@ -20,8 +21,8 @@ class OptimizationData(BaseFacade): tables: TableRepository variables: VariableRepository - def __init__(self, *args, run: Run, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, run: Run, **kwargs: Backend) -> None: + super().__init__(**kwargs) self.equations = EquationRepository(_backend=self.backend, _run=run) self.indexsets = IndexSetRepository(_backend=self.backend, _run=run) self.parameters = ParameterRepository(_backend=self.backend, _run=run) diff --git a/ixmp4/core/optimization/equation.py b/ixmp4/core/optimization/equation.py index 638bf7f0..1cb3ed30 100644 --- a/ixmp4/core/optimization/equation.py +++ b/ixmp4/core/optimization/equation.py @@ -1,8 +1,15 @@ +from collections.abc import Iterable from datetime import datetime -from typing import Any, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar + +if TYPE_CHECKING: + from . import InitKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.core.base import BaseFacade, BaseModelFacade from ixmp4.data.abstract import Docs as DocsModel from ixmp4.data.abstract import Equation as EquationModel @@ -48,12 +55,14 @@ def remove_data(self) -> None: ).data @property - def levels(self) -> list: - return self._model.data.get("levels", []) + def levels(self) -> list[float]: + levels: list[float] = self._model.data.get("levels", []) + return levels @property - def marginals(self) -> list: - return self._model.data.get("marginals", []) + def marginals(self) -> list[float]: + marginals: list[float] = self._model.data.get("marginals", []) + return marginals @property def constrained_to_indexsets(self) -> list[str]: @@ -72,21 +81,21 @@ def created_by(self) -> str | None: return self._model.created_by @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.optimization.equations.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.optimization.equations.docs.delete(self.id) else: self.backend.optimization.equations.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.optimization.equations.docs.delete(self.id) # TODO: silently failing @@ -100,8 +109,8 @@ def __str__(self) -> str: class EquationRepository(BaseFacade): _run: Run - def __init__(self, _run: Run, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, _run: Run, **kwargs: Unpack["InitKwargs"]) -> None: + super().__init__(**kwargs) self._run = _run def create( diff --git a/ixmp4/core/optimization/indexset.py b/ixmp4/core/optimization/indexset.py index d9772b3b..cd901150 100644 --- a/ixmp4/core/optimization/indexset.py +++ b/ixmp4/core/optimization/indexset.py @@ -1,8 +1,14 @@ from datetime import datetime -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar + +if TYPE_CHECKING: + from . import InitKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.core.base import BaseFacade, BaseModelFacade from ixmp4.data.abstract import Docs as DocsModel from ixmp4.data.abstract import IndexSet as IndexSetModel @@ -26,7 +32,9 @@ def name(self) -> str: def data(self) -> list[float | int | str]: return self._model.data - def add(self, data: float | int | list[float | int | str] | str) -> None: + def add( + self, data: float | int | str | list[float] | list[int] | list[str] + ) -> None: """Adds data to an existing IndexSet.""" self.backend.optimization.indexsets.add_data( indexset_id=self._model.id, data=data @@ -76,8 +84,8 @@ def __str__(self) -> str: class IndexSetRepository(BaseFacade): _run: Run - def __init__(self, _run: Run, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, _run: Run, **kwargs: Unpack["InitKwargs"]) -> None: + super().__init__(**kwargs) self._run = _run def create(self, name: str) -> IndexSet: diff --git a/ixmp4/core/optimization/parameter.py b/ixmp4/core/optimization/parameter.py index 32c07295..1e9fcdaa 100644 --- a/ixmp4/core/optimization/parameter.py +++ b/ixmp4/core/optimization/parameter.py @@ -1,12 +1,19 @@ +from collections.abc import Iterable from datetime import datetime -from typing import Any, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar + +if TYPE_CHECKING: + from . import InitKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.core.base import BaseFacade, BaseModelFacade from ixmp4.data.abstract import Docs as DocsModel from ixmp4.data.abstract import Parameter as ParameterModel -from ixmp4.data.abstract import Run +from ixmp4.data.abstract import Run, Unit from ixmp4.data.abstract.optimization import Column @@ -41,12 +48,14 @@ def add(self, data: dict[str, Any] | pd.DataFrame) -> None: ).data @property - def values(self) -> list: - return self._model.data.get("values", []) + def values(self) -> list[float]: + values: list[float] = self._model.data.get("values", []) + return values @property - def units(self) -> list: - return self._model.data.get("units", []) + def units(self) -> list[Unit]: + units: list[Unit] = self._model.data.get("units", []) + return units @property def constrained_to_indexsets(self) -> list[str]: @@ -65,21 +74,21 @@ def created_by(self) -> str | None: return self._model.created_by @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.optimization.parameters.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.optimization.parameters.docs.delete(self.id) else: self.backend.optimization.parameters.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.optimization.parameters.docs.delete(self.id) # TODO: silently failing @@ -93,8 +102,8 @@ def __str__(self) -> str: class ParameterRepository(BaseFacade): _run: Run - def __init__(self, _run: Run, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, _run: Run, **kwargs: Unpack["InitKwargs"]) -> None: + super().__init__(**kwargs) self._run = _run def create( diff --git a/ixmp4/core/optimization/scalar.py b/ixmp4/core/optimization/scalar.py index b4262d80..bfd3d196 100644 --- a/ixmp4/core/optimization/scalar.py +++ b/ixmp4/core/optimization/scalar.py @@ -1,13 +1,21 @@ +from collections.abc import Iterable from datetime import datetime -from typing import ClassVar, Iterable +from typing import TYPE_CHECKING, ClassVar + +if TYPE_CHECKING: + from . import InitKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.core.base import BaseFacade, BaseModelFacade from ixmp4.core.unit import Unit from ixmp4.data.abstract import Docs as DocsModel from ixmp4.data.abstract import Run from ixmp4.data.abstract import Scalar as ScalarModel +from ixmp4.data.abstract import Unit as UnitModel class Scalar(BaseModelFacade): @@ -29,7 +37,7 @@ def value(self) -> float: return self._model.value @value.setter - def value(self, value: float): + def value(self, value: float) -> None: self._model.value = value self.backend.optimization.scalars.update( id=self._model.id, @@ -38,21 +46,19 @@ def value(self, value: float): ) @property - def unit(self): + def unit(self) -> UnitModel: """Associated unit.""" return self._model.unit @unit.setter - def unit(self, unit: str | Unit): - if isinstance(unit, Unit): - unit = unit - else: - unit_model = self.backend.units.get(unit) - unit = Unit(_backend=self.backend, _model=unit_model) + def unit(self, value: str | Unit) -> None: + if isinstance(value, str): + unit_model = self.backend.units.get(value) + value = Unit(_backend=self.backend, _model=unit_model) self._model = self.backend.optimization.scalars.update( id=self._model.id, value=self._model.value, - unit_id=unit.id, + unit_id=value.id, ) @property @@ -68,21 +74,21 @@ def created_by(self) -> str | None: return self._model.created_by @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.optimization.scalars.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.optimization.scalars.docs.delete(self.id) else: self.backend.optimization.scalars.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.optimization.scalars.docs.delete(self.id) # TODO: silently failing @@ -96,8 +102,8 @@ def __str__(self) -> str: class ScalarRepository(BaseFacade): _run: Run - def __init__(self, _run: Run, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, _run: Run, **kwargs: Unpack["InitKwargs"]) -> None: + super().__init__(**kwargs) self._run = _run def create(self, name: str, value: float, unit: str | Unit | None = None) -> Scalar: diff --git a/ixmp4/core/optimization/table.py b/ixmp4/core/optimization/table.py index dad74e6d..3374c209 100644 --- a/ixmp4/core/optimization/table.py +++ b/ixmp4/core/optimization/table.py @@ -1,8 +1,15 @@ +from collections.abc import Iterable from datetime import datetime -from typing import Any, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar + +if TYPE_CHECKING: + from . import InitKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.core.base import BaseFacade, BaseModelFacade from ixmp4.data.abstract import Docs as DocsModel from ixmp4.data.abstract import Run @@ -55,21 +62,21 @@ def created_by(self) -> str | None: return self._model.created_by @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.optimization.tables.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.optimization.tables.docs.delete(self.id) else: self.backend.optimization.tables.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.optimization.tables.docs.delete(self.id) # TODO: silently failing @@ -83,8 +90,8 @@ def __str__(self) -> str: class TableRepository(BaseFacade): _run: Run - def __init__(self, _run: Run, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, _run: Run, **kwargs: Unpack["InitKwargs"]) -> None: + super().__init__(**kwargs) self._run = _run def create( diff --git a/ixmp4/core/optimization/variable.py b/ixmp4/core/optimization/variable.py index cc14a337..5c1b4146 100644 --- a/ixmp4/core/optimization/variable.py +++ b/ixmp4/core/optimization/variable.py @@ -1,8 +1,15 @@ +from collections.abc import Iterable from datetime import datetime -from typing import Any, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar + +if TYPE_CHECKING: + from . import InitKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.core.base import BaseFacade, BaseModelFacade from ixmp4.data.abstract import Docs as DocsModel from ixmp4.data.abstract import OptimizationVariable as VariableModel @@ -48,12 +55,14 @@ def remove_data(self) -> None: ).data @property - def levels(self) -> list: - return self._model.data.get("levels", []) + def levels(self) -> list[float]: + levels: list[float] = self._model.data.get("levels", []) + return levels @property - def marginals(self) -> list: - return self._model.data.get("marginals", []) + def marginals(self) -> list[float]: + marginals: list[float] = self._model.data.get("marginals", []) + return marginals @property def constrained_to_indexsets(self) -> list[str]: @@ -76,21 +85,21 @@ def created_by(self) -> str | None: return self._model.created_by @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.optimization.variables.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.optimization.variables.docs.delete(self.id) else: self.backend.optimization.variables.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.optimization.variables.docs.delete(self.id) # TODO: silently failing @@ -104,8 +113,8 @@ def __str__(self) -> str: class VariableRepository(BaseFacade): _run: Run - def __init__(self, _run: Run, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, _run: Run, **kwargs: Unpack["InitKwargs"]) -> None: + super().__init__(**kwargs) self._run = _run def create( diff --git a/ixmp4/core/platform.py b/ixmp4/core/platform.py index 7362d580..438908e9 100644 --- a/ixmp4/core/platform.py +++ b/ixmp4/core/platform.py @@ -27,6 +27,7 @@ from ixmp4.conf import settings from ixmp4.conf.auth import BaseAuth +from ixmp4.conf.base import PlatformInfo from ixmp4.core.exceptions import PlatformNotFound from ixmp4.data.backend import Backend, RestBackend, SqlAlchemyBackend @@ -63,7 +64,7 @@ def __init__( ) -> None: if name is not None: if name in settings.toml.platforms: - config = settings.toml.get_platform(name) + config: PlatformInfo = settings.toml.get_platform(name) else: settings.check_credentials() if settings.manager is not None: @@ -71,10 +72,11 @@ def __init__( else: raise PlatformNotFound(f"Platform '{name}' was not found.") - if config.dsn.startswith("http"): - self.backend = RestBackend(config, auth=_auth) - else: - self.backend = SqlAlchemyBackend(config) # type: ignore + self.backend = ( + RestBackend(config, auth=_auth) + if config.dsn.startswith("http") + else SqlAlchemyBackend(config) + ) elif _backend is not None: self.backend = _backend else: diff --git a/ixmp4/core/region.py b/ixmp4/core/region.py index a3ea3145..ee145db2 100644 --- a/ixmp4/core/region.py +++ b/ixmp4/core/region.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional, Union import pandas as pd @@ -38,26 +37,26 @@ def created_at(self) -> datetime | None: def created_by(self) -> str | None: return self._model.created_by - def delete(self): + def delete(self) -> None: """Deletes the region from the database.""" self.backend.regions.delete(self._model.id) @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.regions.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.regions.docs.delete(self.id) else: self.backend.regions.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.regions.docs.delete(self.id) # TODO: silently failing @@ -69,7 +68,7 @@ def __str__(self) -> str: class RegionRepository(BaseFacade): - def _get_region_id(self, region: Optional[Union[str, int, "Region"]]) -> int | None: + def _get_region_id(self, region: str | int | Region | None) -> int | None: if region is None: return None elif isinstance(region, str): @@ -94,7 +93,7 @@ def get(self, name: str) -> Region: model = self.backend.regions.get(name) return Region(_backend=self.backend, _model=model) - def delete(self, x: Region | int | str): + def delete(self, x: Region | int | str) -> None: if isinstance(x, Region): id = x.id elif isinstance(x, int): diff --git a/ixmp4/core/run.py b/ixmp4/core/run.py index bdd8ec93..008e2812 100644 --- a/ixmp4/core/run.py +++ b/ixmp4/core/run.py @@ -1,16 +1,29 @@ from collections import UserDict -from typing import ClassVar +from typing import ClassVar, cast import numpy as np import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + +from ixmp4.data.abstract import Model as ModelModel from ixmp4.data.abstract import Run as RunModel +from ixmp4.data.abstract import Scenario as ScenarioModel +from ixmp4.data.abstract.annotations import PrimitiveTypes +from ixmp4.data.abstract.run import EnumerateKwargs +from ixmp4.data.backend import Backend from .base import BaseFacade, BaseModelFacade from .iamc import RunIamcData from .optimization import OptimizationData +class RunKwargs(TypedDict): + _backend: Backend + _model: RunModel + + class Run(BaseModelFacade): _model: RunModel _meta: "RunMetaFacade" @@ -18,7 +31,7 @@ class Run(BaseModelFacade): NotFound: ClassVar = RunModel.NotFound NotUnique: ClassVar = RunModel.NotUnique - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Unpack[RunKwargs]) -> None: super().__init__(**kwargs) self.version = self._model.version @@ -28,35 +41,35 @@ def __init__(self, **kwargs) -> None: self.optimization = OptimizationData(_backend=self.backend, run=self._model) @property - def model(self): + def model(self) -> ModelModel: """Associated model.""" return self._model.model @property - def scenario(self): + def scenario(self) -> ScenarioModel: """Associated scenario.""" return self._model.scenario @property - def id(self): + def id(self) -> int: """Unique id.""" return self._model.id @property - def meta(self): + def meta(self) -> "RunMetaFacade": "Meta indicator data (`dict`-like)." return self._meta @meta.setter - def meta(self, meta): + def meta(self, meta: dict[str, PrimitiveTypes | np.generic | None]) -> None: self._meta._set(meta) - def set_as_default(self): + def set_as_default(self) -> None: """Sets this run as the default version for its `model` + `scenario` combination.""" self.backend.runs.set_as_default_version(self._model.id) - def unset_as_default(self): + def unset_as_default(self) -> None: """Unsets this run as the default version.""" self.backend.runs.unset_as_default_version(self._model.id) @@ -83,16 +96,16 @@ def get( _model = self.backend.runs.get(model, scenario, version) return Run(_backend=self.backend, _model=_model) - def list(self, default_only: bool = True, **kwargs) -> list[Run]: + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Run]: return [ Run(_backend=self.backend, _model=r) - for r in self.backend.runs.list(default_only=default_only, **kwargs) + for r in self.backend.runs.list(**kwargs) ] def tabulate( - self, default_only: bool = True, audit_info: bool = False, **kwargs + self, audit_info: bool = False, **kwargs: Unpack[EnumerateKwargs] ) -> pd.DataFrame: - runs = self.backend.runs.tabulate(default_only=default_only, **kwargs) + runs = self.backend.runs.tabulate(**kwargs) runs["model"] = runs["model__id"].map(self.backend.models.map()) runs["scenario"] = runs["scenario__id"].map(self.backend.scenarios.map()) columns = ["model", "scenario", "version", "is_default"] @@ -101,21 +114,21 @@ def tabulate( return runs[columns] -class RunMetaFacade(BaseFacade, UserDict): +class RunMetaFacade(BaseFacade, UserDict[str, PrimitiveTypes | None]): run: RunModel - def __init__(self, run: RunModel, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, run: RunModel, **kwargs: Backend) -> None: + super().__init__(**kwargs) self.run = run self.df, self.data = self._get() - def _get(self) -> tuple[pd.DataFrame, dict]: + def _get(self) -> tuple[pd.DataFrame, dict[str, PrimitiveTypes | None]]: df = self.backend.meta.tabulate(run_id=self.run.id, run={"default_only": False}) if df.empty: return df, {} return df, dict(zip(df["key"], df["value"])) - def _set(self, meta: dict): + def _set(self, meta: dict[str, PrimitiveTypes | np.generic | None]) -> None: df = pd.DataFrame({"key": self.data.keys()}) df["run__id"] = self.run.id self.backend.meta.bulk_delete(df) @@ -127,28 +140,30 @@ def _set(self, meta: dict): self.backend.meta.bulk_upsert(df) self.df, self.data = self._get() - def __setitem__(self, key, value: int | float | str | bool): + def __setitem__(self, key: str, value: PrimitiveTypes | np.generic | None) -> None: try: del self[key] except KeyError: pass - value = numpy_to_pytype(value) - if value is not None: - self.backend.meta.create(self.run.id, key, value) + py_value = numpy_to_pytype(value) + if py_value is not None: + self.backend.meta.create(self.run.id, key, py_value) self.df, self.data = self._get() - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: id = dict(zip(self.df["key"], self.df["id"]))[key] self.backend.meta.delete(id) self.df, self.data = self._get() -def numpy_to_pytype(value): +def numpy_to_pytype( + value: PrimitiveTypes | np.generic | None, +) -> PrimitiveTypes | None: """Cast numpy-types to basic Python types""" if value is np.nan: # np.nan is cast to 'float', not None return None elif isinstance(value, np.generic): - return value.item() + return cast(PrimitiveTypes, value.item()) else: return value diff --git a/ixmp4/core/scenario.py b/ixmp4/core/scenario.py index 1ed4bb8c..4ce254dd 100644 --- a/ixmp4/core/scenario.py +++ b/ixmp4/core/scenario.py @@ -30,21 +30,21 @@ def created_by(self) -> str | None: return self._model.created_by @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.scenarios.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.scenarios.docs.delete(self.id) else: self.backend.scenarios.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.scenarios.docs.delete(self.id) # TODO: silently failing @@ -75,9 +75,8 @@ def tabulate(self, name: str | None = None) -> pd.DataFrame: return self.backend.scenarios.tabulate(name=name) def _get_scenario_id(self, scenario: str) -> int | None: - if scenario is None: - return None - elif isinstance(scenario, str): + # NOTE leaving this check for users without mypy + if isinstance(scenario, str): obj = self.backend.scenarios.get(scenario) return obj.id else: diff --git a/ixmp4/core/unit.py b/ixmp4/core/unit.py index e95cace9..cd1266f5 100644 --- a/ixmp4/core/unit.py +++ b/ixmp4/core/unit.py @@ -29,25 +29,25 @@ def created_at(self) -> datetime | None: def created_by(self) -> str | None: return self._model.created_by - def delete(self): + def delete(self) -> None: self.backend.units.delete(self._model.id) @property - def docs(self): + def docs(self) -> str | None: try: return self.backend.units.docs.get(self.id).description except DocsModel.NotFound: return None @docs.setter - def docs(self, description): + def docs(self, description: str | None) -> None: if description is None: self.backend.units.docs.delete(self.id) else: self.backend.units.docs.set(self.id, description) @docs.deleter - def docs(self): + def docs(self) -> None: try: self.backend.units.docs.delete(self.id) # TODO: silently failing @@ -59,10 +59,7 @@ def __str__(self) -> str: class UnitRepository(BaseFacade): - def create( - self, - name: str, - ) -> Unit: + def create(self, name: str) -> Unit: if name != "" and name.strip() == "": raise ValueError("Using a space-only unit name is not allowed.") if name == "dimensionless": @@ -72,7 +69,7 @@ def create( model = self.backend.units.create(name) return Unit(_backend=self.backend, _model=model) - def delete(self, x: Unit | int | str): + def delete(self, x: Unit | int | str) -> None: if isinstance(x, Unit): id = x.id elif isinstance(x, int): @@ -97,9 +94,8 @@ def tabulate(self, name: str | None = None) -> pd.DataFrame: return self.backend.units.tabulate(name=name) def _get_unit_id(self, unit: str) -> int | None: - if unit is None: - return None - elif isinstance(unit, str): + # NOTE leaving this check for users without mypy + if isinstance(unit, str): obj = self.backend.units.get(unit) return obj.id else: diff --git a/ixmp4/core/utils.py b/ixmp4/core/utils.py index e4712258..468f1ba5 100644 --- a/ixmp4/core/utils.py +++ b/ixmp4/core/utils.py @@ -3,7 +3,7 @@ from ixmp4.data.abstract import DataPoint as DataPointModel -def substitute_type(df: pd.DataFrame, type: DataPointModel.Type | None = None): +def substitute_type(df: pd.DataFrame, type: DataPointModel.Type | None = None) -> None: if "type" not in df.columns: # `type` given explicitly if type is not None: diff --git a/ixmp4/data/abstract/__init__.py b/ixmp4/data/abstract/__init__.py index 13b1994e..adf7ef7d 100644 --- a/ixmp4/data/abstract/__init__.py +++ b/ixmp4/data/abstract/__init__.py @@ -3,6 +3,7 @@ between the database and api data models and repositories. """ +from .annotations import HasNameFilter from .base import ( BaseMeta, BaseModel, diff --git a/ixmp4/data/abstract/annotations.py b/ixmp4/data/abstract/annotations.py new file mode 100644 index 00000000..6f2cf48a --- /dev/null +++ b/ixmp4/data/abstract/annotations.py @@ -0,0 +1,144 @@ +from collections.abc import Iterable + +# TODO Use `type` when dropping Python 3.11 +from typing import TypeAlias + +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict + +PrimitiveTypes: TypeAlias = bool | float | int | str +PrimitiveIterableTypes: TypeAlias = ( + Iterable[bool] | Iterable[float] | Iterable[int] | Iterable[str] +) + +IntFilterAlias: TypeAlias = int | Iterable[int] +StrFilterAlias: TypeAlias = str | Iterable[str] +DefaultFilterAlias: TypeAlias = IntFilterAlias | StrFilterAlias +OptimizationFilterAlias: TypeAlias = dict[str, DefaultFilterAlias | None] + +# NOTE If you want to be nitpicky, you could argue that `timeseries` have an additional +# `variable` filter, which is not clear from this Alias used for both. However, +# `variable` only adds more of the same types and we only use this for casting, so we +# are fine *for now*. +IamcFilterAlias: TypeAlias = dict[ + str, + bool | DefaultFilterAlias | dict[str, DefaultFilterAlias] | None, +] +IamcObjectFilterAlias: TypeAlias = dict[ + str, + DefaultFilterAlias + | dict[ + str, + dict[str, DefaultFilterAlias | IamcFilterAlias], + ] + | bool + | None, +] + + +class HasIdFilter(TypedDict, total=False): + id: int + id__in: Iterable[int] + + +class HasNameFilter(TypedDict, total=False): + name: str | None + name__in: Iterable[str] + name__like: str + name__ilike: str + name__notlike: str + name__notilike: str + + +class HasHierarchyFilter(TypedDict, total=False): + hierarchy: str | None + hierarchy__in: Iterable[str] + hierarchy__like: str + hierarchy__ilike: str + hierarchy__notlike: str + hierarchy__notilike: str + + +class HasRunIdFilter(TypedDict, total=False): + run_id: int | None + run_id__in: Iterable[int] + run_id__gt: int + run_id__lt: int + run_id__gte: int + run_id__lte: int + run__id: int | None + run__id__in: Iterable[int] + run__id__gt: int + run__id__lt: int + run__id__gte: int + run__id__lte: int + + +class HasUnitIdFilter(TypedDict, total=False): + unit_id: int | None + unit_id__in: Iterable[int] + unit_id__gt: int + unit_id__lt: int + unit_id__gte: int + unit_id__lte: int + unit__id: int | None + unit__id__in: Iterable[int] + unit__id__gt: int + unit__id__lt: int + unit__id__gte: int + unit__id__lte: int + + +class HasRegionFilter(HasHierarchyFilter, HasIdFilter, HasNameFilter, total=False): ... + + +class HasModelFilter(HasIdFilter, HasNameFilter, total=False): ... + + +class HasScenarioFilter(HasIdFilter, HasNameFilter, total=False): ... + + +class HasUnitFilter(HasIdFilter, HasNameFilter, total=False): ... + + +class HasVariableFilter(HasIdFilter, HasNameFilter, total=False): ... + + +class HasRunFilter(HasIdFilter, total=False): + version: int | None + default_only: bool + is_default: bool | None + model: HasModelFilter | None + scenario: HasScenarioFilter | None + + +class IamcScenarioFilter(TypedDict, total=False): + region: HasRegionFilter | None + variable: HasVariableFilter | None + unit: HasUnitFilter | None + run: HasRunFilter + + +class IamcUnitFilter(TypedDict, total=False): + region: HasRegionFilter | None + variable: HasVariableFilter | None + run: HasRunFilter + + +class IamcRunFilter(TypedDict, total=False): + region: HasRegionFilter | None + variable: HasVariableFilter | None + unit: HasUnitFilter | None + + +class IamcRegionFilter(TypedDict, total=False): + variable: HasVariableFilter | None + unit: HasUnitFilter | None + run: HasRunFilter + + +class IamcModelFilter(TypedDict, total=False): + region: HasRegionFilter | None + variable: HasVariableFilter | None + unit: HasUnitFilter | None + run: HasRunFilter diff --git a/ixmp4/data/abstract/base.py b/ixmp4/data/abstract/base.py index 6a4fe0b5..88f6263a 100644 --- a/ixmp4/data/abstract/base.py +++ b/ixmp4/data/abstract/base.py @@ -1,4 +1,10 @@ -from typing import ClassVar, Protocol, _ProtocolMeta +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Protocol, _ProtocolMeta + +from typing_extensions import Unpack + +if TYPE_CHECKING: + from ixmp4.data.db.base import EnumerateKwargs import pandas as pd @@ -7,7 +13,9 @@ class BaseMeta(_ProtocolMeta): - def __init__(self, name, bases, namespace): + def __init__( + self, name: str, bases: tuple[type, ...], namespace: dict[str, Any] + ) -> None: super().__init__(name, bases, namespace) self.NotUnique = type( self.__name__ + "NotUnique", @@ -37,38 +45,38 @@ class BaseModel(Protocol, metaclass=BaseMeta): class BaseRepository(Protocol): - def __init__(self, *args, **kwargs) -> None: ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... class Retriever(BaseRepository, Protocol): - def get(self, *args, **kwargs) -> BaseModel: ... + def get(self, *args: Any, **kwargs: Any) -> BaseModel: ... class Creator(BaseRepository, Protocol): - def create(self, *args, **kwargs) -> BaseModel: ... + def create(self, *args: Any, **kwargs: Any) -> BaseModel: ... class Deleter(BaseRepository, Protocol): - def delete(self, *args, **kwargs) -> None: ... + def delete(self, *args: Any, **kwargs: Any) -> None: ... class Lister(BaseRepository, Protocol): - def list(self, *args, **kwargs) -> list: ... + def list(self, *args: Any, **kwargs: Any) -> Sequence[BaseModel]: ... class Tabulator(BaseRepository, Protocol): - def tabulate(self, *args, **kwargs) -> pd.DataFrame: ... + def tabulate(self, *args: Any, **kwargs: Any) -> pd.DataFrame: ... class Enumerator(Lister, Tabulator, Protocol): def enumerate( - self, *args, table: bool = False, **kwargs - ) -> list | pd.DataFrame: ... + self, table: bool = False, **kwargs: Unpack["EnumerateKwargs"] + ) -> Sequence[BaseModel] | pd.DataFrame: ... class BulkUpserter(BaseRepository, Protocol): - def bulk_upsert(self, *args, **kwargs) -> None: ... + def bulk_upsert(self, *args: Any, **kwargs: Any) -> None: ... class BulkDeleter(BaseRepository, Protocol): - def bulk_delete(self, *args, **kwargs) -> None: ... + def bulk_delete(self, *args: Any, **kwargs: Any) -> None: ... diff --git a/ixmp4/data/abstract/docs.py b/ixmp4/data/abstract/docs.py index 51ce2f43..2fc01a29 100644 --- a/ixmp4/data/abstract/docs.py +++ b/ixmp4/data/abstract/docs.py @@ -12,7 +12,7 @@ class Docs(base.BaseModel, Protocol): "Description of the dimension object." dimension__id: types.Integer "Foreign unique integer id of the object in the dimension's table." - dimension: types.Mapped + dimension: types.Mapped[base.BaseModel] "The documented object." # This doesn't work since each dimension has a different self.dimension object as @@ -23,13 +23,7 @@ class Docs(base.BaseModel, Protocol): # ) -# TODO: adjust all type hints once things work -class DocsRepository( - base.Retriever, - base.Deleter, - base.Enumerator, - Protocol, -): +class DocsRepository(base.Retriever, base.Deleter, base.Enumerator, Protocol): def get(self, dimension_id: int) -> Docs: """Retrieve the documentation of an object of any dimension. @@ -85,11 +79,7 @@ def delete(self, dimension_id: int) -> None: """ ... - def list( - self, - *, - dimension_id: int | None = None, - ) -> list[Docs]: + def list(self, *, dimension_id: int | None = None) -> list[Docs]: """Lists documentations. Parameters diff --git a/ixmp4/data/abstract/iamc/__init__.py b/ixmp4/data/abstract/iamc/__init__.py index fceea3d7..48e467c2 100644 --- a/ixmp4/data/abstract/iamc/__init__.py +++ b/ixmp4/data/abstract/iamc/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa from .datapoint import ( # AnnualDataPoint,; SubAnnualDataPoint,; CategoricalDataPoint, DataPoint, DataPointRepository, diff --git a/ixmp4/data/abstract/iamc/datapoint.py b/ixmp4/data/abstract/iamc/datapoint.py index db96b636..c193fe01 100644 --- a/ixmp4/data/abstract/iamc/datapoint.py +++ b/ixmp4/data/abstract/iamc/datapoint.py @@ -1,8 +1,12 @@ import enum +from collections.abc import Iterable from typing import Protocol import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + from ixmp4.data import types from .. import base @@ -35,6 +39,39 @@ def __str__(self) -> str: return f"" +class EnumerateKwargs(TypedDict, total=False): + step_year: int | None + step_year__in: Iterable[int] + step_year__gt: int + step_year__lt: int + step_year__gte: int + step_year__lte: int + year: int | None + year__in: Iterable[int] + year__gt: int + year__lt: int + year__gte: int + year__lte: int + time_series_id: int | None + time_series_id__in: Iterable[int] + time_series_id__gt: int + time_series_id__lt: int + time_series_id__gte: int + time_series_id__lte: int + time_series__id: int | None + time_series__id__in: Iterable[int] + time_series__id__gt: int + time_series__id__lt: int + time_series__id__gte: int + time_series__id__lte: int + region: dict[str, str | Iterable[str]] | None + unit: dict[str, str | Iterable[str]] | None + variable: dict[str, str | Iterable[str]] | None + model: dict[str, str | Iterable[str]] | None + scenario: dict[str, str | Iterable[str]] | None + run: dict[str, bool | int | Iterable[int]] + + class DataPointRepository( base.Enumerator, base.BulkUpserter, @@ -45,7 +82,8 @@ def list( self, *, join_parameters: bool | None = False, - **kwargs, + join_runs: bool = False, + **kwargs: Unpack[EnumerateKwargs], ) -> list[DataPoint]: """Lists data points by specified criteria. This method incurrs mentionable overhead compared to :meth:`tabulate`. @@ -53,14 +91,17 @@ def list( Parameters ---------- join_parameters : bool | None - If set to `True` the resulting data frame will include parameter columns - from the associated :class:`ixmp4.data.base.TimeSeries`. + If set to `True` the resulting list will include parameter columns + from the associated :class:`ixmp4.data.abstract.iamc.timeseries.TimeSeries`. + join_runs : bool + If set to `True` the resulting list will include model & scenario name + and version id of the associated Run. kwargs : Any Additional key word arguments. Any left over kwargs will be used as filters. Returns ------- - Iterable[:class:`ixmp4.data.base.DataPoint`]: + Iterable[:class:`ixmp4.data.abstract.iamc.datapoint.DataPoint`]: List of data points. """ ... @@ -70,7 +111,7 @@ def tabulate( *, join_parameters: bool | None = False, join_runs: bool = False, - **kwargs, + **kwargs: Unpack[EnumerateKwargs], ) -> pd.DataFrame: """Tabulates data points by specified criteria. @@ -78,7 +119,7 @@ def tabulate( ---------- join_parameters : bool | None If set to `True` the resulting data frame will include parameter columns - from the associated :class:`ixmp4.data.abstract.TimeSeries`. + from the associated :class:`ixmp4.data.abstract.iamc.timeseries.TimeSeries`. join_runs : bool If set to `True` the resulting data frame will include model & scenario name and version id of the associated Run. @@ -103,7 +144,7 @@ def tabulate( """ ... - def bulk_upsert(self, df: pd.DataFrame, **kwargs) -> None: + def bulk_upsert(self, df: pd.DataFrame) -> None: """Looks which data points in the supplied data frame already exists, updates those that have changed and inserts new ones. @@ -123,7 +164,7 @@ def bulk_upsert(self, df: pd.DataFrame, **kwargs) -> None: """ ... - def bulk_delete(self, df: pd.DataFrame, **kwargs) -> None: + def bulk_delete(self, df: pd.DataFrame) -> None: """Deletes data points which match criteria in the supplied data frame. Parameters diff --git a/ixmp4/data/abstract/iamc/measurand.py b/ixmp4/data/abstract/iamc/measurand.py index 53b34139..8528df1f 100644 --- a/ixmp4/data/abstract/iamc/measurand.py +++ b/ixmp4/data/abstract/iamc/measurand.py @@ -1,4 +1,7 @@ -from typing import Protocol +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from .. import Unit, Variable import pandas as pd @@ -12,12 +15,12 @@ class Measurand(base.BaseModel, Protocol): variable__id: types.Integer "Foreign unique integer id of a variable." - variable: types.Mapped + variable: types.Mapped["Variable"] "Associated variable." unit__id: types.Integer "Foreign unique integer id of a unit." - unit: types.Mapped + unit: types.Mapped["Unit"] "Associated unit." def __str__(self) -> str: @@ -40,10 +43,6 @@ def get_or_create(self, variable_name: str, unit__id: int) -> Measurand: except Measurand.NotFound: return self.create(variable_name, unit__id) - def list( - self, - ) -> list[Measurand]: ... + def list(self) -> list[Measurand]: ... - def tabulate( - self, - ) -> pd.DataFrame: ... + def tabulate(self) -> pd.DataFrame: ... diff --git a/ixmp4/data/abstract/iamc/timeseries.py b/ixmp4/data/abstract/iamc/timeseries.py index 265de96e..43ae8403 100644 --- a/ixmp4/data/abstract/iamc/timeseries.py +++ b/ixmp4/data/abstract/iamc/timeseries.py @@ -1,10 +1,24 @@ -from typing import Generic, Mapping, Protocol, TypeVar +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar + +if TYPE_CHECKING: + from ixmp4.data.db.timeseries import CreateKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from .. import base +from ..annotations import ( + HasIdFilter, + HasRegionFilter, + HasRunFilter, + HasUnitFilter, + HasVariableFilter, +) class TimeSeries(base.BaseModel, Protocol): @@ -12,7 +26,7 @@ class TimeSeries(base.BaseModel, Protocol): run__id: types.Integer "Unique run id." - parameters: Mapping + parameters: Mapping[str, Any] "A set of parameter values for the time series." def __str__(self) -> str: @@ -22,6 +36,13 @@ def __str__(self) -> str: ModelType = TypeVar("ModelType", bound=TimeSeries) +class EnumerateKwargs(HasIdFilter, total=False): + region: HasRegionFilter + unit: HasUnitFilter + variable: HasVariableFilter + run: HasRunFilter + + class TimeSeriesRepository( base.Creator, base.Retriever, @@ -30,35 +51,36 @@ class TimeSeriesRepository( Protocol, Generic[ModelType], ): - def create(self, run_id: int, parameters: Mapping) -> ModelType: + def create(self, **kwargs: Unpack["CreateKwargs"]) -> ModelType: """Retrieves a time series. Parameters ---------- - run_id : int + run__id : int Unique run id. - parameters : Mapping + parameters : Mapping[str, Any] A set of parameter values for the time series. Raises ------ - :class:`ixmp4.data.abstract.TimeSeries.NotUnique`. + :class:`ixmp4.data.abstract.iamc.timeseries.TimeSeries.NotUnique`. Returns ------- - :class:`ixmp4.data.base.TimeSeries`: + :class:`ixmp4.data.abstract.iamc.timeseries.TimeSeries`: The retrieved time series. """ ... - def get(self, run_id: int, parameters: Mapping) -> ModelType: + # NOTE this seems unused, so I'm guessing at the parameters type + def get(self, run_id: int, parameters: Mapping[str, Any]) -> ModelType: """Retrieves a time series. Parameters ---------- run_id : int Unique run id. - parameters : Mapping + parameters : Mapping[str, Any] A set of parameter values for the time series. Raises @@ -93,14 +115,14 @@ def get_by_id(self, id: int) -> ModelType: """ ... - def get_or_create(self, run_id: int, parameters: Mapping) -> ModelType: + def get_or_create(self, run_id: int, parameters: Mapping[str, Any]) -> ModelType: """Tries to retrieve a time series and creates it if it was not found. Parameters ---------- run_id : int Unique run id. - parameters : Mapping + parameters : Mapping[str, Any] A set of parameter values for the time series. Returns @@ -111,12 +133,9 @@ def get_or_create(self, run_id: int, parameters: Mapping) -> ModelType: try: return self.get(run_id, parameters) except TimeSeries.NotFound: - return self.create(run_id, parameters) + return self.create(run__id=run_id, parameters=parameters) - def list( - self, - **kwargs, - ) -> list[ModelType]: + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[ModelType]: r"""Lists time series by specified criteria. Parameters @@ -133,7 +152,7 @@ def list( ... def tabulate( - self, *, join_parameters: bool | None = False, **kwargs + self, *, join_parameters: bool | None = False, **kwargs: Unpack[EnumerateKwargs] ) -> pd.DataFrame: r"""Tabulate time series by specified criteria. diff --git a/ixmp4/data/abstract/iamc/variable.py b/ixmp4/data/abstract/iamc/variable.py index e70181ae..d05fee48 100644 --- a/ixmp4/data/abstract/iamc/variable.py +++ b/ixmp4/data/abstract/iamc/variable.py @@ -2,9 +2,19 @@ import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from .. import base +from ..annotations import ( + HasIdFilter, + HasNameFilter, + HasRegionFilter, + HasRunFilter, + HasUnitFilter, +) from ..docs import DocsRepository @@ -23,6 +33,12 @@ def __str__(self) -> str: return f"" +class EnumerateKwargs(HasIdFilter, HasNameFilter, total=False): + region: HasRegionFilter + run: HasRunFilter + unit: HasUnitFilter + + class VariableRepository( base.Creator, base.Retriever, @@ -71,16 +87,14 @@ def get(self, name: str) -> Variable: """ ... - def list(self, *, name: str | None = None, **kwargs) -> list[Variable]: + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Variable]: r"""Lists variables by specified criteria. Parameters ---------- - name : str - The name of a variable. If supplied only one result will be returned. \*\*kwargs: any - More filter parameters as specified in - `ixmp4.data.db.iamc.variable.filters.VariableFilter`. + Any filter parameters as specified in + `ixmp4.data.db.iamc.variable.filter.VariableFilter`. Returns ------- @@ -89,16 +103,14 @@ def list(self, *, name: str | None = None, **kwargs) -> list[Variable]: """ ... - def tabulate(self, *, name: str | None = None, **kwargs) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: r"""Tabulate variables by specified criteria. Parameters ---------- - name : str - The name of a variable. If supplied only one result will be returned. \*\*kwargs: any - More filter parameters as specified in - `ixmp4.data.db.iamc.variable.filters.VariableFilter`. + Any filter parameters as specified in + `ixmp4.data.db.iamc.variable.filter.VariableFilter`. Returns ------- diff --git a/ixmp4/data/abstract/meta.py b/ixmp4/data/abstract/meta.py index d10c2b35..7cb2a702 100644 --- a/ixmp4/data/abstract/meta.py +++ b/ixmp4/data/abstract/meta.py @@ -1,12 +1,17 @@ +from collections.abc import Iterable from enum import Enum from typing import ClassVar, Protocol import pandas as pd from pydantic import StrictBool, StrictFloat, StrictInt, StrictStr +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from . import base +from .annotations import HasIdFilter, HasRunFilter, HasRunIdFilter # as long as all of these are `Strict` the order does not matter StrictMetaValue = StrictBool | StrictInt | StrictFloat | StrictStr @@ -23,19 +28,19 @@ class Type(str, Enum): BOOL = "BOOL" @classmethod - def from_pytype(cls, type_): + def from_pytype(cls, type_: type) -> str: return RunMetaEntry._type_map[type_] run__id: types.Integer "Foreign unique integer id of a run." key: types.String "Key for the entry. Unique for each `run__id`." - type: types.String + dtype: types.String "Datatype of the entry's value." value: types.Integer | types.Float | types.Integer | types.Boolean "Value of the entry." - _type_map: ClassVar[dict] = { + _type_map: ClassVar[dict[type, str]] = { int: Type.INT, str: Type.STR, float: Type.FLOAT, @@ -47,6 +52,41 @@ def __str__(self) -> str: key={self.key} value={self.value}" +class EnumerateKwargs(HasIdFilter, HasRunIdFilter, total=False): + dtype: str + dtype__in: Iterable[str] + dtype__like: str + dtype__ilike: str + dtype__notlike: str + dtype__notilike: str + key: str + key__in: Iterable[str] + key__like: str + key__ilike: str + key__notlike: str + key__notilike: str + value_int: int + value_int_id__in: Iterable[int] + value_int_id__gt: int + value_int_id__lt: int + value_int_id__gte: int + value_int_id__lte: int + value_str: str + value_str__in: Iterable[str] + value_str__like: str + value_str__ilike: str + value_str__notlike: str + value_str__notilike: str + value_float: float + value_float_id__in: Iterable[float] + value_float_id__gt: float + value_float_id__lt: float + value_float_id__gte: float + value_float_id__lte: float + value_bool: bool + run: HasRunFilter + + class RunMetaEntryRepository( base.Creator, base.Retriever, @@ -56,12 +96,7 @@ class RunMetaEntryRepository( base.BulkDeleter, Protocol, ): - def create( - self, - run__id: int, - key: str, - value: MetaValue, - ) -> RunMetaEntry: + def create(self, run__id: int, key: str, value: MetaValue) -> RunMetaEntry: """Creates a meta indicator entry for a run. Parameters @@ -125,13 +160,14 @@ def delete(self, id: int) -> None: ... def list( - self, - **kwargs, + self, join_run_index: bool = False, **kwargs: Unpack[EnumerateKwargs] ) -> list[RunMetaEntry]: r"""Lists run's meta indicator entries by specified criteria. Parameters ---------- + join_run_index: bool, optional + Default `False`. \*\*kwargs: any Filter parameters as specified in `ixmp4.data.db.meta.filter.RunMetaEntryFilter`. @@ -144,14 +180,14 @@ def list( ... def tabulate( - self, - join_run_index: bool = False, - **kwargs, + self, join_run_index: bool = False, **kwargs: Unpack[EnumerateKwargs] ) -> pd.DataFrame: r"""Tabulates run's meta indicator entries by specified criteria. Parameters ---------- + join_run_index: bool, optional + Default `False`. \*\*kwargs: any Filter parameters as specified in `ixmp4.data.db.meta.filter.RunMetaEntryFilter`. diff --git a/ixmp4/data/abstract/model.py b/ixmp4/data/abstract/model.py index 5ae37908..653611a7 100644 --- a/ixmp4/data/abstract/model.py +++ b/ixmp4/data/abstract/model.py @@ -2,9 +2,13 @@ import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from . import base +from .annotations import HasIdFilter, HasNameFilter, IamcModelFilter from .docs import DocsRepository @@ -25,6 +29,10 @@ def __str__(self) -> str: return f"" +class EnumerateKwargs(HasIdFilter, HasNameFilter, total=False): + iamc: IamcModelFilter | bool + + class ModelRepository( base.Creator, base.Retriever, @@ -73,17 +81,14 @@ def get(self, name: str) -> Model: """ ... - def list( - self, - *, - name: str | None = None, - ) -> list[Model]: - """Lists models by specified criteria. + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Model]: + r"""Lists models by specified criteria. Parameters ---------- - name : str - The name of a model. If supplied only one result will be returned. + \*\*kwargs: any + Any filter parameters as specified in + `ixmp4.data.db.model.filter.ModelFilter`. Returns ------- @@ -92,18 +97,14 @@ def list( """ ... - def tabulate( - self, - *, - name: str | None = None, - **kwargs, - ) -> pd.DataFrame: - """Tabulate models by specified criteria. + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + r"""Tabulate models by specified criteria. Parameters ---------- - name : str - The name of a model. If supplied only one result will be returned. + \*\*kwargs: any + Any filter parameters as specified in + `ixmp4.data.db.model.filter.ModelFilter`. Returns ------- @@ -115,7 +116,7 @@ def tabulate( """ ... - def map(self, *args, **kwargs) -> dict: + def map(self, **kwargs: Unpack[EnumerateKwargs]) -> dict[int, str]: """Return a mapping of model-id to model-name. Returns @@ -123,4 +124,4 @@ def map(self, *args, **kwargs) -> dict: :class:`dict`: A dictionary `id` -> `name` """ - return dict([(m.id, m.name) for m in self.list(*args, **kwargs)]) + return dict([(m.id, m.name) for m in self.list(**kwargs)]) diff --git a/ixmp4/data/abstract/optimization/__init__.py b/ixmp4/data/abstract/optimization/__init__.py index 60603b1e..8d95e132 100644 --- a/ixmp4/data/abstract/optimization/__init__.py +++ b/ixmp4/data/abstract/optimization/__init__.py @@ -1,3 +1,6 @@ +from collections.abc import Iterable + +from ..annotations import HasIdFilter, HasNameFilter, HasRunIdFilter from .column import Column from .equation import Equation, EquationRepository from .indexset import IndexSet, IndexSetRepository @@ -5,3 +8,6 @@ from .scalar import Scalar, ScalarRepository from .table import Table, TableRepository from .variable import Variable, VariableRepository + + +class EnumerateKwargs(HasIdFilter, HasNameFilter, HasRunIdFilter, total=False): ... diff --git a/ixmp4/data/abstract/optimization/equation.py b/ixmp4/data/abstract/optimization/equation.py index 896534d6..a781cd3c 100644 --- a/ixmp4/data/abstract/optimization/equation.py +++ b/ixmp4/data/abstract/optimization/equation.py @@ -1,7 +1,14 @@ -from typing import Any, Iterable, Protocol +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from . import EnumerateKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from .. import base @@ -124,16 +131,13 @@ def get_by_id(self, id: int) -> Equation: """ ... - def list(self, *, name: str | None = None, **kwargs) -> Iterable[Equation]: + def list(self, **kwargs: Unpack["EnumerateKwargs"]) -> Iterable[Equation]: r"""Lists Equations by specified criteria. Parameters ---------- - name : str - The name of an Equation. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter Equations as specified in + Any filter Equations as specified in `ixmp4.data.db.optimization.equation.filter.OptimizationEquationFilter`. Returns @@ -143,16 +147,13 @@ def list(self, *, name: str | None = None, **kwargs) -> Iterable[Equation]: """ ... - def tabulate(self, *, name: str | None = None, **kwargs) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack["EnumerateKwargs"]) -> pd.DataFrame: r"""Tabulate Equations by specified criteria. Parameters ---------- - name : str - The name of an Equation. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter variables as specified in + Any filter variables as specified in `ixmp4.data.db.optimization.equation.filter.OptimizationEquationFilter`. Returns diff --git a/ixmp4/data/abstract/optimization/indexset.py b/ixmp4/data/abstract/optimization/indexset.py index ae0c6e2d..82feea2d 100644 --- a/ixmp4/data/abstract/optimization/indexset.py +++ b/ixmp4/data/abstract/optimization/indexset.py @@ -1,7 +1,13 @@ -from typing import List, Protocol +from typing import TYPE_CHECKING, List, Protocol + +if TYPE_CHECKING: + from . import EnumerateKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from .. import base @@ -83,16 +89,13 @@ def get(self, run_id: int, name: str) -> IndexSet: """ ... - def list(self, *, name: str | None = None, **kwargs) -> list[IndexSet]: + def list(self, **kwargs: Unpack["EnumerateKwargs"]) -> list[IndexSet]: r"""Lists IndexSets by specified criteria. Parameters ---------- - name : str - The name of an IndexSet. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter parameters as specified in + Any filter parameters as specified in `ixmp4.data.db.optimization.indexset.filter.OptimizationIndexSetFilter`. Returns @@ -103,20 +106,17 @@ def list(self, *, name: str | None = None, **kwargs) -> list[IndexSet]: ... def tabulate( - self, *, name: str | None = None, include_data: bool = False, **kwargs + self, *, include_data: bool = False, **kwargs: Unpack["EnumerateKwargs"] ) -> pd.DataFrame: r"""Tabulate IndexSets by specified criteria. Parameters ---------- - name : str, optional - The name of an IndexSet. If supplied only one result will be returned. include_data : bool, optional Whether to load all IndexSet data, which reduces loading speed. Defaults to `False`. - # TODO: Update kwargs \*\*kwargs: any - More filter parameters as specified in + Any filter parameters as specified in `ixmp4.data.db.optimization.indexset.filter.OptimizationIndexSetFilter`. Returns @@ -133,7 +133,9 @@ def tabulate( ... def add_data( - self, indexset_id: int, data: float | int | List[float | int | str] | str + self, + indexset_id: int, + data: float | int | str | List[float] | List[int] | List[str], ) -> None: """Adds data to an existing IndexSet. @@ -141,7 +143,7 @@ def add_data( ---------- indexset_id : int The id of the target IndexSet. - data : float | int | List[float | int | str] | str + data : float | int | str | List[float] | List[int] | List[str] The data to be added to the IndexSet. Returns diff --git a/ixmp4/data/abstract/optimization/parameter.py b/ixmp4/data/abstract/optimization/parameter.py index 86fd4567..94fa7537 100644 --- a/ixmp4/data/abstract/optimization/parameter.py +++ b/ixmp4/data/abstract/optimization/parameter.py @@ -1,7 +1,14 @@ -from typing import Any, Iterable, Protocol +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from . import EnumerateKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from .. import base @@ -124,16 +131,13 @@ def get_by_id(self, id: int) -> Parameter: """ ... - def list(self, *, name: str | None = None, **kwargs) -> Iterable[Parameter]: + def list(self, **kwargs: Unpack["EnumerateKwargs"]) -> Iterable[Parameter]: r"""Lists Parameters by specified criteria. Parameters ---------- - name : str - The name of a Parameter. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter parameters as specified in + Any filter parameters as specified in `ixmp4.data.db.optimization.parameter.filter.OptimizationParameterFilter`. Returns @@ -143,16 +147,13 @@ def list(self, *, name: str | None = None, **kwargs) -> Iterable[Parameter]: """ ... - def tabulate(self, *, name: str | None = None, **kwargs) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack["EnumerateKwargs"]) -> pd.DataFrame: r"""Tabulate Parameters by specified criteria. Parameters ---------- - name : str - The name of a Parameter. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter parameters as specified in + Any filter parameters as specified in `ixmp4.data.db.optimization.parameter.filter.OptimizationParameterFilter`. Returns diff --git a/ixmp4/data/abstract/optimization/scalar.py b/ixmp4/data/abstract/optimization/scalar.py index e332d168..aab3da88 100644 --- a/ixmp4/data/abstract/optimization/scalar.py +++ b/ixmp4/data/abstract/optimization/scalar.py @@ -1,11 +1,24 @@ -from typing import Iterable, Protocol +from collections.abc import Iterable +from typing import TYPE_CHECKING, Protocol + +from ..annotations import HasUnitIdFilter + +if TYPE_CHECKING: + from . import EnumerateKwargs as BaseEnumerateKwargs + + class EnumerateKwargs(BaseEnumerateKwargs, HasUnitIdFilter, total=False): ... + import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from .. import base from ..docs import DocsRepository +from ..unit import Unit class Scalar(base.BaseModel, Protocol): @@ -17,7 +30,7 @@ class Scalar(base.BaseModel, Protocol): """Value of the Scalar.""" unit__id: types.Integer "Foreign unique integer id of a unit." - unit: types.Mapped + unit: types.Mapped[Unit] "Associated unit." run__id: types.Integer "Foreign unique integer id of a run." @@ -132,16 +145,13 @@ def get_by_id(self, id: int) -> Scalar: """ ... - def list(self, *, name: str | None = None, **kwargs) -> Iterable[Scalar]: + def list(self, **kwargs: Unpack["EnumerateKwargs"]) -> Iterable[Scalar]: r"""Lists Scalars by specified criteria. Parameters ---------- - name : str - The name of a Scalar. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter parameters as specified in + Any filter parameters as specified in `ixmp4.data.db.optimization.scalar.filter.OptimizationScalarFilter`. Returns @@ -151,16 +161,13 @@ def list(self, *, name: str | None = None, **kwargs) -> Iterable[Scalar]: """ ... - def tabulate(self, *, name: str | None = None, **kwargs) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack["EnumerateKwargs"]) -> pd.DataFrame: r"""Tabulate Scalars by specified criteria. Parameters ---------- - name : str - The name of a Scalar. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter parameters as specified in + Any filter parameters as specified in `ixmp4.data.db.optimization.scalar.filter.OptimizationScalarFilter`. Returns diff --git a/ixmp4/data/abstract/optimization/table.py b/ixmp4/data/abstract/optimization/table.py index 067a252c..7c2a71e3 100644 --- a/ixmp4/data/abstract/optimization/table.py +++ b/ixmp4/data/abstract/optimization/table.py @@ -1,7 +1,14 @@ -from typing import Any, Iterable, Protocol +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from . import EnumerateKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from .. import base @@ -124,16 +131,13 @@ def get_by_id(self, id: int) -> Table: """ ... - def list(self, *, name: str | None = None, **kwargs) -> Iterable[Table]: + def list(self, **kwargs: Unpack["EnumerateKwargs"]) -> Iterable[Table]: r"""Lists Tables by specified criteria. Parameters ---------- - name : str - The name of a Table. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter parameters as specified in + Any filter parameters as specified in `ixmp4.data.db.optimization.table.filter.OptimizationTableFilter`. Returns @@ -143,16 +147,13 @@ def list(self, *, name: str | None = None, **kwargs) -> Iterable[Table]: """ ... - def tabulate(self, *, name: str | None = None, **kwargs) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack["EnumerateKwargs"]) -> pd.DataFrame: r"""Tabulate Tables by specified criteria. Parameters ---------- - name : str - The name of a Table. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter parameters as specified in + Any filter parameters as specified in `ixmp4.data.db.optimization.table.filter.OptimizationTableFilter`. Returns diff --git a/ixmp4/data/abstract/optimization/variable.py b/ixmp4/data/abstract/optimization/variable.py index cae81378..39e86aaf 100644 --- a/ixmp4/data/abstract/optimization/variable.py +++ b/ixmp4/data/abstract/optimization/variable.py @@ -1,7 +1,13 @@ -from typing import Any, Iterable, Protocol +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Protocol +if TYPE_CHECKING: + from . import EnumerateKwargs import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from .. import base @@ -125,16 +131,13 @@ def get_by_id(self, id: int) -> Variable: """ ... - def list(self, *, name: str | None = None, **kwargs) -> Iterable[Variable]: + def list(self, **kwargs: Unpack["EnumerateKwargs"]) -> Iterable[Variable]: r"""Lists Variables by specified criteria. Parameters ---------- - name : str - The name of a Variable. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter Variables as specified in + Any filter Variables as specified in `ixmp4.data.db.optimization.variable.filter.OptimizationVariableFilter`. Returns @@ -144,16 +147,13 @@ def list(self, *, name: str | None = None, **kwargs) -> Iterable[Variable]: """ ... - def tabulate(self, *, name: str | None = None, **kwargs) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack["EnumerateKwargs"]) -> pd.DataFrame: r"""Tabulate Variables by specified criteria. Parameters ---------- - name : str - The name of a Variable. If supplied only one result will be returned. - # TODO: Update kwargs \*\*kwargs: any - More filter variables as specified in + Any filter variables as specified in `ixmp4.data.db.optimization.variable.filter.OptimizationVariableFilter`. Returns diff --git a/ixmp4/data/abstract/region.py b/ixmp4/data/abstract/region.py index c135d8a5..d474ca47 100644 --- a/ixmp4/data/abstract/region.py +++ b/ixmp4/data/abstract/region.py @@ -2,9 +2,18 @@ import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from . import base +from .annotations import ( + HasHierarchyFilter, + HasIdFilter, + HasNameFilter, + IamcRegionFilter, +) from .docs import DocsRepository @@ -25,6 +34,10 @@ def __str__(self) -> str: return f"" +class EnumerateKwargs(HasHierarchyFilter, HasIdFilter, HasNameFilter, total=False): + iamc: IamcRegionFilter | bool | None + + class RegionRepository( base.Creator, base.Deleter, @@ -57,7 +70,7 @@ def create(self, name: str, hierarchy: str) -> Region: """ ... - def delete(self, id: int): + def delete(self, id: int) -> None: """Deletes a region. Parameters @@ -94,11 +107,7 @@ def get(self, name: str) -> Region: """ ... - def get_or_create( - self, - name: str, - hierarchy: str | None = None, - ) -> Region: + def get_or_create(self, name: str, hierarchy: str | None = None) -> Region: try: region = self.get(name) except Region.NotFound: @@ -114,24 +123,14 @@ def get_or_create( else: return region - def list( - self, - *, - name: str | None = None, - hierarchy: str | None = None, - **kwargs, - ) -> list[Region]: + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Region]: r"""Lists regions by specified criteria. Parameters ---------- - name : str - The name of a region. If supplied only one result will be returned. - hierarchy : str - The hierarchy of a region. \*\*kwargs: any - More filter parameters as specified in - `ixmp4.data.db.region.filters.RegionFilter`. + Any filter parameters as specified in + `ixmp4.data.db.region.filter.RegionFilter`. Returns ------- @@ -140,24 +139,14 @@ def list( """ ... - def tabulate( - self, - *, - name: str | None = None, - hierarchy: str | None = None, - **kwargs, - ) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: r"""Tabulate regions by specified criteria. Parameters ---------- - name : str - The name of a region. If supplied only one result will be returned. - hierarchy : str - The hierarchy of a region. \*\*kwargs: any - More filter parameters as specified in - `ixmp4.data.db.region.filters.RegionFilter`. + Any filter parameters as specified in + `ixmp4.data.db.region.filter.RegionFilter`. Returns ------- diff --git a/ixmp4/data/abstract/run.py b/ixmp4/data/abstract/run.py index f5a5f841..e528f5a5 100644 --- a/ixmp4/data/abstract/run.py +++ b/ixmp4/data/abstract/run.py @@ -1,11 +1,18 @@ -from typing import ClassVar, Protocol +from typing import TYPE_CHECKING, ClassVar, Protocol import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.core.exceptions import IxmpError, NoDefaultRunVersion from ixmp4.data import types from . import base +from .annotations import HasRunFilter, IamcRunFilter + +if TYPE_CHECKING: + from . import Model, Scenario class Run(base.BaseModel, Protocol): @@ -15,12 +22,12 @@ class Run(base.BaseModel, Protocol): model__id: types.Integer "Foreign unique integer id of the model." - model: types.Mapped + model: types.Mapped["Model"] "Associated model." scenario__id: types.Integer "Foreign unique integer id of the scenario." - scenario: types.Mapped + scenario: types.Mapped["Scenario"] "Associated scenario." version: types.Integer @@ -34,17 +41,17 @@ def __str__(self) -> str: is_default={self.is_default}>" +class EnumerateKwargs(HasRunFilter, total=False): + iamc: IamcRunFilter | bool | None + + class RunRepository( base.Creator, base.Retriever, base.Enumerator, Protocol, ): - def create( - self, - model_name: str, - scenario_name: str, - ) -> Run: + def create(self, model_name: str, scenario_name: str) -> Run: """Creates a run with an incremented version number or version=1 if no versions exist. Will automatically create the models and scenarios if they don't exist yet. @@ -63,12 +70,7 @@ def create( """ ... - def get( - self, - model_name: str, - scenario_name: str, - version: int, - ) -> Run: + def get(self, model_name: str, scenario_name: str, version: int) -> Run: """Retrieves a run. Parameters @@ -92,11 +94,7 @@ def get( """ ... - def get_or_create( - self, - model_name: str, - scenario_name: str, - ) -> Run: + def get_or_create(self, model_name: str, scenario_name: str) -> Run: """Tries to retrieve a run's default version and creates it if it was not found. @@ -117,11 +115,7 @@ def get_or_create( except Run.NoDefaultVersion: return self.create(model_name, scenario_name) - def get_default_version( - self, - model_name: str, - scenario_name: str, - ) -> Run: + def get_default_version(self, model_name: str, scenario_name: str) -> Run: """Retrieves a run's default version. Parameters @@ -143,23 +137,14 @@ def get_default_version( """ ... - def list( - self, - *, - version: int | None = None, - **kwargs, - ) -> list[Run]: + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Run]: r"""Lists runs by specified criteria. Parameters ---------- - version : int - The run's version. - default_only : bool - True by default. This function will return default runs only if true. \*\*kwargs: any - More filter parameters as specified in - `ixmp4.data.db.run.filters.RunFilter`. + Any filter parameters as specified in + `ixmp4.data.db.run.filter.RunFilter`. Returns ------- @@ -168,23 +153,14 @@ def list( """ ... - def tabulate( - self, - *, - version: int | None = None, - **kwargs, - ) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: r"""Tabulate runs by specified criteria. Parameters ---------- - version : int - The run's version. - default_only : bool - True by default. This function will return default runs only if true. \*\*kwargs: any - More filter parameters as specified in - `ixmp4.data.db.run.filters.RunFilter`. + Any filter parameters as specified in + `ixmp4.data.db.run.filter.RunFilter`. Returns ------- diff --git a/ixmp4/data/abstract/scenario.py b/ixmp4/data/abstract/scenario.py index a1ab8dd2..341ebf23 100644 --- a/ixmp4/data/abstract/scenario.py +++ b/ixmp4/data/abstract/scenario.py @@ -2,9 +2,13 @@ import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from . import base +from .annotations import HasIdFilter, HasNameFilter, IamcScenarioFilter from .docs import DocsRepository @@ -23,6 +27,10 @@ def __str__(self) -> str: return f"" +class EnumerateKwargs(HasIdFilter, HasNameFilter, total=False): + iamc: IamcScenarioFilter | bool + + class ScenarioRepository(base.Creator, base.Retriever, base.Enumerator, Protocol): docs: DocsRepository @@ -66,17 +74,14 @@ def get(self, name: str) -> Scenario: """ ... - def list( - self, - *, - name: str | None = None, - ) -> list[Scenario]: - """Lists scenarios by specified criteria. + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Scenario]: + r"""Lists scenarios by specified criteria. Parameters ---------- - name : str - The name of a scenario. If supplied only one result will be returned. + \*\*kwargs: any + Any filter parameters as specified in + `ixmp4.data.db.scenario.filter.ScenarioFilter`. Returns ------- @@ -85,13 +90,14 @@ def list( """ ... - def tabulate(self, *, name: str | None = None, **kwargs) -> pd.DataFrame: - """Tabulate scenarios by specified criteria. + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + r"""Tabulate scenarios by specified criteria. Parameters ---------- - name : str - The name of a scenario. If supplied only one result will be returned. + \*\*kwargs: any + Any filter parameters as specified in + `ixmp4.data.db.scenario.filter.ScenarioFilter`. Returns ------- @@ -102,7 +108,7 @@ def tabulate(self, *, name: str | None = None, **kwargs) -> pd.DataFrame: """ ... - def map(self, *args, **kwargs) -> dict: + def map(self, **kwargs: Unpack[EnumerateKwargs]) -> dict[int, str]: """Return a mapping of scenario-id to scenario-name. Returns @@ -110,4 +116,4 @@ def map(self, *args, **kwargs) -> dict: :class:`dict` A dictionary `id` -> `name` """ - return dict([(s.id, s.name) for s in self.list(*args, **kwargs)]) + return dict([(s.id, s.name) for s in self.list(**kwargs)]) diff --git a/ixmp4/data/abstract/unit.py b/ixmp4/data/abstract/unit.py index b1a28fb8..27e2e561 100644 --- a/ixmp4/data/abstract/unit.py +++ b/ixmp4/data/abstract/unit.py @@ -2,9 +2,13 @@ import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import types from . import base +from .annotations import HasIdFilter, HasNameFilter, IamcUnitFilter from .docs import DocsRepository @@ -22,6 +26,10 @@ def __str__(self) -> str: return f"" +class EnumerateKwargs(HasIdFilter, HasNameFilter, total=False): + iamc: IamcUnitFilter | bool + + class UnitRepository( base.Creator, base.Deleter, @@ -97,7 +105,7 @@ def get_or_create(self, name: str) -> Unit: except Unit.NotFound: return self.create(name) - def delete(self, id: int): + def delete(self, id: int) -> None: """Deletes a unit. Parameters @@ -114,21 +122,14 @@ def delete(self, id: int): """ ... - def list( - self, - *, - name: str | None = None, - **kwargs, - ) -> list[Unit]: + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Unit]: r"""Lists units by specified criteria. Parameters ---------- - name : str - The name of a unit. If supplied only one result will be returned. \*\*kwargs: any - More filter parameters as specified in - `ixmp4.data.db.unit.filters.UnitFilter`. + Any filter parameters as specified in + `ixmp4.data.db.unit.filter.UnitFilter`. Returns ------- @@ -137,21 +138,14 @@ def list( """ ... - def tabulate( - self, - *, - name: str | None = None, - **kwargs, - ) -> pd.DataFrame: + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: r"""Tabulate units by specified criteria. Parameters ---------- - name : str - The name of a unit. If supplied only one result will be returned. \*\*kwargs: any - More filter parameters as specified in - `ixmp4.data.db.unit.filters.UnitFilter`. + Any filter parameters as specified in + `ixmp4.data.db.unit.filter.UnitFilter`. Returns ------- diff --git a/ixmp4/data/api/base.py b/ixmp4/data/api/base.py index fa1403cd..5978fc2f 100644 --- a/ixmp4/data/api/base.py +++ b/ixmp4/data/api/base.py @@ -1,16 +1,12 @@ import logging import time +from collections.abc import Callable, Generator, Iterable, Mapping from concurrent import futures +from datetime import datetime from json.decoder import JSONDecodeError -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Generic, - Type, - TypeVar, -) + +# TODO Use `type` instead of TypeAlias when dropping Python 3.11 +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar, cast import httpx import pandas as pd @@ -18,6 +14,9 @@ from pydantic import ConfigDict, Field, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import TypedDict, Unpack + from ixmp4.conf import settings from ixmp4.core.exceptions import ( ApiEncumbered, @@ -26,12 +25,33 @@ UnknownApiError, registry, ) +from ixmp4.data import abstract if TYPE_CHECKING: from ixmp4.data.backend.api import RestBackend logger = logging.getLogger(__name__) +JsonType: TypeAlias = Mapping[ + str, + Iterable[float] + | Iterable[int] + | Iterable[str] + | Mapping[str, Any] + | abstract.annotations.PrimitiveTypes + | None, +] +ParamType: TypeAlias = dict[ + str, bool | int | str | list[int] | Mapping[str, Any] | None +] +_RequestParamType: TypeAlias = Mapping[ + str, + abstract.annotations.PrimitiveTypes + | abstract.annotations.PrimitiveIterableTypes + | Mapping[str, Any] + | None, +] + class BaseModel(PydanticBaseModel): NotFound: ClassVar[type[IxmpError]] @@ -40,26 +60,65 @@ class BaseModel(PydanticBaseModel): model_config = ConfigDict(from_attributes=True) -def df_to_dict(df: pd.DataFrame) -> dict: +class DataFrameDict(TypedDict): + index: list[int] | list[str] + columns: list[str] + dtypes: list[str] + # NOTE This is deliberately slightly out of sync with DataFrame.data below + # (cf positioning of int and float) to demonstrate that only the one below seems to + # affect our tests by causing ValidationErrors + data: list[ + list[ + abstract.annotations.PrimitiveTypes + | datetime + | dict[str, Any] + | list[float] + | list[int] + | list[str] + | None + ] + ] + + +def df_to_dict(df: pd.DataFrame) -> DataFrameDict: columns = [] dtypes = [] for c in df.columns: columns.append(c) dtypes.append(df[c].dtype.name) - return { - "index": df.index.to_list(), - "columns": columns, - "dtypes": dtypes, - "data": df.values.tolist(), - } + return DataFrameDict( + index=df.index.to_list(), + columns=columns, + dtypes=dtypes, + data=df.values.tolist(), + ) class DataFrame(PydanticBaseModel): - index: list | None = Field(None) + index: list[int] | list[str] | None = Field(None) columns: list[str] | None dtypes: list[str] | None - data: list | None + # TODO The order is important here at the moment, in particular having int before + # float! This should likely not be the case, but using StrictInt and StrictFloat + # from pydantic only created even more errors. + data: ( + list[ + list[ + bool + | datetime + | int + | float + | str + | dict[str, Any] + | list[float] + | list[int] + | list[str] + | None + ] + ] + | None + ) model_config = ConfigDict(json_encoders={pd.Timestamp: lambda x: x.isoformat()}) @@ -68,14 +127,10 @@ def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: return core_schema.no_info_before_validator_function(cls.validate, handler(cls)) - # yield cls.validate @classmethod - def validate(cls, df: pd.DataFrame | dict): - if isinstance(df, pd.DataFrame): - return df_to_dict(df) - else: - return df + def validate(cls, df: pd.DataFrame | DataFrameDict) -> DataFrameDict: + return df_to_dict(df) if isinstance(df, pd.DataFrame) else df def to_pandas(self) -> pd.DataFrame: df = pd.DataFrame( @@ -86,27 +141,37 @@ def to_pandas(self) -> pd.DataFrame: if self.columns and self.dtypes: for c, dt in zip(self.columns, self.dtypes): # there seems to be a type incompatbility between StrDtypeArg and str - df[c] = df[c].astype(dt) # type: ignore + df[c] = df[c].astype(dt) # type: ignore[call-overload] return df +class _RequestKwargs(TypedDict, total=False): + params: _RequestParamType | None + json: JsonType | None + max_retries: int + + +class RequestKwargs(TypedDict, total=False): + content: str + + ModelType = TypeVar("ModelType", bound=BaseModel) class BaseRepository(Generic[ModelType]): - model_class: Type[ModelType] + model_class: type[ModelType] prefix: ClassVar[str] enumeration_method: str = "PATCH" backend: "RestBackend" - def __init__(self, backend: "RestBackend", *args, **kwargs) -> None: + def __init__(self, backend: "RestBackend") -> None: self.backend = backend - def sanitize_params(self, params: dict): + def sanitize_params(self, params: Mapping[str, Any]) -> dict[str, Any]: return {k: params[k] for k in params if params[k] is not None} - def get_remote_exception(self, res: httpx.Response, status_code: int): + def get_remote_exception(self, res: httpx.Response, status_code: int) -> IxmpError: try: json = res.json() except (ValueError, JSONDecodeError): @@ -130,11 +195,11 @@ def _request( self, method: str, path: str, - params: dict | None = None, - json: dict | None = None, + params: _RequestParamType | None = None, + json: JsonType | None = None, max_retries: int = settings.client_max_request_retries, - **kwargs, - ) -> dict | list | None: + **kwargs: Unpack[RequestKwargs], + ) -> dict[str, Any] | list[Any] | None: """Sends a request and handles potential error responses. Re-raises a remote `IxmpError` if thrown and transferred from the backend. Handles read timeouts and rate limiting responses via retries with backoffs. @@ -142,7 +207,7 @@ def _request( but has a status code less than 300. """ - def retry(max_retries=max_retries) -> dict | list | None: + def retry(max_retries: int = max_retries) -> dict[str, Any] | list[Any] | None: if max_retries == 0: logger.error(f"API Encumbered: '{self.backend.info.dsn}'") raise ApiEncumbered( @@ -186,8 +251,8 @@ def retry(max_retries=max_retries) -> dict | list | None: def _handle_response( self, res: httpx.Response, - retry: Callable[..., dict | list | None], - ) -> dict | list | None: + retry: Callable[..., dict[str, Any] | list[Any] | None], + ) -> dict[str, Any] | list[Any] | None: if res.status_code in [ 429, # Too Many Requests 420, # Enhance Your Calm @@ -203,7 +268,9 @@ def _handle_response( raise self.get_remote_exception(res, res.status_code) else: try: - return res.json() + # res.json just returns Any... + json_decoded: dict[str, Any] | list[Any] = res.json() + return json_decoded except JSONDecodeError: if res.status_code < 300 and res.text == "": return None @@ -211,46 +278,48 @@ def _handle_response( pass raise UnknownApiError(res.text) - def _get_by_id(self, id: int, *args, **kwargs) -> dict[str, Any]: + def _get_by_id(self, id: int) -> dict[str, Any]: # we can assume this type on create endpoints - return self._request("GET", self.prefix + str(id) + "/", **kwargs) # type: ignore + return self._request("GET", self.prefix + str(id) + "/") # type: ignore[return-value] def _request_enumeration( self, table: bool = False, - params: dict | None = None, - json: dict | None = None, - ): + params: ParamType | None = None, + json: JsonType | None = None, + ) -> dict[str, Any] | list[Any]: """Convenience method for requests to the enumeration endpoint.""" if params is None: params = {} + # See https://github.com/iiasa/ixmp4/pull/129#discussion_r1841829519 for why we + # are keeping this assumption + # we can assume these types on enumeration endpoints return self._request( self.enumeration_method, self.prefix, params={**params, "table": table}, json=json, - ) + ) # type: ignore[return-value] def _dispatch_pagination_requests( self, total: int, start: int, limit: int, - params: dict | None, - json: dict | None, - ) -> list[list | dict]: + params: ParamType | None, + json: JsonType | None, + ) -> list[list[Any]] | list[dict[str, Any]]: """Uses the backends executor to send many pagination requests concurrently.""" - requests: list[futures.Future] = [] + requests: list[futures.Future[dict[str, Any]]] = [] for req_offset in range(start, total, limit): - if params is not None: - req_params = params.copy() - else: - req_params = {} + req_params = params.copy() if params is not None else {} req_params.update({"limit": limit, "offset": req_offset}) - futu = self.backend.executor.submit( - self._request, + # Based on usage below, we seem to rely on self._request always returning a + # dict[str, Any] here + futu: futures.Future[dict[str, Any]] = self.backend.executor.submit( + self._request, # type: ignore [arg-type] self.enumeration_method, self.prefix, params=req_params, @@ -259,15 +328,16 @@ def _dispatch_pagination_requests( requests.append(futu) results = futures.wait(requests) responses = [f.result() for f in results.done] + # This seems to imply that type(responses) == list[dict[str, Any]] return [r.pop("results") for r in responses] def _handle_pagination( self, - data: dict, + data: dict[str, Any], table: bool = False, - params: dict | None = None, - json: dict | None = None, - ) -> list[list] | list[dict]: + params: ParamType | None = None, + json: JsonType | None = None, + ) -> list[list[Any]] | list[dict[str, Any]]: """Handles paginated response and sends subsequent requests if necessary. Returns aggregated pages as a list.""" @@ -289,51 +359,81 @@ def _handle_pagination( return [data.pop("results")] + results def _list( - self, params: dict | None = None, json: dict | None = None, **kwargs + self, + params: ParamType | None = None, + json: JsonType | None = None, ) -> list[ModelType]: data = self._request_enumeration(params=params, table=False, json=json) if isinstance(data, dict): # we can assume this type on list endpoints - pages: list[list] = self._handle_pagination( + pages: list[list[Any]] = self._handle_pagination( data, table=False, params=params, json=json - ) # type: ignore + ) # type: ignore[assignment] results = [i for page in pages for i in page] else: results = data return [self.model_class(**i) for i in results] def _tabulate( - self, params: dict | None = {}, json: dict | None = None, **kwargs + self, + params: ParamType | None = {}, + json: JsonType | None = None, ) -> pd.DataFrame: - data = self._request_enumeration(table=True, params=params, json=json) + # we can assume this type on table endpoints + data: dict[str, Any] = self._request_enumeration( + table=True, params=params, json=json + ) # type: ignore[assignment] pagination = data.get("pagination", None) if pagination is not None: # we can assume this type on table endpoints - pages: list[dict] = self._handle_pagination( + pages: list[dict[str, Any]] = self._handle_pagination( data, table=True, params=params, json=json, - ) # type: ignore + ) # type: ignore[assignment] dfs = [DataFrame(**page).to_pandas() for page in pages] return pd.concat(dfs) else: return DataFrame(**data).to_pandas() - def _create(self, *args, **kwargs) -> dict[str, Any]: + def _create( + self, + *args: Unpack[tuple[str]], + **kwargs: Unpack[_RequestKwargs], + ) -> dict[str, Any]: # we can assume this type on create endpoints - return self._request("POST", *args, **kwargs) # type: ignore + return self._request("POST", *args, **kwargs) # type: ignore[return-value] - def _delete(self, id: int): + def _delete(self, id: int) -> None: self._request("DELETE", f"{self.prefix}{str(id)}/") +class GetKwargs(TypedDict, total=False): + dimension_id: int + run_ids: list[int] + parameters: Mapping[str, Any] + name: str + run_id: int + key: str + model: dict[str, str] + scenario: dict[str, str] + version: int + default_only: bool + is_default: bool | None + + class Retriever(BaseRepository[ModelType]): - def get(self, **kwargs) -> ModelType: - if self.enumeration_method == "GET": - list_ = self._list(params=kwargs) - else: - list_ = self._list(json=kwargs) + def get(self, **kwargs: Unpack[GetKwargs]) -> ModelType: + _kwargs = cast( + dict[str, bool | int | str | list[int] | Mapping[str, Any] | None], + kwargs, + ) + list_ = ( + self._list(params=_kwargs) + if self.enumeration_method == "GET" + else self._list(json=_kwargs) + ) try: [obj] = list_ @@ -345,7 +445,16 @@ def get(self, **kwargs) -> ModelType: class Creator(BaseRepository[ModelType]): - def create(self, **kwargs) -> ModelType: + def create( + self, + **kwargs: int + | str + | Mapping[str, Any] + | abstract.MetaValue + | list[str] + | float + | None, + ) -> ModelType: res = self._create( self.prefix, json=kwargs, @@ -354,53 +463,63 @@ def create(self, **kwargs) -> ModelType: class Deleter(BaseRepository[ModelType]): - def delete(self, id: int): + def delete(self, id: int) -> None: self._delete(id) +class ListKwargs(TypedDict, total=False): + run_id: int + name: str + + class Lister(BaseRepository[ModelType]): - def list(self, *args, **kwargs) -> list[ModelType]: - return self._list(json=kwargs) + def list(self, **kwargs: Unpack[ListKwargs]) -> list[ModelType]: + return self._list(json=kwargs) # type: ignore[arg-type] class Tabulator(BaseRepository[ModelType]): - def tabulate(self, *args, **kwargs) -> pd.DataFrame: + def tabulate(self, **kwargs: Any) -> pd.DataFrame: return self._tabulate(json=kwargs) class Enumerator(Lister[ModelType], Tabulator[ModelType]): def enumerate( - self, *args, table: bool = False, **kwargs + self, table: bool = False, **kwargs: Any ) -> list[ModelType] | pd.DataFrame: - if table: - return self.tabulate(*args, **kwargs) - else: - return self.list(*args, **kwargs) + return self.tabulate(**kwargs) if table else self.list(**kwargs) class BulkOperator(BaseRepository[ModelType]): - def yield_chunks(self, df: pd.DataFrame, chunk_size: int): + def yield_chunks( + self, df: pd.DataFrame, chunk_size: int + ) -> Generator[pd.DataFrame, Any, None]: for _, chunk in df.groupby(df.index // chunk_size): yield chunk +class BulkUpsertKwargs(TypedDict, total=False): + create_related: bool + + class BulkUpserter(BulkOperator[ModelType]): def bulk_upsert( self, df: pd.DataFrame, chunk_size: int = settings.client_default_upload_chunk_size, - **kwargs, - ): + **kwargs: Unpack[BulkUpsertKwargs], + ) -> None: for chunk in self.yield_chunks(df, chunk_size): self.bulk_upsert_chunk(chunk, **kwargs) - def bulk_upsert_chunk(self, df: pd.DataFrame, **kwargs) -> None: + def bulk_upsert_chunk( + self, df: pd.DataFrame, **kwargs: Unpack[BulkUpsertKwargs] + ) -> None: dict_ = df_to_dict(df) json_ = DataFrame(**dict_).model_dump_json() self._request( "POST", self.prefix + "bulk/", - params=kwargs, + params=cast(dict[str, bool | None], kwargs), content=json_, ) @@ -410,12 +529,14 @@ def bulk_delete( self, df: pd.DataFrame, chunk_size: int = settings.client_default_upload_chunk_size, - **kwargs, - ): + # NOTE nothing in our code base supplies kwargs here + **kwargs: Any, + ) -> None: for chunk in self.yield_chunks(df, chunk_size): self.bulk_delete_chunk(chunk, **kwargs) - def bulk_delete_chunk(self, df: pd.DataFrame, **kwargs) -> None: + # NOTE this only gets kwargs from bulk_delete() + def bulk_delete_chunk(self, df: pd.DataFrame, **kwargs: Any) -> None: dict_ = df_to_dict(df) json_ = DataFrame(**dict_).model_dump_json() self._request( diff --git a/ixmp4/data/api/docs.py b/ixmp4/data/api/docs.py index b834b5f7..f90e5444 100644 --- a/ixmp4/data/api/docs.py +++ b/ixmp4/data/api/docs.py @@ -1,4 +1,6 @@ -from typing import ClassVar, Type +from typing import ClassVar + +import pandas as pd from ixmp4.data import abstract @@ -22,7 +24,7 @@ class DocsRepository( base.Enumerator[Docs], abstract.DocsRepository, ): - model_class: Type[Docs] + model_class: type[Docs] enumeration_method = "GET" def get(self, dimension_id: int) -> Docs: @@ -38,6 +40,11 @@ def set(self, dimension_id: int, description: str) -> Docs: ) return Docs(**res) + # NOTE This is not used anywhere, but without it, mypy complains that the base + # definitions of enumerate() are incompatible + def enumerate(self, dimension_id: int | None = None) -> list[Docs] | pd.DataFrame: + return super().enumerate(dimension_id=dimension_id) + def list(self, *, dimension_id: int | None = None) -> list[Docs]: return super()._list(params={"dimension_id": dimension_id}) diff --git a/ixmp4/data/api/iamc/__init__.py b/ixmp4/data/api/iamc/__init__.py index bd3fdad6..d06eac11 100644 --- a/ixmp4/data/api/iamc/__init__.py +++ b/ixmp4/data/api/iamc/__init__.py @@ -1,5 +1,3 @@ -# flake8: noqa - from .datapoint import ( # AnnualDataPoint,; SubAnnualDataPoint,; CategoricalDataPoint, DataPoint, DataPointRepository, diff --git a/ixmp4/data/api/iamc/datapoint.py b/ixmp4/data/api/iamc/datapoint.py index 72b569ae..5e8fb310 100644 --- a/ixmp4/data/api/iamc/datapoint.py +++ b/ixmp4/data/api/iamc/datapoint.py @@ -1,8 +1,13 @@ from datetime import datetime -from typing import ClassVar + +# TODO Use `type` instead of TypeAlias when dropping Python 3.11 +from typing import ClassVar, TypeAlias, cast import pandas as pd +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from .. import base @@ -23,6 +28,14 @@ class DataPoint(base.BaseModel): step_datetime: datetime | None +JsonType: TypeAlias = dict[ + str, + abstract.annotations.IntFilterAlias + | dict[str, bool | abstract.annotations.DefaultFilterAlias] + | None, +] + + class DataPointRepository( base.Enumerator[DataPoint], base.BulkUpserter[DataPoint], @@ -32,17 +45,19 @@ class DataPointRepository( model_class = DataPoint prefix = "iamc/datapoints/" - def enumerate(self, **kwargs) -> list[DataPoint] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.iamc.datapoint.EnumerateKwargs] + ) -> list[DataPoint] | pd.DataFrame: return super().enumerate(**kwargs) def list( self, join_parameters: bool | None = None, - join_runs: bool | None = None, - **kwargs, + join_runs: bool = False, + **kwargs: Unpack[abstract.iamc.datapoint.EnumerateKwargs], ) -> list[DataPoint]: return super()._list( - json=kwargs, + json=cast(JsonType, kwargs), params={ "join_parameters": join_parameters, "join_runs": join_runs, @@ -52,19 +67,19 @@ def list( def tabulate( self, join_parameters: bool | None = None, - join_runs: bool | None = None, - **kwargs, + join_runs: bool = False, + **kwargs: Unpack[abstract.iamc.datapoint.EnumerateKwargs], ) -> pd.DataFrame: return super()._tabulate( - json=kwargs, + json=cast(JsonType, kwargs), params={ "join_parameters": join_parameters, "join_runs": join_runs, }, ) - def bulk_upsert(self, df: pd.DataFrame, **kwargs) -> None: + def bulk_upsert(self, df: pd.DataFrame) -> None: super().bulk_upsert(df) - def bulk_delete(self, df: pd.DataFrame, **kwargs) -> None: + def bulk_delete(self, df: pd.DataFrame) -> None: super().bulk_delete(df) diff --git a/ixmp4/data/api/iamc/timeseries.py b/ixmp4/data/api/iamc/timeseries.py index 9501a3f8..408e5bc0 100644 --- a/ixmp4/data/api/iamc/timeseries.py +++ b/ixmp4/data/api/iamc/timeseries.py @@ -1,7 +1,11 @@ -from typing import ClassVar, Mapping +from collections.abc import Mapping +from typing import Any, ClassVar, cast import pandas as pd +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from .. import base @@ -14,7 +18,7 @@ class TimeSeries(base.BaseModel): id: int run__id: int - parameters: Mapping + parameters: Mapping[str, Any] class TimeSeriesRepository( @@ -22,27 +26,35 @@ class TimeSeriesRepository( base.Retriever[TimeSeries], base.Enumerator[TimeSeries], base.BulkUpserter[TimeSeries], - abstract.TimeSeriesRepository, + abstract.TimeSeriesRepository[abstract.TimeSeries], ): model_class = TimeSeries prefix = "iamc/timeseries/" - def create(self, run_id: int, parameters: Mapping) -> TimeSeries: + def create(self, run_id: int, parameters: Mapping[str, Any]) -> TimeSeries: return super().create(run_id=run_id, parameters=parameters) - def get(self, run_id: int, parameters: Mapping) -> TimeSeries: + def get(self, run_id: int, parameters: Mapping[str, Any]) -> TimeSeries: return super().get(run_ids=[run_id], parameters=parameters) - def enumerate(self, **kwargs) -> list[TimeSeries] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.iamc.timeseries.EnumerateKwargs] + ) -> list[TimeSeries] | pd.DataFrame: return super().enumerate(**kwargs) - def list(self, **kwargs) -> list[TimeSeries]: - return super()._list(json=kwargs) + def list( + self, **kwargs: Unpack[abstract.iamc.timeseries.EnumerateKwargs] + ) -> list[TimeSeries]: + json = cast(abstract.annotations.IamcFilterAlias, kwargs) + return super()._list(json=json) - def tabulate(self, join_parameters: bool | None = None, **kwargs) -> pd.DataFrame: - return super()._tabulate( - json=kwargs, params={"join_parameters": join_parameters} - ) + def tabulate( + self, + join_parameters: bool | None = None, + **kwargs: Unpack[abstract.iamc.timeseries.EnumerateKwargs], + ) -> pd.DataFrame: + json = cast(abstract.annotations.IamcFilterAlias, kwargs) + return super()._tabulate(json=json, params={"join_parameters": join_parameters}) def get_by_id(self, id: int) -> TimeSeries: res = self._get_by_id(id) diff --git a/ixmp4/data/api/iamc/variable.py b/ixmp4/data/api/iamc/variable.py index f51c3562..a3b45b06 100644 --- a/ixmp4/data/api/iamc/variable.py +++ b/ixmp4/data/api/iamc/variable.py @@ -1,9 +1,15 @@ from datetime import datetime -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd from pydantic import Field +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from .. import base @@ -36,8 +42,8 @@ class VariableRepository( model_class = Variable prefix = "iamc/variables/" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) self.docs = VariableDocsRepository(self.backend) def create( @@ -49,11 +55,19 @@ def create( def get(self, name: str) -> Variable: return super().get(name=name) - def enumerate(self, **kwargs) -> list[Variable] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.iamc.variable.EnumerateKwargs] + ) -> list[Variable] | pd.DataFrame: return super().enumerate(**kwargs) - def list(self, **kwargs) -> list[Variable]: - return super()._list(json=kwargs) - - def tabulate(self, **kwargs) -> pd.DataFrame: - return super()._tabulate(json=kwargs) + def list( + self, **kwargs: Unpack[abstract.iamc.variable.EnumerateKwargs] + ) -> list[Variable]: + json = cast(abstract.annotations.IamcFilterAlias, kwargs) + return super()._list(json=json) + + def tabulate( + self, **kwargs: Unpack[abstract.iamc.variable.EnumerateKwargs] + ) -> pd.DataFrame: + json = cast(abstract.annotations.IamcFilterAlias, kwargs) + return super()._tabulate(json=json) diff --git a/ixmp4/data/api/meta.py b/ixmp4/data/api/meta.py index d674b3f4..eb5fb908 100644 --- a/ixmp4/data/api/meta.py +++ b/ixmp4/data/api/meta.py @@ -1,7 +1,13 @@ -from typing import ClassVar +from collections.abc import Iterable + +# TODO Use `type` instead of TypeAlias when dropping Python 3.11 +from typing import ClassVar, TypeAlias, cast import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from . import base @@ -15,10 +21,27 @@ class RunMetaEntry(base.BaseModel): id: int run__id: int key: str - type: str + dtype: str value: abstract.StrictMetaValue +# TODO This is tantalizingly close to the run JsonType, but not quite there. +JsonType: TypeAlias = dict[ + str, + bool + | float + | Iterable[float] + | abstract.annotations.DefaultFilterAlias + | dict[ + str, + bool + | abstract.annotations.IntFilterAlias + | dict[str, abstract.annotations.DefaultFilterAlias], + ] + | None, +] + + class RunMetaEntryRepository( base.Creator[RunMetaEntry], base.Retriever[RunMetaEntry], @@ -31,12 +54,7 @@ class RunMetaEntryRepository( model_class = RunMetaEntry prefix = "meta/" - def create( - self, - run__id: int, - key: str, - value: abstract.MetaValue, - ) -> RunMetaEntry: + def create(self, run__id: int, key: str, value: abstract.MetaValue) -> RunMetaEntry: return super().create(run__id=run__id, key=key, value=value) def get(self, run__id: int, key: str) -> RunMetaEntry: @@ -45,14 +63,27 @@ def get(self, run__id: int, key: str) -> RunMetaEntry: def delete(self, id: int) -> None: super().delete(id) - def enumerate(self, **kwargs) -> list[RunMetaEntry] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.meta.EnumerateKwargs] + ) -> list[RunMetaEntry] | pd.DataFrame: return super().enumerate(**kwargs) - def list(self, join_run_index: bool | None = None, **kwargs) -> list[RunMetaEntry]: - return super()._list(json=kwargs, params={"join_run_index": join_run_index}) - - def tabulate(self, join_run_index: bool | None = None, **kwargs) -> pd.DataFrame: - return super()._tabulate(json=kwargs, params={"join_run_index": join_run_index}) + def list( + self, + join_run_index: bool | None = None, + **kwargs: Unpack[abstract.meta.EnumerateKwargs], + ) -> list[RunMetaEntry]: + # base functions require dict, but TypedDict just inherits from Mapping + json = cast(JsonType, kwargs) + return super()._list(json=json, params={"join_run_index": join_run_index}) + + def tabulate( + self, + join_run_index: bool | None = None, + **kwargs: Unpack[abstract.meta.EnumerateKwargs], + ) -> pd.DataFrame: + json = cast(JsonType, kwargs) + return super()._tabulate(json=json, params={"join_run_index": join_run_index}) def bulk_upsert(self, df: pd.DataFrame) -> None: super().bulk_upsert(df) diff --git a/ixmp4/data/api/model.py b/ixmp4/data/api/model.py index 05ec2071..0dfd41d5 100644 --- a/ixmp4/data/api/model.py +++ b/ixmp4/data/api/model.py @@ -1,9 +1,15 @@ from datetime import datetime -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd from pydantic import Field +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from . import base @@ -31,27 +37,30 @@ class ModelRepository( model_class = Model prefix = "models/" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) self.docs = ModelDocsRepository(self.backend) - def create( - self, - name: str, - ) -> Model: + def create(self, name: str) -> Model: return super().create(name=name) def get(self, name: str) -> Model: return super().get(name=name) - def enumerate(self, **kwargs) -> list[Model] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.model.EnumerateKwargs] + ) -> list[Model] | pd.DataFrame: return super().enumerate(**kwargs) - def list(self, **kwargs) -> list[Model]: - return super()._list(json=kwargs) + def list(self, **kwargs: Unpack[abstract.model.EnumerateKwargs]) -> list[Model]: + json = cast(abstract.annotations.IamcObjectFilterAlias, kwargs) + return super()._list(json=json) - def tabulate(self, **kwargs) -> pd.DataFrame: - return super()._tabulate(json=kwargs) + def tabulate( + self, **kwargs: Unpack[abstract.model.EnumerateKwargs] + ) -> pd.DataFrame: + json = cast(abstract.annotations.IamcObjectFilterAlias, kwargs) + return super()._tabulate(json=json) class ModelDocsRepository(DocsRepository): diff --git a/ixmp4/data/api/optimization/equation.py b/ixmp4/data/api/optimization/equation.py index f25aefa4..8a45328e 100644 --- a/ixmp4/data/api/optimization/equation.py +++ b/ixmp4/data/api/optimization/equation.py @@ -1,8 +1,15 @@ +from collections.abc import Iterable from datetime import datetime -from typing import Any, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from .. import base @@ -39,9 +46,9 @@ class EquationRepository( model_class = Equation prefix = "optimization/equations/" - def __init__(self, backend, *args, **kwargs) -> None: - super().__init__(backend, *args, **kwargs) - self.docs = EquationDocsRepository(backend) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) + self.docs = EquationDocsRepository(self.backend) def create( self, @@ -60,7 +67,8 @@ def create( def add_data(self, equation_id: int, data: dict[str, Any] | pd.DataFrame) -> None: if isinstance(data, pd.DataFrame): # data will always contains str, not only Hashable - data: dict[str, Any] = data.to_dict(orient="list") # type: ignore + dict_data: dict[str, Any] = data.to_dict(orient="list") # type: ignore[assignment] + data = dict_data kwargs = {"data": data} self._request( method="PATCH", path=self.prefix + str(equation_id) + "/data/", json=kwargs @@ -76,11 +84,19 @@ def get_by_id(self, id: int) -> Equation: res = self._get_by_id(id) return Equation(**res) - def list(self, *args, **kwargs) -> Iterable[Equation]: - return super().list(*args, **kwargs) - - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) - - def enumerate(self, *args, **kwargs) -> Iterable[Equation] | pd.DataFrame: - return super().enumerate(*args, **kwargs) + def list( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> Iterable[Equation]: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._list(json=json) + + def tabulate( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> pd.DataFrame: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._tabulate(json=json) + + def enumerate( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> Iterable[Equation] | pd.DataFrame: + return super().enumerate(**kwargs) diff --git a/ixmp4/data/api/optimization/indexset.py b/ixmp4/data/api/optimization/indexset.py index 890237ec..7e3b4a8c 100644 --- a/ixmp4/data/api/optimization/indexset.py +++ b/ixmp4/data/api/optimization/indexset.py @@ -1,7 +1,14 @@ from datetime import datetime -from typing import ClassVar, List +from typing import TYPE_CHECKING, ClassVar, List, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd +from pydantic import StrictFloat, StrictInt, StrictStr + +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack from ixmp4.data import abstract @@ -16,7 +23,7 @@ class IndexSet(base.BaseModel): id: int name: str - data: float | int | str | list[int | float | str] | None + data: float | int | str | list[int] | list[float] | list[str] | None run__id: int created_at: datetime | None @@ -37,33 +44,44 @@ class IndexSetRepository( model_class = IndexSet prefix = "optimization/indexsets/" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) self.docs = IndexSetDocsRepository(self.backend) - def create( - self, - run_id: int, - name: str, - ) -> IndexSet: + def create(self, run_id: int, name: str) -> IndexSet: return super().create(run_id=run_id, name=name) def get(self, run_id: int, name: str) -> IndexSet: return super().get(name=name, run_id=run_id) - def enumerate(self, **kwargs) -> list[IndexSet] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> list[IndexSet] | pd.DataFrame: return super().enumerate(**kwargs) - def list(self, **kwargs) -> list[IndexSet]: - return super()._list(json=kwargs) + def list( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> list[IndexSet]: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._list(json=json) - def tabulate(self, include_data: bool = False, **kwargs) -> pd.DataFrame: - return super()._tabulate(json=kwargs, params={"include_data": include_data}) + def tabulate( + self, + include_data: bool = False, + **kwargs: Unpack[abstract.optimization.EnumerateKwargs], + ) -> pd.DataFrame: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._tabulate(json=json, params={"include_data": include_data}) def add_data( self, indexset_id: int, - data: float | int | str | List[float | int | str], + data: StrictFloat + | StrictInt + | StrictStr + | List[StrictFloat] + | List[StrictInt] + | List[StrictStr], ) -> None: kwargs = {"indexset_id": indexset_id, "data": data} self._request("PATCH", self.prefix + str(indexset_id) + "/", json=kwargs) diff --git a/ixmp4/data/api/optimization/parameter.py b/ixmp4/data/api/optimization/parameter.py index 32b21cc1..eb3a94ae 100644 --- a/ixmp4/data/api/optimization/parameter.py +++ b/ixmp4/data/api/optimization/parameter.py @@ -1,8 +1,15 @@ +from collections.abc import Iterable from datetime import datetime -from typing import Any, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from .. import base @@ -39,9 +46,9 @@ class ParameterRepository( model_class = Parameter prefix = "optimization/parameters/" - def __init__(self, backend, *args, **kwargs) -> None: - super().__init__(backend, *args, **kwargs) - self.docs = ParameterDocsRepository(backend) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) + self.docs = ParameterDocsRepository(self.backend) def create( self, @@ -60,7 +67,8 @@ def create( def add_data(self, parameter_id: int, data: dict[str, Any] | pd.DataFrame) -> None: if isinstance(data, pd.DataFrame): # data will always contains str, not only Hashable - data: dict[str, Any] = data.to_dict(orient="list") # type: ignore + dict_data: dict[str, Any] = data.to_dict(orient="list") # type: ignore[assignment] + data = dict_data kwargs = {"data": data} self._request( method="PATCH", path=self.prefix + str(parameter_id) + "/data/", json=kwargs @@ -73,11 +81,19 @@ def get_by_id(self, id: int) -> Parameter: res = self._get_by_id(id) return Parameter(**res) - def list(self, *args, **kwargs) -> Iterable[Parameter]: - return super().list(*args, **kwargs) - - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) - - def enumerate(self, *args, **kwargs) -> Iterable[Parameter] | pd.DataFrame: - return super().enumerate(*args, **kwargs) + def list( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> Iterable[Parameter]: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._list(json=json) + + def tabulate( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> pd.DataFrame: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._tabulate(json=json) + + def enumerate( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> Iterable[Parameter] | pd.DataFrame: + return super().enumerate(**kwargs) diff --git a/ixmp4/data/api/optimization/scalar.py b/ixmp4/data/api/optimization/scalar.py index 29462fd1..edae5b87 100644 --- a/ixmp4/data/api/optimization/scalar.py +++ b/ixmp4/data/api/optimization/scalar.py @@ -1,8 +1,15 @@ +from collections.abc import Iterable, Mapping from datetime import datetime -from typing import Any, ClassVar, Iterable, Mapping +from typing import TYPE_CHECKING, Any, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from .. import base @@ -40,17 +47,11 @@ class ScalarRepository( model_class = Scalar prefix = "optimization/scalars/" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) self.docs = ScalarDocsRepository(self.backend) - def create( - self, - name: str, - value: float, - unit_name: str, - run_id: int, - ) -> Scalar: + def create(self, name: str, value: float, unit_name: str, run_id: int) -> Scalar: return super().create( name=name, value=value, unit_name=unit_name, run_id=run_id ) @@ -66,7 +67,7 @@ def update( "value": value, "unit_id": unit_id, }, - ) # type: ignore + ) # type: ignore[assignment] return self.model_class(**res) def get(self, run_id: int, name: str) -> Scalar: @@ -76,11 +77,19 @@ def get_by_id(self, id: int) -> Scalar: res = self._get_by_id(id) return Scalar(**res) - def list(self, *args, **kwargs) -> Iterable[Scalar]: - return super().list(*args, **kwargs) - - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) - - def enumerate(self, *args, **kwargs) -> Iterable[Scalar] | pd.DataFrame: - return super().enumerate(*args, **kwargs) + def list( + self, **kwargs: Unpack["abstract.optimization.scalar.EnumerateKwargs"] + ) -> Iterable[Scalar]: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._list(json=json) + + def tabulate( + self, **kwargs: Unpack["abstract.optimization.scalar.EnumerateKwargs"] + ) -> pd.DataFrame: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._tabulate(json=json) + + def enumerate( + self, **kwargs: Unpack["abstract.optimization.scalar.EnumerateKwargs"] + ) -> Iterable[Scalar] | pd.DataFrame: + return super().enumerate(**kwargs) diff --git a/ixmp4/data/api/optimization/table.py b/ixmp4/data/api/optimization/table.py index 8f884d3c..b6924990 100644 --- a/ixmp4/data/api/optimization/table.py +++ b/ixmp4/data/api/optimization/table.py @@ -1,8 +1,15 @@ +from collections.abc import Iterable from datetime import datetime -from typing import Any, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from .. import base @@ -39,9 +46,9 @@ class TableRepository( model_class = Table prefix = "optimization/tables/" - def __init__(self, backend, *args, **kwargs) -> None: - super().__init__(backend, *args, **kwargs) - self.docs = TableDocsRepository(backend) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) + self.docs = TableDocsRepository(self.backend) def create( self, @@ -60,7 +67,8 @@ def create( def add_data(self, table_id: int, data: dict[str, Any] | pd.DataFrame) -> None: if isinstance(data, pd.DataFrame): # data will always contains str, not only Hashable - data: dict[str, Any] = data.to_dict(orient="list") # type: ignore + dict_data: dict[str, Any] = data.to_dict(orient="list") # type: ignore[assignment] + data = dict_data kwargs = {"data": data} self._request( method="PATCH", path=self.prefix + str(table_id) + "/data/", json=kwargs @@ -73,11 +81,19 @@ def get_by_id(self, id: int) -> Table: res = self._get_by_id(id) return Table(**res) - def list(self, *args, **kwargs) -> Iterable[Table]: - return super().list(*args, **kwargs) - - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) - - def enumerate(self, *args, **kwargs) -> Iterable[Table] | pd.DataFrame: - return super().enumerate(*args, **kwargs) + def list( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> Iterable[Table]: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._list(json=json) + + def tabulate( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> pd.DataFrame: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._tabulate(json=json) + + def enumerate( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> Iterable[Table] | pd.DataFrame: + return super().enumerate(**kwargs) diff --git a/ixmp4/data/api/optimization/variable.py b/ixmp4/data/api/optimization/variable.py index 504c39a3..5f176341 100644 --- a/ixmp4/data/api/optimization/variable.py +++ b/ixmp4/data/api/optimization/variable.py @@ -1,8 +1,15 @@ +from collections.abc import Iterable from datetime import datetime -from typing import Any, ClassVar, Iterable +from typing import TYPE_CHECKING, Any, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from .. import base @@ -39,9 +46,9 @@ class VariableRepository( model_class = Variable prefix = "optimization/variables/" - def __init__(self, backend, *args, **kwargs) -> None: - super().__init__(backend, *args, **kwargs) - self.docs = VariableDocsRepository(backend) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) + self.docs = VariableDocsRepository(self.backend) def create( self, @@ -60,7 +67,8 @@ def create( def add_data(self, variable_id: int, data: dict[str, Any] | pd.DataFrame) -> None: if isinstance(data, pd.DataFrame): # data will always contains str, not only Hashable - data: dict[str, Any] = data.to_dict(orient="list") # type: ignore + dict_data: dict[str, Any] = data.to_dict(orient="list") # type: ignore[assignment] + data = dict_data kwargs = {"data": data} self._request( method="PATCH", path=self.prefix + str(variable_id) + "/data/", json=kwargs @@ -76,11 +84,19 @@ def get_by_id(self, id: int) -> Variable: res = self._get_by_id(id) return Variable(**res) - def list(self, *args, **kwargs) -> Iterable[Variable]: - return super().list(*args, **kwargs) - - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) - - def enumerate(self, *args, **kwargs) -> Iterable[Variable] | pd.DataFrame: - return super().enumerate(*args, **kwargs) + def list( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> Iterable[Variable]: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._list(json=json) + + def tabulate( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> pd.DataFrame: + json = cast(abstract.annotations.OptimizationFilterAlias, kwargs) + return super()._tabulate(json=json) + + def enumerate( + self, **kwargs: Unpack[abstract.optimization.EnumerateKwargs] + ) -> Iterable[Variable] | pd.DataFrame: + return super().enumerate(**kwargs) diff --git a/ixmp4/data/api/region.py b/ixmp4/data/api/region.py index 8f011a33..d9b38de3 100644 --- a/ixmp4/data/api/region.py +++ b/ixmp4/data/api/region.py @@ -1,9 +1,15 @@ from datetime import datetime -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd from pydantic import Field +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from . import base @@ -44,15 +50,11 @@ class RegionRepository( model_class = Region prefix = "regions/" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) self.docs = RegionDocsRepository(self.backend) - def create( - self, - name: str, - hierarchy: str, - ) -> Region: + def create(self, name: str, hierarchy: str) -> Region: return super().create(name=name, hierarchy=hierarchy) def delete(self, id: int) -> None: @@ -61,11 +63,17 @@ def delete(self, id: int) -> None: def get(self, name: str) -> Region: return super().get(name=name) - def enumerate(self, **kwargs) -> list[Region] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.region.EnumerateKwargs] + ) -> list[Region] | pd.DataFrame: return super().enumerate(**kwargs) - def list(self, **kwargs) -> list[Region]: - return super()._list(json=kwargs) + def list(self, **kwargs: Unpack[abstract.region.EnumerateKwargs]) -> list[Region]: + json = cast(abstract.annotations.IamcObjectFilterAlias, kwargs) + return super()._list(json=json) - def tabulate(self, **kwargs) -> pd.DataFrame: - return super()._tabulate(json=kwargs) + def tabulate( + self, **kwargs: Unpack[abstract.region.EnumerateKwargs] + ) -> pd.DataFrame: + json = cast(abstract.annotations.IamcObjectFilterAlias, kwargs) + return super()._tabulate(json=json) diff --git a/ixmp4/data/api/run.py b/ixmp4/data/api/run.py index 9b1485ec..374d0ebc 100644 --- a/ixmp4/data/api/run.py +++ b/ixmp4/data/api/run.py @@ -1,8 +1,12 @@ -from typing import ClassVar +# TODO Use `type` instead of TypeAlias when dropping Python 3.11 +from typing import ClassVar, TypeAlias, cast import pandas as pd from pydantic import Field +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from . import base @@ -29,6 +33,19 @@ class Run(base.BaseModel): is_default: bool +JsonType: TypeAlias = dict[ + str, + bool + | abstract.annotations.IntFilterAlias + | dict[ + str, + abstract.annotations.DefaultFilterAlias + | dict[str, abstract.annotations.DefaultFilterAlias], + ] + | None, +] + + class RunRepository( base.Creator[Run], base.Retriever[Run], @@ -50,14 +67,18 @@ def get(self, model_name: str, scenario_name: str, version: int) -> Run: is_default=None, ) - def enumerate(self, **kwargs) -> list[Run] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.run.EnumerateKwargs] + ) -> list[Run] | pd.DataFrame: return super().enumerate(**kwargs) - def list(self, **kwargs) -> list[Run]: - return super()._list(json=kwargs) + def list(self, **kwargs: Unpack[abstract.run.EnumerateKwargs]) -> list[Run]: + json = cast(JsonType, kwargs) + return super()._list(json=json) - def tabulate(self, **kwargs) -> pd.DataFrame: - return super()._tabulate(json=kwargs) + def tabulate(self, **kwargs: Unpack[abstract.run.EnumerateKwargs]) -> pd.DataFrame: + json = cast(JsonType, kwargs) + return super()._tabulate(json=json) def get_default_version(self, model_name: str, scenario_name: str) -> Run: try: diff --git a/ixmp4/data/api/scenario.py b/ixmp4/data/api/scenario.py index 8a33dcf7..1521c076 100644 --- a/ixmp4/data/api/scenario.py +++ b/ixmp4/data/api/scenario.py @@ -1,9 +1,15 @@ from datetime import datetime -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd from pydantic import Field +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from . import base @@ -35,24 +41,29 @@ class ScenarioRepository( model_class = Scenario prefix = "scenarios/" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) self.docs = ScenarioDocsRepository(self.backend) - def create( - self, - name: str, - ) -> Scenario: + def create(self, name: str) -> Scenario: return super().create(name=name) def get(self, name: str) -> Scenario: return super().get(name=name) - def enumerate(self, **kwargs) -> list[Scenario] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.scenario.EnumerateKwargs] + ) -> list[Scenario] | pd.DataFrame: return super().enumerate(**kwargs) - def list(self, **kwargs) -> list[Scenario]: - return super()._list(json=kwargs) - - def tabulate(self, **kwargs) -> pd.DataFrame: - return super()._tabulate(json=kwargs) + def list( + self, **kwargs: Unpack[abstract.scenario.EnumerateKwargs] + ) -> list[Scenario]: + json = cast(abstract.annotations.IamcObjectFilterAlias, kwargs) + return super()._list(json=json) + + def tabulate( + self, **kwargs: Unpack[abstract.scenario.EnumerateKwargs] + ) -> pd.DataFrame: + json = cast(abstract.annotations.IamcObjectFilterAlias, kwargs) + return super()._tabulate(json=json) diff --git a/ixmp4/data/api/unit.py b/ixmp4/data/api/unit.py index 4ec40b53..8024534c 100644 --- a/ixmp4/data/api/unit.py +++ b/ixmp4/data/api/unit.py @@ -1,9 +1,15 @@ from datetime import datetime -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.api import RestBackend import pandas as pd from pydantic import Field +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from . import base @@ -36,14 +42,11 @@ class UnitRepository( model_class = Unit prefix = "units/" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: Unpack[tuple["RestBackend"]]) -> None: + super().__init__(*args) self.docs = UnitDocsRepository(self.backend) - def create( - self, - name: str, - ) -> Unit: + def create(self, name: str) -> Unit: return super().create(name=name) def delete(self, id: int) -> None: @@ -56,11 +59,15 @@ def get_by_id(self, id: int) -> Unit: res = self._get_by_id(id) return Unit(**res) - def enumerate(self, **kwargs) -> list[Unit] | pd.DataFrame: + def enumerate( + self, **kwargs: Unpack[abstract.unit.EnumerateKwargs] + ) -> list[Unit] | pd.DataFrame: return super().enumerate(**kwargs) - def list(self, **kwargs) -> list[Unit]: - return super()._list(json=kwargs) + def list(self, **kwargs: Unpack[abstract.unit.EnumerateKwargs]) -> list[Unit]: + json = cast(abstract.annotations.IamcObjectFilterAlias, kwargs) + return super()._list(json=json) - def tabulate(self, **kwargs) -> pd.DataFrame: - return super()._tabulate(json=kwargs) + def tabulate(self, **kwargs: Unpack[abstract.unit.EnumerateKwargs]) -> pd.DataFrame: + json = cast(abstract.annotations.IamcObjectFilterAlias, kwargs) + return super()._tabulate(json=json) diff --git a/ixmp4/data/auth/context.py b/ixmp4/data/auth/context.py index 4ff30672..8163a725 100644 --- a/ixmp4/data/auth/context.py +++ b/ixmp4/data/auth/context.py @@ -1,4 +1,5 @@ import re +from typing import Any, TypeVar import pandas as pd @@ -17,7 +18,7 @@ def __init__( self.manager = manager self.platform = platform - def tabulate_permissions(self): + def tabulate_permissions(self) -> pd.DataFrame: df = self.manager.fetch_user_permissions( self.user, self.platform, jti=self.user.jti ) @@ -38,7 +39,9 @@ def convert_to_like(m: str) -> str: df["like"] = df["model"].apply(convert_to_like) return df - def apply(self, access_type: str, exc: db.sql.Select) -> db.sql.Select: + ApplyType = TypeVar("ApplyType", bound=db.sql.Select[tuple[Any]]) + + def apply(self, access_type: str, exc: ApplyType) -> ApplyType: if self.is_managed: return exc if self.user.is_superuser: @@ -47,7 +50,7 @@ def apply(self, access_type: str, exc: db.sql.Select) -> db.sql.Select: if utils.is_joined(exc, Model): perms = self.tabulate_permissions() if perms.empty: - return exc.where(False) # type: ignore + return exc.where(db.false()) if access_type == "edit": perms = perms.where(perms["access_type"] == "EDIT").dropna() # `*` is used as wildcard in permission logic, replaced by sql-wildcard `%` diff --git a/ixmp4/data/auth/decorators.py b/ixmp4/data/auth/decorators.py index cae1927a..460a45cf 100644 --- a/ixmp4/data/auth/decorators.py +++ b/ixmp4/data/auth/decorators.py @@ -1,25 +1,32 @@ +from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Callable, Protocol +from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar from ixmp4.core.exceptions import Forbidden, ProgrammingError if TYPE_CHECKING: - from ..backend.db import SqlAlchemyBackend + from ..db.base import BaseModel, BaseRepository +P = ParamSpec("P") -class Guardable(Protocol): - # NOTE: Eager checking for api backends may be desirable - # at some point - backend: "SqlAlchemyBackend" +ReturnT = TypeVar("ReturnT") -def guard(access: str) -> Callable: +def guard( + access: str, +) -> Callable[ + [Callable[Concatenate[Any, P], ReturnT]], Callable[Concatenate[Any, P], ReturnT] +]: if access not in ["edit", "manage", "view"]: raise ProgrammingError("Guard access must be 'edit', 'manage' or 'view'.") - def decorator(func): + def decorator( + func: Callable[Concatenate[Any, P], ReturnT], + ) -> Callable[Concatenate[Any, P], ReturnT]: @wraps(func) - def guarded_func(self: Guardable, *args, **kwargs): + def guarded_func( + self: "BaseRepository[BaseModel]", /, *args: P.args, **kwargs: P.kwargs + ) -> ReturnT: if self.backend.auth_context is not None: if access == "view" and self.backend.auth_context.is_viewable: return func(self, *args, **kwargs) diff --git a/ixmp4/data/backend/__init__.py b/ixmp4/data/backend/__init__.py index c41ad799..dad0c72c 100644 --- a/ixmp4/data/backend/__init__.py +++ b/ixmp4/data/backend/__init__.py @@ -1,5 +1,3 @@ -# flake8: noqa - from .api import RestBackend, RestTestBackend from .base import Backend from .db import SqlAlchemyBackend, SqliteTestBackend diff --git a/ixmp4/data/backend/api.py b/ixmp4/data/backend/api.py index beda8741..0ee96761 100644 --- a/ixmp4/data/backend/api.py +++ b/ixmp4/data/backend/api.py @@ -1,6 +1,9 @@ import logging from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .db import SqlAlchemyBackend import httpx import pandas as pd from fastapi.testclient import TestClient @@ -55,7 +58,7 @@ def __init__( logger.info("Platform notice: >\n" + info.notice) self.create_repositories() - def make_client(self, rest_url: str, auth: BaseAuth | None): + def make_client(self, rest_url: str, auth: BaseAuth | None) -> None: auth = self.get_auth(rest_url, auth) self.client = httpx.Client( @@ -65,7 +68,9 @@ def make_client(self, rest_url: str, auth: BaseAuth | None): auth=auth, ) - def get_auth(self, rest_url: str, override_auth: BaseAuth | None) -> BaseAuth: + def get_auth( + self, rest_url: str, override_auth: BaseAuth | None + ) -> BaseAuth | None: root = httpx.get(rest_url, follow_redirects=True) if root.status_code != 200: logger.error("Root API response not OK: " + root.text) @@ -109,7 +114,7 @@ def get_auth(self, rest_url: str, override_auth: BaseAuth | None) -> BaseAuth: else: return override_auth - def create_repositories(self): + def create_repositories(self) -> None: self.iamc.datapoints = DataPointRepository(self) self.iamc.timeseries = TimeSeriesRepository(self) self.iamc.variables = VariableRepository(self) @@ -148,14 +153,12 @@ def create_repositories(self): class RestTestBackend(RestBackend): - def __init__(self, db_backend, *args, **kwargs) -> None: + def __init__(self, db_backend: "SqlAlchemyBackend") -> None: self.db_backend = db_backend self.auth_params = (test_user, mock_manager, test_platform) - super().__init__( - test_platform, SelfSignedAuth(settings.secret_hs256), *args, **kwargs - ) + super().__init__(test_platform, SelfSignedAuth(settings.secret_hs256)) - def make_client(self, rest_url: str, auth: BaseAuth): + def make_client(self, rest_url: str, auth: BaseAuth) -> None: self.client = TestClient( app=app, base_url=rest_url, @@ -172,12 +175,12 @@ def make_client(self, rest_url: str, auth: BaseAuth): self.db_backend, self.auth_params ) - def close(self): + def close(self) -> None: self.client.close() self.executor.shutdown(cancel_futures=True) - def setup(self): + def setup(self) -> None: pass - def teardown(self): + def teardown(self) -> None: pass diff --git a/ixmp4/data/backend/base.py b/ixmp4/data/backend/base.py index ae3fb46d..15fb689c 100644 --- a/ixmp4/data/backend/base.py +++ b/ixmp4/data/backend/base.py @@ -12,6 +12,7 @@ ScalarRepository, ScenarioRepository, TableRepository, + TimeSeries, TimeSeriesRepository, UnitRepository, VariableRepository, @@ -20,7 +21,7 @@ class IamcSubobject(object): datapoints: DataPointRepository - timeseries: TimeSeriesRepository + timeseries: TimeSeriesRepository[TimeSeries] variables: VariableRepository @@ -53,5 +54,5 @@ def close(self) -> None: """Closes the connection to the database.""" ... - def __str__(self): + def __str__(self) -> str: return f"" diff --git a/ixmp4/data/backend/db.py b/ixmp4/data/backend/db.py index 25f2981c..8b2f3bd5 100644 --- a/ixmp4/data/backend/db.py +++ b/ixmp4/data/backend/db.py @@ -1,12 +1,15 @@ import logging +from collections.abc import Generator from contextlib import contextmanager from functools import lru_cache -from typing import Generator from sqlalchemy.engine import Engine, create_engine from sqlalchemy.orm.session import sessionmaker from sqlalchemy.pool import NullPool, StaticPool +# TODO Import this from typing when dropping support for Python 3.11 +from typing_extensions import Unpack + from ixmp4.conf.base import PlatformInfo from ixmp4.conf.manager import ManagerConfig, ManagerPlatformInfo from ixmp4.conf.user import User @@ -83,7 +86,7 @@ def __init__(self, info: PlatformInfo) -> None: self.make_repositories() self.event_handler = SqlaEventHandler(self) - def check_dsn(self, dsn: str): + def check_dsn(self, dsn: str) -> str: if dsn.startswith("postgresql://"): logger.debug( "Replacing the platform dsn prefix to use the new `psycopg` driver." @@ -91,11 +94,11 @@ def check_dsn(self, dsn: str): dsn = dsn.replace("postgresql://", "postgresql+psycopg://") return dsn - def make_engine(self, dsn: str): + def make_engine(self, dsn: str) -> None: self.engine = cached_create_engine(dsn) self.session = self.Session(bind=self.engine) - def make_repositories(self): + def make_repositories(self) -> None: self.iamc.datapoints = DataPointRepository(self) self.iamc.timeseries = TimeSeriesRepository(self) self.iamc.variables = VariableRepository(self) @@ -127,34 +130,29 @@ def auth( yield self.auth_context self.auth_context = None - def _create_all(self): + def _create_all(self) -> None: BaseModel.metadata.create_all(bind=self.engine) - def _drop_all(self): + def _drop_all(self) -> None: BaseModel.metadata.drop_all(bind=self.engine, checkfirst=True) - def setup(self): + def setup(self) -> None: self._create_all() - def teardown(self): + def teardown(self) -> None: self.session.rollback() self._drop_all() - self.engine = None - self.session = None - def close(self): + def close(self) -> None: self.session.close() self.engine.dispose() class SqliteTestBackend(SqlAlchemyBackend): - def __init__(self, *args, **kwargs) -> None: - super().__init__( - *args, - **kwargs, - ) + def __init__(self, *args: Unpack[tuple[PlatformInfo]]) -> None: + super().__init__(*args) - def make_engine(self, dsn: str): + def make_engine(self, dsn: str) -> None: self.engine = create_engine( dsn, connect_args={"check_same_thread": False}, @@ -164,12 +162,9 @@ def make_engine(self, dsn: str): class PostgresTestBackend(SqlAlchemyBackend): - def __init__(self, *args, **kwargs) -> None: - super().__init__( - *args, - **kwargs, - ) + def __init__(self, *args: Unpack[tuple[PlatformInfo]]) -> None: + super().__init__(*args) - def make_engine(self, dsn: str): + def make_engine(self, dsn: str) -> None: self.engine = create_engine(dsn, poolclass=NullPool) self.session = self.Session(bind=self.engine) diff --git a/ixmp4/data/db/base.py b/ixmp4/data/db/base.py index 89c4ceae..491c8323 100644 --- a/ixmp4/data/db/base.py +++ b/ixmp4/data/db/base.py @@ -2,22 +2,12 @@ import logging import sqlite3 -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Generic, - Iterable, - Iterator, - Tuple, - TypeVar, - cast, -) +from collections.abc import Callable, Iterable, Iterator, Mapping +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast import numpy as np import pandas as pd -from sqlalchemy import event, text +from sqlalchemy import TextClause, event, text from sqlalchemy.engine import Engine from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.exc import IntegrityError, NoResultFound @@ -26,6 +16,9 @@ from sqlalchemy.orm.session import Session from sqlalchemy.sql.schema import Identity, MetaData +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import NotRequired, TypedDict, Unpack + from ixmp4 import db from ixmp4.core.exceptions import Forbidden, IxmpError, ProgrammingError from ixmp4.data import abstract, types @@ -38,15 +31,18 @@ @event.listens_for(Engine, "connect") -def set_sqlite_pragma(dbapi_connection, connection_record): +def set_sqlite_pragma( + dbapi_connection: sqlite3.Connection, connection_record: Any +) -> None: if isinstance(dbapi_connection, sqlite3.Connection): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() -@compiles(Identity, "sqlite") -def visit_identity(element, compiler, **kwargs): +# NOTE compiles from sqlalchemy is untyped, not much we can do here +@compiles(Identity, "sqlite") # type: ignore[misc,no-untyped-call] +def visit_identity(element: Any, compiler: Any, **kwargs: Any) -> TextClause: return text("") @@ -69,7 +65,7 @@ def __tablename__(cls: "BaseModel") -> str: info={"skip_autogenerate": True}, ) - def __str__(self): + def __str__(self) -> str: return self.__class__.__name__ @@ -83,6 +79,7 @@ def __str__(self): } ) +SelectType = TypeVar("SelectType", bound=db.sql.Select[tuple[BaseModel, ...]]) ModelType = TypeVar("ModelType", bound=BaseModel) @@ -91,10 +88,10 @@ class BaseRepository(Generic[ModelType]): backend: "SqlAlchemyBackend" session: Session dialect: Dialect - bundle: Bundle + bundle: Bundle[Any] model_class: type[ModelType] - def __init__(self, backend: "SqlAlchemyBackend", *args, **kwargs) -> None: + def __init__(self, backend: "SqlAlchemyBackend") -> None: self.backend = backend self.session = backend.session self.engine = backend.engine @@ -104,22 +101,37 @@ def __init__(self, backend: "SqlAlchemyBackend", *args, **kwargs) -> None: else: raise ProgrammingError("Database session is closed.") - self.bundle: Bundle = Bundle( + self.bundle: Bundle[Any] = Bundle( self.model_class.__name__, *db.utils.get_columns(self.model_class).values() ) - super().__init__(*args, **kwargs) + super().__init__() class Retriever(BaseRepository[ModelType], abstract.Retriever): - def get(self, *args, **kwargs) -> ModelType: + def get(self, *args: Any, **kwargs: Any) -> ModelType: raise NotImplementedError +class CreateKwargs(TypedDict, total=False): + dimension_id: int + description: str + name: str + model_name: str + scenario_name: str + hierarchy: str + run__id: int + key: str + value: abstract.annotations.PrimitiveTypes + parameters: Mapping[str, Any] + run_id: int + unit_name: str | None + + class Creator(BaseRepository[ModelType], abstract.Creator): - def add(self, *args, **kwargs) -> ModelType: + def add(self, *args: Any, **kwargs: Any) -> ModelType: raise NotImplementedError - def create(self, *args, **kwargs) -> ModelType: + def create(self, *args: Any, **kwargs: Unpack[CreateKwargs]) -> ModelType: model = self.add(*args, **kwargs) try: self.session.commit() @@ -131,10 +143,8 @@ def create(self, *args, **kwargs) -> ModelType: class Deleter(BaseRepository[ModelType]): - def delete(self, id: int): - exc: db.sql.Delete = db.delete(self.model_class).where( - self.model_class.id == id - ) + def delete(self, id: int) -> None: + exc = db.delete(self.model_class).where(self.model_class.id == id) try: self.session.execute( @@ -147,11 +157,38 @@ def delete(self, id: int): raise self.model_class.DeletionPrevented +class CheckAccessKwargs(TypedDict, total=False): + run: abstract.annotations.HasRunFilter + is_default: bool | None + default_only: bool | None + + +class SelectCountKwargs(abstract.HasNameFilter, total=False): + default_only: bool | None + dimension_id: int | None + id__in: set[int] + is_default: bool | None + join_parameters: bool | None + join_runs: bool | None + iamc: abstract.annotations.IamcFilterAlias + model: abstract.annotations.HasModelFilter + region: abstract.annotations.HasRegionFilter + run: abstract.annotations.HasRunFilter + scenario: abstract.annotations.HasScenarioFilter + unit: abstract.annotations.HasUnitFilter + variable: abstract.annotations.HasVariableFilter + + class Selecter(BaseRepository[ModelType]): - filter_class: type[filters.BaseFilter] + filter_class: type[filters.BaseFilter] | None - def check_access(self, ids: set[int], access_type: str = "view", **kwargs): - exc = self.select( + def check_access( + self, + ids: set[int], + access_type: str = "view", + **kwargs: Unpack[CheckAccessKwargs], + ) -> None: + exc = self.select_for_count( _exc=db.select(db.func.count()).select_from(self.model_class), id__in=ids, _access_type=access_type, @@ -166,25 +203,55 @@ def check_access(self, ids: set[int], access_type: str = "view", **kwargs): ) raise Forbidden(f"Permission check failed for access type '{access_type}'.") - def join_auth(self, exc: db.sql.Select) -> db.sql.Select: + SelectAuthType = TypeVar("SelectAuthType", bound=db.sql.Select[tuple[Any]]) + + def join_auth(self, exc: SelectAuthType) -> SelectAuthType: return exc - def apply_auth(self, exc: db.sql.Select, access_type: str) -> db.sql.Select: + def apply_auth(self, exc: SelectAuthType, access_type: str) -> SelectAuthType: if self.backend.auth_context is not None: if not self.backend.auth_context.is_managed: exc = self.join_auth(exc) exc = self.backend.auth_context.apply(access_type, exc) return exc + def select_for_count( + self, + _exc: db.sql.Select[tuple[int]], + _filter: filters.BaseFilter | None = None, + _skip_filter: bool = False, + _access_type: str = "view", + **kwargs: Unpack[SelectCountKwargs], + ) -> db.sql.Select[tuple[int]]: + if self.filter_class is None: + cls_name = self.__class__.__name__ + raise NotImplementedError( + f"Provide `{cls_name}.filter_class` or reimplement `{cls_name}.select`." + ) + + _exc = self.apply_auth(_exc, _access_type) + if _filter is not None and not _skip_filter: + filter_instance = _filter + _exc = filter_instance.join(_exc, session=self.session) + _exc = filter_instance.apply(_exc, self.model_class, self.session) + elif not _skip_filter: + kwarg_filter: filters.BaseFilter = self.filter_class(**kwargs) + _exc = kwarg_filter.join(_exc, session=self.session) + _exc = kwarg_filter.apply(_exc, self.model_class, self.session) + return _exc + def select( self, _filter: filters.BaseFilter | None = None, - _exc: db.sql.Select | None = None, + _exc: db.sql.Select[tuple[ModelType]] | None = None, _access_type: str = "view", - _post_filter: Callable[[db.sql.Select], db.sql.Select] | None = None, + _post_filter: Callable[ + [db.sql.Select[tuple[ModelType]]], db.sql.Select[tuple[ModelType]] + ] + | None = None, _skip_filter: bool = False, - **kwargs, - ) -> db.sql.Select: + **kwargs: Any, + ) -> db.sql.Select[tuple[ModelType]]: if self.filter_class is None: cls_name = self.__class__.__name__ raise NotImplementedError( @@ -197,12 +264,11 @@ def select( _exc = self.apply_auth(_exc, _access_type) if _filter is not None and not _skip_filter: - # for some reason checkers resolve the type of `_filter` to `Unknown` - filter_instance: filters.BaseFilter = _filter + filter_instance = _filter _exc = filter_instance.join(_exc, session=self.session) _exc = filter_instance.apply(_exc, self.model_class, self.session) elif not _skip_filter: - kwarg_filter = self.filter_class(**kwargs) + kwarg_filter: filters.BaseFilter = self.filter_class(**kwargs) _exc = kwarg_filter.join(_exc, session=self.session) _exc = kwarg_filter.apply(_exc, self.model_class, self.session) @@ -212,7 +278,7 @@ def select( class Lister(Selecter[ModelType]): - def list(self, *args, **kwargs) -> list[ModelType]: + def list(self, *args: Any, **kwargs: Any) -> list[ModelType]: _exc = self.select(*args, **kwargs) _exc = _exc.order_by(self.model_class.id.asc()) result = self.session.execute(_exc).scalars().all() @@ -220,12 +286,7 @@ def list(self, *args, **kwargs) -> list[ModelType]: class Tabulator(Selecter[ModelType]): - def tabulate( - self, - *args, - _raw: bool = False, - **kwargs, - ) -> pd.DataFrame: + def tabulate(self, *args: Any, _raw: bool = False, **kwargs: Any) -> pd.DataFrame: _exc = self.select(*args, **kwargs) _exc = _exc.order_by(self.model_class.id.asc()) @@ -236,31 +297,55 @@ def tabulate( raise ProgrammingError("Database session is closed.") +class PaginateKwargs(TypedDict): + _filter: filters.BaseFilter + include_data: NotRequired[bool] + join_parameters: NotRequired[bool | None] + join_runs: NotRequired[bool | None] + join_run_index: NotRequired[bool | None] + table: bool + + +class EnumerateKwargs(TypedDict): + _filter: filters.BaseFilter + include_data: NotRequired[bool] + join_parameters: NotRequired[bool | None] + join_runs: NotRequired[bool | None] + join_run_index: NotRequired[bool | None] + _post_filter: Callable[ + [db.sql.Select[tuple[ModelType]]], db.sql.Select[tuple[ModelType]] + ] + + +class CountKwargs(abstract.HasNameFilter, total=False): + dimension_id: int | None + _filter: filters.BaseFilter + join_parameters: bool | None + join_runs: bool | None + iamc: abstract.annotations.IamcFilterAlias + model: abstract.annotations.HasModelFilter + region: abstract.annotations.HasRegionFilter + run: abstract.annotations.HasRunFilter + scenario: abstract.annotations.HasScenarioFilter + unit: abstract.annotations.HasUnitFilter + variable: abstract.annotations.HasVariableFilter + + class Enumerator(Lister[ModelType], Tabulator[ModelType]): def enumerate( - self, *args, table: bool = False, **kwargs + self, table: bool = False, **kwargs: Unpack[EnumerateKwargs] ) -> list[ModelType] | pd.DataFrame: - if table: - return self.tabulate(*args, **kwargs) - else: - return self.list(*args, **kwargs) + return self.tabulate(**kwargs) if table else self.list(**kwargs) def paginate( - self, - *args, - limit: int = 1000, - offset: int = 0, - **kwargs, + self, limit: int = 1000, offset: int = 0, **kwargs: Unpack[PaginateKwargs] ) -> list[ModelType] | pd.DataFrame: return self.enumerate( - *args, **kwargs, _post_filter=lambda e: e.offset(offset).limit(limit) + **kwargs, _post_filter=lambda e: e.offset(offset).limit(limit) ) - def count( - self, - **kwargs, - ) -> int: - _exc = self.select( + def count(self, **kwargs: Unpack[CountKwargs]) -> int: + _exc = self.select_for_count( _exc=db.select(db.func.count(self.model_class.id.distinct())), **kwargs, ) @@ -284,7 +369,7 @@ def merge_existing( set(existing_df.columns) & set(df.columns) & set(columns.keys()) ) # all cols which exist in both dfs and the db model - set(self.model_class.updateable_columns) # no updateable columns - - set(primary_key_columns) # no pk columns + - set(primary_key_columns.keys()) # no pk columns ) # = all columns that are constant and provided during creation return df.merge( @@ -311,7 +396,7 @@ def drop_merge_artifacts( def split_by_max_unique_values( self, df: pd.DataFrame, columns: Iterable[str], mu: int - ) -> Tuple[pd.DataFrame, pd.DataFrame]: + ) -> tuple[pd.DataFrame, pd.DataFrame]: df_len = len(df.index) chunk_size = df_len remaining_df = pd.DataFrame() @@ -403,7 +488,7 @@ def bulk_upsert_chunk(self, df: pd.DataFrame) -> None: self.session.commit() - def bulk_insert(self, df: pd.DataFrame, **kwargs) -> None: + def bulk_insert(self, df: pd.DataFrame, **kwargs: Any) -> None: # to_dict returns a more general list[Mapping[Hashable, Unknown]] if "id" in df.columns: raise ProgrammingError("You may not insert the 'id' column.") @@ -418,7 +503,7 @@ def bulk_insert(self, df: pd.DataFrame, **kwargs) -> None: except IntegrityError as e: raise self.model_class.NotUnique(*e.args) - def bulk_update(self, df: pd.DataFrame, **kwargs) -> None: + def bulk_update(self, df: pd.DataFrame, **kwargs: Any) -> None: # to_dict returns a more general list[Mapping[Hashable, Unknown]] m = cast(list[dict[str, Any]], df.to_dict("records")) self.session.execute( diff --git a/ixmp4/data/db/docs.py b/ixmp4/data/db/docs.py index 39fa97d4..75ea573e 100644 --- a/ixmp4/data/db/docs.py +++ b/ixmp4/data/db/docs.py @@ -1,8 +1,10 @@ from typing import ClassVar, TypeVar -import pandas as pd from sqlalchemy.exc import NoResultFound +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + from ixmp4 import db from ixmp4.data import abstract, types @@ -17,9 +19,9 @@ class AbstractDocs(base.BaseModel): __abstract__ = True - description: types.Mapped + description: types.Mapped[str] - dimension__id: types.Mapped + dimension__id: types.Mapped[int] def docs_model(model: type[base.BaseModel]) -> type[AbstractDocs]: @@ -40,6 +42,10 @@ def docs_model(model: type[base.BaseModel]) -> type[AbstractDocs]: DocsType = TypeVar("DocsType", bound=AbstractDocs) +class ListKwargs(TypedDict, total=False): + dimension_id: int | None + + class BaseDocsRepository( base.Creator[DocsType], base.Retriever[DocsType], @@ -50,8 +56,11 @@ class BaseDocsRepository( dimension_model_class: ClassVar[type[base.BaseModel]] def select( - self, *, _exc: db.sql.Select | None = None, dimension_id: int | None = None - ) -> db.sql.Select: + self, + *, + _exc: db.sql.Select[tuple[DocsType]] | None = None, + dimension_id: int | None = None, + ) -> db.sql.Select[tuple[DocsType]]: if _exc is None: _exc = db.select(self.model_class) @@ -60,6 +69,14 @@ def select( return _exc + def select_for_count( + self, _exc: db.sql.Select[tuple[int]], dimension_id: int | None = None + ) -> db.sql.Select[tuple[int]]: + if dimension_id is not None: + _exc = _exc.where(self.model_class.dimension__id == dimension_id) + + return _exc + def add(self, dimension_id: int, description: str) -> DocsType: docs = self.model_class(description=description, dimension__id=dimension_id) self.session.add(docs) @@ -69,8 +86,7 @@ def add(self, dimension_id: int, description: str) -> DocsType: def get(self, dimension_id: int) -> DocsType: exc = self.select(dimension_id=dimension_id) try: - docs = self.session.execute(exc).scalar_one() - return docs + return self.session.execute(exc).scalar_one() except NoResultFound: raise self.model_class.NotFound @@ -88,7 +104,7 @@ def set(self, dimension_id: int, description: str) -> DocsType: @guard("edit") def delete(self, dimension_id: int) -> None: - exc: db.sql.Delete = db.delete(self.model_class).where( + exc = db.delete(self.model_class).where( self.model_class.dimension__id == dimension_id ) @@ -104,9 +120,5 @@ def delete(self, dimension_id: int) -> None: raise self.model_class.NotFound @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) - - @guard("view") - def list(self, *args, **kwargs) -> list[DocsType]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[ListKwargs]) -> list[DocsType]: + return super().list(**kwargs) diff --git a/ixmp4/data/db/events.py b/ixmp4/data/db/events.py index 0438ccbb..511f9901 100644 --- a/ixmp4/data/db/events.py +++ b/ixmp4/data/db/events.py @@ -1,8 +1,9 @@ import logging +from collections.abc import Callable, Generator, Mapping, Sequence from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast +from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import Connection, event, sql +from sqlalchemy import Connection, Result, event, sql from sqlalchemy.orm import Mapper, ORMExecuteState from ixmp4 import db @@ -21,7 +22,9 @@ class SqlaEventHandler(object): def __init__(self, backend: "SqlAlchemyBackend") -> None: self.backend = backend - self.listeners = [ + self.listeners: list[ + tuple[tuple[Any, str, Callable[..., None]], dict[str, bool]] + ] = [ ((backend.session, "do_orm_execute", self.receive_do_orm_execute), {}), ( (base.BaseModel, "before_insert", self.receive_before_insert), @@ -34,28 +37,30 @@ def __init__(self, backend: "SqlAlchemyBackend") -> None: ] self.add_listeners() - def add_listeners(self): + def add_listeners(self) -> None: + # Somehow, mypy knows how long args is, but tries to insert kwargs as argument + # 2 and 3 for args, kwargs in self.listeners: - event.listen(*args, **kwargs) + event.listen(*args, **kwargs) # type: ignore[arg-type] - def remove_listeners(self): + def remove_listeners(self) -> None: for args, kwargs in self.listeners: if event.contains(*args): event.remove(*args) @contextmanager - def pause(self): + def pause(self) -> Generator[None, Any, None]: """Temporarily removes all event listeners for the enclosed scope.""" self.remove_listeners() yield self.add_listeners() - def set_logger(self, state): + def set_logger(self, state: tuple[int, int, int] | ORMExecuteState) -> None: self.logger = logging.getLogger(__name__ + "." + str(id(state))) def receive_before_insert( - self, mapper: Mapper, connection: Connection, target: base.BaseModel - ): + self, mapper: Mapper[Any], connection: Connection, target: base.BaseModel + ) -> None: """Handles the insert event when creating data like this: ``` model = Model(**kwargs) @@ -73,8 +78,8 @@ def receive_before_insert( target.set_creation_info(self.backend.auth_context) def receive_before_update( - self, mapper: Mapper, connection: Connection, target: base.BaseModel - ): + self, mapper: Mapper[Any], connection: Connection, target: base.BaseModel + ) -> None: """Handles the update event when changing data like this: ``` model = query_model() @@ -91,7 +96,7 @@ def receive_before_update( self.logger.debug(f"Setting update info for: {target}") target.set_update_info(self.backend.auth_context) - def receive_do_orm_execute(self, orm_execute_state: ORMExecuteState): + def receive_do_orm_execute(self, orm_execute_state: ORMExecuteState) -> None: """Handles ORM execution events like: ``` exc = select/update/delete(Model) @@ -105,25 +110,25 @@ def receive_do_orm_execute(self, orm_execute_state: ORMExecuteState): self.set_logger(orm_execute_state) self.logger.debug("Received 'do_orm_execute' event.") if orm_execute_state.is_select: - return self.receive_select(orm_execute_state) + self.receive_select(orm_execute_state) else: if orm_execute_state.is_insert: self.logger.debug("Operation: 'insert'") - return self.receive_insert(orm_execute_state) + self.receive_insert(orm_execute_state) if orm_execute_state.is_update: self.logger.debug("Operation: 'update'") - return self.receive_update(orm_execute_state) + self.receive_update(orm_execute_state) if orm_execute_state.is_delete: self.logger.debug("Operation: 'delete'") - return self.receive_delete(orm_execute_state) + self.receive_delete(orm_execute_state) else: self.logger.debug(f"Ignoring operation: {orm_execute_state}") - def receive_select(self, oes: ORMExecuteState): + def receive_select(self, oes: ORMExecuteState) -> None: # select = cast(sql.Select, oes.statement) pass - def receive_insert(self, oes: ORMExecuteState): + def receive_insert(self, oes: ORMExecuteState) -> Result[Any] | None: insert = cast(sql.Insert, oes.statement) entity = insert.entity_description type_ = entity["type"] @@ -135,11 +140,13 @@ def receive_insert(self, oes: ORMExecuteState): "created_at": type_.get_timestamp(), } + assert oes.parameters is not None return oes.invoke_statement( params=self.get_extra_params(oes.parameters, creation_info) ) + return None - def receive_update(self, oes: ORMExecuteState): + def receive_update(self, oes: ORMExecuteState) -> Result[Any] | None: update = cast(sql.Update, oes.statement) entity = update.entity_description type_ = entity["type"] @@ -159,12 +166,14 @@ def receive_update(self, oes: ORMExecuteState): return oes.invoke_statement(statement=new_statement) else: raise ProgrammingError(f"Cannot handle update statement: {update}") + return None - def receive_delete(self, oes: ORMExecuteState): + def receive_delete(self, oes: ORMExecuteState) -> None: # delete = cast(sql.Delete, oes.statement) pass - def select_affected(self, statement: sql.Delete | sql.Update) -> sql.Select: + # Can't find any reference use cases + def select_affected(self, statement: sql.Delete | sql.Update) -> sql.Select[Any]: entity = statement.entity_description wc = statement.whereclause exc = db.select(entity["entity"]) @@ -172,8 +181,9 @@ def select_affected(self, statement: sql.Delete | sql.Update) -> sql.Select: exc.where(wc.expression) return exc - def get_extra_params(self, params, extra): - if isinstance(params, Sequence): - return [extra] * len(params) - else: - return extra + def get_extra_params( + self, + params: Sequence[Mapping[str, Any]] | Mapping[str, Any], + extra: dict[str, Any], + ) -> list[dict[str, Any]] | dict[str, Any]: + return [extra] * len(params) if isinstance(params, Sequence) else extra diff --git a/ixmp4/data/db/filters/meta.py b/ixmp4/data/db/filters/meta.py index 9fb28256..897a3018 100644 --- a/ixmp4/data/db/filters/meta.py +++ b/ixmp4/data/db/filters/meta.py @@ -1,6 +1,6 @@ from typing import ClassVar -from ixmp4.db import filters +from ixmp4.db import Session, filters, sql from .. import RunMetaEntry @@ -9,8 +9,8 @@ class RunMetaEntryFilter(filters.BaseFilter, metaclass=filters.FilterMeta): sqla_model: ClassVar[type] = RunMetaEntry id: filters.Id - type: filters.String - run__id: filters.Integer = filters.Field(None, alias="run_id") + dtype: filters.String + run__id: filters.Integer | None = filters.Field(None, alias="run_id") key: filters.String @@ -19,5 +19,8 @@ class RunMetaEntryFilter(filters.BaseFilter, metaclass=filters.FilterMeta): value_float: filters.Float value_bool: filters.Boolean - def join(self, exc, **kwargs): + # NOTE specific type hint here is based on usage; adapt accordingly + def join( + self, exc: sql.Select[tuple[RunMetaEntry]], session: Session | None = None + ) -> sql.Select[tuple[RunMetaEntry]]: return exc diff --git a/ixmp4/data/db/filters/model.py b/ixmp4/data/db/filters/model.py index 055262a4..0d2cb839 100644 --- a/ixmp4/data/db/filters/model.py +++ b/ixmp4/data/db/filters/model.py @@ -1,17 +1,19 @@ from typing import ClassVar -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, utils from .. import Model, Run +from ..base import SelectType class ModelFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) sqla_model: ClassVar[type] = Model - def join(self, exc, **kwargs): + # TODO using general form here as this seems to be callable on non-Model tables + def join(self, exc: SelectType, session: Session | None = None) -> SelectType: if not utils.is_joined(exc, Model): exc = exc.join(Model, Run.model) return exc diff --git a/ixmp4/data/db/filters/optimizationcolumn.py b/ixmp4/data/db/filters/optimizationcolumn.py index 0b234432..62dd702a 100644 --- a/ixmp4/data/db/filters/optimizationcolumn.py +++ b/ixmp4/data/db/filters/optimizationcolumn.py @@ -1,6 +1,6 @@ from typing import ClassVar -from ixmp4.db import filters +from ixmp4.db import filters, sql from .. import Column, Run @@ -8,10 +8,11 @@ class OptimizationColumnFilter(filters.BaseFilter, metaclass=filters.FilterMeta): id: filters.Id name: filters.String - run__id: filters.Integer = filters.Field(None, alias="run_id") + run__id: filters.Integer | None = filters.Field(None, alias="run_id") sqla_model: ClassVar[type] = Column - def join(self, exc, **kwargs): - exc = exc.join(Run, onclause=Column.run__id == Run.id) + # Not fixing this since I think we don't need columns + def join(self, exc: sql.Select, **kwargs) -> sql.Select: # type: ignore[no-untyped-def,type-arg] + exc = exc.join(Run, onclause=Column.run__id == Run.id) # type: ignore[attr-defined] return exc diff --git a/ixmp4/data/db/filters/optimizationequation.py b/ixmp4/data/db/filters/optimizationequation.py index 4073247a..40b76676 100644 --- a/ixmp4/data/db/filters/optimizationequation.py +++ b/ixmp4/data/db/filters/optimizationequation.py @@ -1,17 +1,19 @@ from typing import ClassVar -from ixmp4.db import filters +from ixmp4.db import Session, filters, sql from .. import Equation, Run class OptimizationEquationFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String - run__id: filters.Integer = filters.Field(None, alias="run_id") + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) + run__id: filters.Integer | None = filters.Field(None, alias="run_id") sqla_model: ClassVar[type] = Equation - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Equation]], session: Session | None = None + ) -> sql.Select[tuple[Equation]]: exc = exc.join(Run, onclause=Equation.run__id == Run.id) return exc diff --git a/ixmp4/data/db/filters/optimizationindexset.py b/ixmp4/data/db/filters/optimizationindexset.py index a737a513..3592d557 100644 --- a/ixmp4/data/db/filters/optimizationindexset.py +++ b/ixmp4/data/db/filters/optimizationindexset.py @@ -1,17 +1,19 @@ from typing import ClassVar -from ixmp4.db import filters +from ixmp4.db import Session, filters, sql from .. import IndexSet, Run class OptimizationIndexSetFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String - run__id: filters.Integer = filters.Field(None, alias="run_id") + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) + run__id: filters.Integer | None = filters.Field(None, alias="run_id") sqla_model: ClassVar[type] = IndexSet - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[IndexSet]], session: Session | None = None + ) -> sql.Select[tuple[IndexSet]]: exc = exc.join(Run, onclause=IndexSet.run__id == Run.id) return exc diff --git a/ixmp4/data/db/filters/optimizationparameter.py b/ixmp4/data/db/filters/optimizationparameter.py index 5fe142a6..4ccb9f45 100644 --- a/ixmp4/data/db/filters/optimizationparameter.py +++ b/ixmp4/data/db/filters/optimizationparameter.py @@ -1,17 +1,19 @@ from typing import ClassVar -from ixmp4.db import filters +from ixmp4.db import Session, filters, sql from .. import Parameter, Run class OptimizationParameterFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String - run__id: filters.Integer = filters.Field(None, alias="run_id") + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) + run__id: filters.Integer | None = filters.Field(None, alias="run_id") sqla_model: ClassVar[type] = Parameter - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Parameter]], session: Session | None = None + ) -> sql.Select[tuple[Parameter]]: exc = exc.join(Run, onclause=Parameter.run__id == Run.id) return exc diff --git a/ixmp4/data/db/filters/optimizationscalar.py b/ixmp4/data/db/filters/optimizationscalar.py index ca6aeedb..f5938f56 100644 --- a/ixmp4/data/db/filters/optimizationscalar.py +++ b/ixmp4/data/db/filters/optimizationscalar.py @@ -1,19 +1,21 @@ from typing import ClassVar -from ixmp4.db import filters +from ixmp4.db import Session, filters, sql from .. import Run, Scalar, Unit class OptimizationScalarFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String - run__id: filters.Integer = filters.Field(None, alias="run_id") - unit__id: filters.Integer = filters.Field(None, alias="unit_id") + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) + run__id: filters.Integer | None = filters.Field(None, alias="run_id") + unit__id: filters.Integer | None = filters.Field(None, alias="unit_id") sqla_model: ClassVar[type] = Scalar - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Scalar]], session: Session | None = None + ) -> sql.Select[tuple[Scalar]]: exc = exc.join(Run, onclause=Scalar.run__id == Run.id) exc = exc.join(Unit, onclause=Scalar.unit__id == Unit.id) return exc diff --git a/ixmp4/data/db/filters/optimizationtable.py b/ixmp4/data/db/filters/optimizationtable.py index a8aba970..e97e5b7b 100644 --- a/ixmp4/data/db/filters/optimizationtable.py +++ b/ixmp4/data/db/filters/optimizationtable.py @@ -1,17 +1,19 @@ from typing import ClassVar -from ixmp4.db import filters +from ixmp4.db import Session, filters, sql from .. import Run, Table class OptimizationTableFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String - run__id: filters.Integer = filters.Field(None, alias="run_id") + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) + run__id: filters.Integer | None = filters.Field(None, alias="run_id") sqla_model: ClassVar[type] = Table - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Table]], session: Session | None = None + ) -> sql.Select[tuple[Table]]: exc = exc.join(Run, onclause=Table.run__id == Run.id) return exc diff --git a/ixmp4/data/db/filters/optimizationvariable.py b/ixmp4/data/db/filters/optimizationvariable.py index 63d9a5c6..f5e81869 100644 --- a/ixmp4/data/db/filters/optimizationvariable.py +++ b/ixmp4/data/db/filters/optimizationvariable.py @@ -1,17 +1,21 @@ from typing import ClassVar -from ixmp4.db import filters +from ixmp4.db import Session, filters, sql from .. import OptimizationVariable, Run class OptimizationVariableFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String - run__id: filters.Integer = filters.Field(None, alias="run_id") + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) + run__id: filters.Integer | None = filters.Field(None, alias="run_id") sqla_model: ClassVar[type] = OptimizationVariable - def join(self, exc, **kwargs): + def join( + self, + exc: sql.Select[tuple[OptimizationVariable]], + session: Session | None = None, + ) -> sql.Select[tuple[OptimizationVariable]]: exc = exc.join(Run, onclause=OptimizationVariable.run__id == Run.id) return exc diff --git a/ixmp4/data/db/filters/region.py b/ixmp4/data/db/filters/region.py index 34eabaec..5d75d6c9 100644 --- a/ixmp4/data/db/filters/region.py +++ b/ixmp4/data/db/filters/region.py @@ -1,17 +1,19 @@ from typing import ClassVar -from ixmp4.db import filters +from ixmp4.db import Session, filters, sql from .. import Region, TimeSeries class RegionFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String - hierarchy: filters.String + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) + hierarchy: filters.String | None = filters.Field(None) sqla_model: ClassVar[type] = Region - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Region]], session: Session | None = None + ) -> sql.Select[tuple[Region]]: exc = exc.join(Region, TimeSeries.region) return exc diff --git a/ixmp4/data/db/filters/run.py b/ixmp4/data/db/filters/run.py index 879f73f2..df3b6f83 100644 --- a/ixmp4/data/db/filters/run.py +++ b/ixmp4/data/db/filters/run.py @@ -1,30 +1,34 @@ -from typing import ClassVar +from typing import Any, ClassVar -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, sql, typing_column, utils from .. import Run +from ..base import SelectType from ..iamc import TimeSeries from .model import ModelFilter from .scenario import ScenarioFilter class RunFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - version: filters.Integer + id: filters.Id | None = filters.Field(None) + version: filters.Integer | None = filters.Field(None) default_only: filters.Boolean = filters.Field(True) - is_default: filters.Boolean - model: ModelFilter - scenario: ScenarioFilter + is_default: filters.Boolean | None = filters.Field(None) + model: ModelFilter | None = filters.Field(None) + scenario: ScenarioFilter | None = filters.Field(None) sqla_model: ClassVar[type] = Run - def filter_default_only(self, exc, c, v, **kwargs): - if v: - return exc.where(Run.is_default) - else: - return exc + def filter_default_only( + self, + exc: sql.Select[tuple[Run]], + c: typing_column[Any], # Any since it is unused + v: bool, + session: Session | None = None, + ) -> sql.Select[tuple[Run]]: + return exc.where(Run.is_default) if v else exc - def join(self, exc, **kwargs): + def join(self, exc: SelectType, session: Session | None = None) -> SelectType: if not utils.is_joined(exc, Run): exc = exc.join(Run, TimeSeries.run) return exc diff --git a/ixmp4/data/db/filters/scenario.py b/ixmp4/data/db/filters/scenario.py index bf8afdf8..fffb63d0 100644 --- a/ixmp4/data/db/filters/scenario.py +++ b/ixmp4/data/db/filters/scenario.py @@ -1,17 +1,18 @@ from typing import ClassVar -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, utils from .. import Run, Scenario +from ..base import SelectType class ScenarioFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) sqla_model: ClassVar[type] = Scenario - def join(self, exc, **kwargs): + def join(self, exc: SelectType, session: Session | None = None) -> SelectType: if not utils.is_joined(exc, Scenario): exc = exc.join(Scenario, Run.scenario) return exc diff --git a/ixmp4/data/db/filters/timeseries.py b/ixmp4/data/db/filters/timeseries.py index 455e5098..39b1dbcf 100644 --- a/ixmp4/data/db/filters/timeseries.py +++ b/ixmp4/data/db/filters/timeseries.py @@ -1,16 +1,17 @@ from typing import ClassVar -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, utils from .. import Run, TimeSeries +from ..base import SelectType class TimeSeriesFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id + id: filters.Id | None = filters.Field(None) sqla_model: ClassVar[type] = TimeSeries - def join(self, exc, **kwargs): + def join(self, exc: SelectType, session: Session | None = None) -> SelectType: if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, onclause=TimeSeries.run__id == Run.id) return exc diff --git a/ixmp4/data/db/filters/unit.py b/ixmp4/data/db/filters/unit.py index c0bc5a4c..8b6c3ed1 100644 --- a/ixmp4/data/db/filters/unit.py +++ b/ixmp4/data/db/filters/unit.py @@ -1,17 +1,18 @@ from typing import ClassVar -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, utils from .. import Measurand, TimeSeries, Unit +from ..base import SelectType class UnitFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) sqla_model: ClassVar[type] = Unit - def join(self, exc, **kwargs): + def join(self, exc: SelectType, session: Session | None = None) -> SelectType: if not utils.is_joined(exc, Measurand): exc = exc.join(Measurand, TimeSeries.measurand) exc = exc.join(Unit, Measurand.unit) diff --git a/ixmp4/data/db/filters/variable.py b/ixmp4/data/db/filters/variable.py index df136a3d..cf1d435a 100644 --- a/ixmp4/data/db/filters/variable.py +++ b/ixmp4/data/db/filters/variable.py @@ -1,17 +1,18 @@ from typing import ClassVar -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, utils from .. import Measurand, TimeSeries, Variable +from ..base import SelectType class VariableFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id - name: filters.String + id: filters.Id | None = filters.Field(None) + name: filters.String | None = filters.Field(None) sqla_model: ClassVar[type] = Variable - def join(self, exc, **kwargs): + def join(self, exc: SelectType, session: Session | None = None) -> SelectType: if not utils.is_joined(exc, Measurand): exc = exc.join(Measurand, TimeSeries.measurand) exc = exc.join(Variable, Measurand.variable) diff --git a/ixmp4/data/db/iamc/__init__.py b/ixmp4/data/db/iamc/__init__.py index f32c3ee5..ce0bb90f 100644 --- a/ixmp4/data/db/iamc/__init__.py +++ b/ixmp4/data/db/iamc/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa from .datapoint import ( # AnnualDataPoint,; SubAnnualDataPoint,; CategoricalDataPoint DataPoint, DataPointRepository, diff --git a/ixmp4/data/db/iamc/base.py b/ixmp4/data/db/iamc/base.py index 98b56581..5611ccf5 100644 --- a/ixmp4/data/db/iamc/base.py +++ b/ixmp4/data/db/iamc/base.py @@ -1,4 +1,3 @@ -# flake8: noqa from ..base import BaseModel as RootBaseModel from ..base import ( BulkDeleter, diff --git a/ixmp4/data/db/iamc/datapoint/__init__.py b/ixmp4/data/db/iamc/datapoint/__init__.py index edff1cdd..c0a94471 100644 --- a/ixmp4/data/db/iamc/datapoint/__init__.py +++ b/ixmp4/data/db/iamc/datapoint/__init__.py @@ -1,3 +1,2 @@ -# flake8: noqa from .model import DataPoint, UniversalDataPoint, get_datapoint_model from .repository import DataPointRepository diff --git a/ixmp4/data/db/iamc/datapoint/filter.py b/ixmp4/data/db/iamc/datapoint/filter.py index 36436d33..e34b67c0 100644 --- a/ixmp4/data/db/iamc/datapoint/filter.py +++ b/ixmp4/data/db/iamc/datapoint/filter.py @@ -1,5 +1,6 @@ -from typing import ClassVar +from typing import Any, ClassVar +from ixmp4.data.db.base import SelectType from ixmp4.data.db.iamc.datapoint import get_datapoint_model from ixmp4.data.db.iamc.measurand import Measurand from ixmp4.data.db.iamc.timeseries import TimeSeries @@ -9,7 +10,7 @@ from ixmp4.data.db.run import Run from ixmp4.data.db.scenario import Scenario from ixmp4.data.db.unit import Unit -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, sql, typing_column, utils class RegionFilter(filters.BaseFilter, metaclass=filters.FilterMeta): @@ -18,7 +19,7 @@ class RegionFilter(filters.BaseFilter, metaclass=filters.FilterMeta): sqla_model: ClassVar[type] = Region - def join(self, exc, session): + def join(self, exc: SelectType, session: Session) -> SelectType: model = get_datapoint_model(session) if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, onclause=model.time_series__id == TimeSeries.id) @@ -32,7 +33,7 @@ class UnitFilter(filters.BaseFilter, metaclass=filters.FilterMeta): sqla_model: ClassVar[type] = Unit - def join(self, exc, session): + def join(self, exc: SelectType, session: Session) -> SelectType: model = get_datapoint_model(session) if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, onclause=model.time_series__id == TimeSeries.id) @@ -49,7 +50,7 @@ class VariableFilter(filters.BaseFilter, metaclass=filters.FilterMeta): sqla_model: ClassVar[type] = Variable - def join(self, exc, session): + def join(self, exc: SelectType, session: Session) -> SelectType: model = get_datapoint_model(session) if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, onclause=model.time_series__id == TimeSeries.id) @@ -65,7 +66,7 @@ class ModelFilter(filters.BaseFilter, metaclass=filters.FilterMeta): sqla_model: ClassVar[type] = Model - def join(self, exc, session): + def join(self, exc: SelectType, session: Session) -> SelectType: model = get_datapoint_model(session) if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, onclause=model.time_series__id == TimeSeries.id) @@ -81,7 +82,7 @@ class ScenarioFilter(filters.BaseFilter, metaclass=filters.FilterMeta): sqla_model: ClassVar[type] = Scenario - def join(self, exc, session): + def join(self, exc: SelectType, session: Session) -> SelectType: model = get_datapoint_model(session) if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, onclause=model.time_series__id == TimeSeries.id) @@ -93,12 +94,12 @@ def join(self, exc, session): class RunFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - id: filters.Id + id: filters.Id | None = filters.Field(None) default_only: filters.Boolean = filters.Field(True) sqla_model: ClassVar[type] = Run - def join(self, exc, session): + def join(self, exc: SelectType, session: Session) -> SelectType: model = get_datapoint_model(session) if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, model.time_series__id == TimeSeries.id) @@ -106,15 +107,18 @@ def join(self, exc, session): exc = exc.join(Run, TimeSeries.run) return exc - def filter_default_only(self, exc, c, v, **kwargs): - if v: - return exc.where(Run.is_default) - else: - return exc + def filter_default_only( + self, + exc: sql.Select[tuple[Run]], + c: typing_column[Any], # Any since it is unused + v: bool | None, + session: Session | None = None, + ) -> sql.Select[tuple[Run]]: + return exc.where(Run.is_default) if v else exc class DataPointFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - """This class is used for filtering data points + """This class is used for filtering data points. All parameters are optional. Use the field name (or the field alias) directly for equality comparisons. For performing an SQL IN operation @@ -161,11 +165,11 @@ class DataPointFilter(filters.BaseFilter, metaclass=filters.FilterMeta): >>> iamc.tabulate(**filter) """ - step_year: filters.Integer = filters.Field(None, alias="year") - time_series__id: filters.Id = filters.Field(None, alias="time_series_id") - region: RegionFilter - unit: UnitFilter - variable: VariableFilter - model: ModelFilter - scenario: ScenarioFilter + step_year: filters.Integer | None = filters.Field(None, alias="year") + time_series__id: filters.Id | None = filters.Field(None, alias="time_series_id") + region: RegionFilter | None = filters.Field(None) + unit: UnitFilter | None = filters.Field(None) + variable: VariableFilter | None = filters.Field(None) + model: ModelFilter | None = filters.Field(None) + scenario: ScenarioFilter | None = filters.Field(None) run: RunFilter = filters.Field(RunFilter()) diff --git a/ixmp4/data/db/iamc/datapoint/model.py b/ixmp4/data/db/iamc/datapoint/model.py index 7b903256..1c659324 100644 --- a/ixmp4/data/db/iamc/datapoint/model.py +++ b/ixmp4/data/db/iamc/datapoint/model.py @@ -20,7 +20,7 @@ class DataPoint(base.BaseModel): updateable_columns = ["value"] @declared_attr - def time_series__id(cls): + def time_series__id(cls) -> db.MappedColumn[int]: return db.Column( db.Integer, db.ForeignKey("iamc_timeseries.id"), @@ -48,5 +48,5 @@ def __tablename__(cls) -> str: ) -def get_datapoint_model(session) -> type[DataPoint]: +def get_datapoint_model(session: db.Session) -> type[DataPoint]: return UniversalDataPoint diff --git a/ixmp4/data/db/iamc/datapoint/repository.py b/ixmp4/data/db/iamc/datapoint/repository.py index 753deaf2..481e3b65 100644 --- a/ixmp4/data/db/iamc/datapoint/repository.py +++ b/ixmp4/data/db/iamc/datapoint/repository.py @@ -1,10 +1,14 @@ -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd import pandera as pa +from pandera.engines import pandas_engine from pandera.typing import DataFrame, Series +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4 import db from ixmp4.core.decorators import check_types from ixmp4.core.exceptions import InconsistentIamcType, ProgrammingError @@ -15,6 +19,7 @@ from ixmp4.data.db.run import Run, RunRepository from ixmp4.data.db.scenario import Scenario from ixmp4.data.db.unit import Unit +from ixmp4.db.filters import BaseFilter from .. import base from ..measurand import Measurand @@ -24,11 +29,16 @@ from .filter import DataPointFilter from .model import DataPoint +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + class RemoveDataPointFrameSchema(pa.DataFrameModel): type: Series[pa.String] | None = pa.Field(isin=[t for t in DataPoint.Type]) step_year: Series[pa.Int] | None = pa.Field(coerce=True, nullable=True) - step_datetime: Series[pa.DateTime] | None = pa.Field(coerce=True, nullable=True) + step_datetime: Series[pandas_engine.DateTime] | None = pa.Field( + coerce=True, nullable=True + ) step_category: Series[pa.String] | None = pa.Field(nullable=True) time_series__id: Series[pa.Int] = pa.Field(coerce=True) @@ -73,6 +83,12 @@ def infer_content(df: pd.DataFrame, col: str) -> str: raise InconsistentIamcType +class EnumerateKwargs(abstract.iamc.datapoint.EnumerateKwargs, total=False): + join_parameters: bool | None + join_runs: bool + _filter: BaseFilter + + class DataPointRepository( base.Enumerator[DataPoint], base.BulkUpserter[DataPoint], @@ -83,18 +99,20 @@ class DataPointRepository( timeseries: TimeSeriesRepository runs: RunRepository - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: "SqlAlchemyBackend") -> None: backend, *_ = args # A different table was used for ORACLE databases (deprecated since ixmp4 0.3.0) self.model_class = get_datapoint_model(backend.session) - self.timeseries = TimeSeriesRepository(*args, **kwargs) - self.runs = RunRepository(*args, **kwargs) + self.timeseries = TimeSeriesRepository(*args) + self.runs = RunRepository(*args) self.filter_class = DataPointFilter - super().__init__(*args, **kwargs) + super().__init__(*args) - def join_auth(self, exc: db.sql.Select) -> db.sql.Select: + def join_auth( + self, exc: db.sql.Select[tuple[DataPoint]] + ) -> db.sql.Select[tuple[DataPoint]]: if not db.utils.is_joined(exc, TimeSeries): exc = exc.join( TimeSeries, onclause=self.model_class.time_series__id == TimeSeries.id @@ -106,8 +124,12 @@ def join_auth(self, exc: db.sql.Select) -> db.sql.Select: return exc - def select_joined_parameters(self, join_runs=False): - bundle = [] + def select_joined_parameters( + self, join_runs: bool = False + ) -> db.sql.Select[tuple[DataPoint]]: + # NOTE Not quite sure about this bundle, seems to possibly take all types of + # all model classes? + bundle: list[db.Label[str] | db.Label[int] | db.Bundle[Any]] = [] if join_runs: bundle.extend( [ @@ -151,9 +173,9 @@ def select( join_parameters: bool | None = False, join_runs: bool = False, _filter: DataPointFilter | None = None, - _exc: db.sql.Select | None = None, + _exc: db.sql.Select[tuple[DataPoint]] | None = None, **kwargs: Any, - ) -> db.sql.Select: + ) -> db.sql.Select[tuple[DataPoint]]: if _exc is not None: exc = _exc elif join_parameters: @@ -164,12 +186,12 @@ def select( return super().select(_exc=exc, _filter=_filter, **kwargs) @guard("view") - def list(self, *args, **kwargs) -> list[DataPoint]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[DataPoint]: + return super().list(**kwargs) @guard("view") def tabulate( - self, *args: Any, _raw: bool | None = False, **kwargs: Any + self, *args: Any, _raw: bool | None = False, **kwargs: Unpack[EnumerateKwargs] ) -> pd.DataFrame: if _raw: return super().tabulate(*args, **kwargs) @@ -180,7 +202,7 @@ def tabulate( ) return df.dropna(axis="columns", how="all") - def check_df_access(self, df: pd.DataFrame): + def check_df_access(self, df: pd.DataFrame) -> None: if self.backend.auth_context is not None: ts_ids = set(df["time_series__id"].unique().tolist()) self.timeseries.check_access( @@ -192,29 +214,28 @@ def check_df_access(self, df: pd.DataFrame): @check_types @guard("edit") def bulk_upsert(self, df: DataFrame[AddDataPointFrameSchema]) -> None: - return super().bulk_upsert(df) + super().bulk_upsert(df) @check_types @guard("edit") def bulk_insert(self, df: DataFrame[AddDataPointFrameSchema]) -> None: self.check_df_access(df) - return super().bulk_insert(df) + super().bulk_insert(df) @check_types @guard("edit") def bulk_update(self, df: DataFrame[UpdateDataPointFrameSchema]) -> None: self.check_df_access(df) - return super().bulk_update(df) + super().bulk_update(df) @check_types @guard("edit") def bulk_delete(self, df: DataFrame[RemoveDataPointFrameSchema]) -> None: self.check_df_access(df) - res = super().bulk_delete(df) + super().bulk_delete(df) self.delete_orphans() - return res - def delete_orphans(self): + def delete_orphans(self) -> None: exc = db.delete(TimeSeries).where( ~db.exists( db.select(self.model_class.id).where( diff --git a/ixmp4/data/db/iamc/measurand.py b/ixmp4/data/db/iamc/measurand.py index fb99773f..ec19c59a 100644 --- a/ixmp4/data/db/iamc/measurand.py +++ b/ixmp4/data/db/iamc/measurand.py @@ -1,8 +1,11 @@ -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar import pandas as pd from sqlalchemy.exc import NoResultFound +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4 import db from ixmp4.data import types from ixmp4.data.abstract import iamc as abstract @@ -10,6 +13,10 @@ from ixmp4.data.db import mixins from ..unit import Unit + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + from . import base from .variable import Variable @@ -50,12 +57,12 @@ class MeasurandRepository( ): model_class = Measurand - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) @guard("view") def get(self, variable_name: str, unit__id: int) -> Measurand: - exc: db.sql.Select = ( + exc = ( db.select(Measurand) .join(Measurand.variable) .where(Measurand.unit__id == unit__id) @@ -68,7 +75,7 @@ def get(self, variable_name: str, unit__id: int) -> Measurand: raise Measurand.NotFound def add(self, variable_name: str, unit__id: int) -> Measurand: - q_exc: db.sql.Select = db.select(Variable).where(Variable.name == variable_name) + q_exc = db.select(Variable).where(Variable.name == variable_name) try: variable = self.session.execute(q_exc).scalar_one() except NoResultFound: @@ -80,16 +87,16 @@ def add(self, variable_name: str, unit__id: int) -> Measurand: return measurand @guard("edit") - def create(self, *args, **kwargs) -> Measurand: - return super().create(*args, **kwargs) + def create(self, *args: Unpack[tuple[str, int]]) -> Measurand: + return super().create(*args) - def select(self, *args, **kwargs) -> db.sql.Select: + def select(self) -> db.sql.Select[tuple[Measurand]]: return db.select(Measurand) @guard("view") - def list(self, *args, **kwargs) -> list[Measurand]: - return super().list(*args, **kwargs) + def list(self) -> list[Measurand]: + return super().list() @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self) -> pd.DataFrame: + return super().tabulate() diff --git a/ixmp4/data/db/iamc/timeseries/filter.py b/ixmp4/data/db/iamc/timeseries/filter.py index f0e318af..4e245371 100644 --- a/ixmp4/data/db/iamc/timeseries/filter.py +++ b/ixmp4/data/db/iamc/timeseries/filter.py @@ -1,13 +1,16 @@ from ixmp4.data.db import filters as base -from ixmp4.db import filters +from ixmp4.db import Session, filters, sql + +from .model import TimeSeries class TimeSeriesFilter(base.TimeSeriesFilter, metaclass=filters.FilterMeta): - id: filters.Id run: base.RunFilter = filters.Field(default=base.RunFilter()) - region: base.RegionFilter | None - variable: base.VariableFilter | None - unit: base.UnitFilter | None + region: base.RegionFilter | None = filters.Field(None) + variable: base.VariableFilter | None = filters.Field(None) + unit: base.UnitFilter | None = filters.Field(None) - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[TimeSeries]], session: Session | None = None + ) -> sql.Select[tuple[TimeSeries]]: return exc diff --git a/ixmp4/data/db/iamc/timeseries/model.py b/ixmp4/data/db/iamc/timeseries/model.py index f1cc21db..168ae0c2 100644 --- a/ixmp4/data/db/iamc/timeseries/model.py +++ b/ixmp4/data/db/iamc/timeseries/model.py @@ -1,4 +1,4 @@ -from typing import Mapping +from collections.abc import Mapping from ixmp4 import db from ixmp4.data import types @@ -27,7 +27,7 @@ class TimeSeries(BaseTimeSeries, base.BaseModel): ) @property - def parameters(self) -> Mapping: + def parameters(self) -> Mapping[str, str]: return { "region": self.region.name, "unit": self.measurand.unit.name, diff --git a/ixmp4/data/db/iamc/timeseries/repository.py b/ixmp4/data/db/iamc/timeseries/repository.py index 74b8e70b..5e469364 100644 --- a/ixmp4/data/db/iamc/timeseries/repository.py +++ b/ixmp4/data/db/iamc/timeseries/repository.py @@ -1,18 +1,21 @@ -from typing import Any +from typing import TYPE_CHECKING import numpy as np import pandas as pd -from sqlalchemy import select +from sqlalchemy import Select, select from sqlalchemy.orm import Bundle +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + from ixmp4.data import abstract from ixmp4.data.auth.decorators import guard +from ixmp4.data.db.base import BaseModel from ixmp4.data.db.iamc.measurand import Measurand from ixmp4.data.db.region import Region, RegionRepository from ixmp4.data.db.run import RunRepository -from ixmp4.data.db.timeseries import ( - TimeSeriesRepository as BaseTimeSeriesRepository, -) +from ixmp4.data.db.timeseries import CreateKwargs, EnumerateKwargs, GetKwargs +from ixmp4.data.db.timeseries import TimeSeriesRepository as BaseTimeSeriesRepository from ixmp4.data.db.unit import Unit, UnitRepository from ixmp4.data.db.utils import map_existing @@ -20,9 +23,13 @@ from ..variable import Variable from .model import TimeSeries +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + class TimeSeriesRepository( - BaseTimeSeriesRepository[TimeSeries], abstract.TimeSeriesRepository + BaseTimeSeriesRepository[TimeSeries], + abstract.TimeSeriesRepository[abstract.TimeSeries], ): model_class = TimeSeries @@ -30,22 +37,37 @@ class TimeSeriesRepository( measurands: MeasurandRepository units: UnitRepository - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: "SqlAlchemyBackend") -> None: from .filter import TimeSeriesFilter self.filter_class = TimeSeriesFilter - self.runs = RunRepository(*args, **kwargs) - self.regions = RegionRepository(*args, **kwargs) - self.measurands = MeasurandRepository(*args, **kwargs) - self.units = UnitRepository(*args, **kwargs) - super().__init__(*args, **kwargs) + self.runs = RunRepository(*args) + self.regions = RegionRepository(*args) + self.measurands = MeasurandRepository(*args) + self.units = UnitRepository(*args) + super().__init__(*args) + + # TODO Why do I have to essentially copy this and get_by_id() from db/timeseries? + # Mypy complains about incompatible definitions of create() and get_by_id(). + @guard("edit") + def create(self, **kwargs: Unpack[CreateKwargs]) -> TimeSeries: + return super().create(**kwargs) @guard("view") - def get(self, run_id: int, **kwargs: Any) -> TimeSeries: + def get(self, run_id: int, **kwargs: Unpack[GetKwargs]) -> TimeSeries: return super().get(run_id, **kwargs) - def select_joined_parameters(self): + @guard("view") + def get_by_id(self, id: int) -> TimeSeries: + obj = self.session.get(self.model_class, id) + + if obj is None: + raise self.model_class.NotFound + + return obj + + def select_joined_parameters(self) -> Select[tuple[BaseModel, ...]]: return ( select( self.bundle, @@ -69,12 +91,14 @@ def select_joined_parameters(self): ) @guard("view") - def list(self, **kwargs) -> list[TimeSeries]: - return super().list(**kwargs) + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[TimeSeries]: + timeseries_list: list[TimeSeries] = super().list(**kwargs) + return timeseries_list @guard("view") - def tabulate(self, **kwargs) -> pd.DataFrame: - return super().tabulate(**kwargs) + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + df: pd.DataFrame = super().tabulate(**kwargs) + return df @guard("edit") def bulk_upsert(self, df: pd.DataFrame, create_related: bool = False) -> None: @@ -92,7 +116,7 @@ def bulk_upsert(self, df: pd.DataFrame, create_related: bool = False) -> None: df = df.drop_duplicates() super().bulk_upsert(df) - def map_regions(self, df: pd.DataFrame): + def map_regions(self, df: pd.DataFrame) -> pd.DataFrame: existing_regions = self.regions.tabulate(name__in=df["region"].unique()) df, missing = map_existing( df, @@ -117,7 +141,7 @@ def map_measurands(self, df: pd.DataFrame) -> pd.DataFrame: df["measurand__id"] = np.nan - def map_measurand(df): + def map_measurand(df: pd.DataFrame) -> pd.DataFrame: variable_name, unit__id = df.name measurand = self.measurands.get_or_create( variable_name=variable_name, unit__id=int(unit__id) @@ -129,10 +153,11 @@ def map_measurand(df): # ensure compatibility with pandas < 2.2 # TODO remove legacy-handling when dropping support for pandas < 2.2 - if pd.__version__[0:3] in ["2.0", "2.1"]: - apply_args = dict() - else: - apply_args = dict(include_groups=False) + apply_args = ( + dict() + if pd.__version__[0:3] in ["2.0", "2.1"] + else dict(include_groups=False) + ) return pd.DataFrame( df.groupby(["variable", "unit__id"], group_keys=False).apply( diff --git a/ixmp4/data/db/iamc/variable/__init__.py b/ixmp4/data/db/iamc/variable/__init__.py index befe2fb2..bc381054 100644 --- a/ixmp4/data/db/iamc/variable/__init__.py +++ b/ixmp4/data/db/iamc/variable/__init__.py @@ -1,3 +1,2 @@ -# flake8: noqa from .model import Variable from .repository import VariableRepository diff --git a/ixmp4/data/db/iamc/variable/docs.py b/ixmp4/data/db/iamc/variable/docs.py index 3db542a6..f3678753 100644 --- a/ixmp4/data/db/iamc/variable/docs.py +++ b/ixmp4/data/db/iamc/variable/docs.py @@ -1,8 +1,12 @@ +from typing import Any + from ixmp4.data.db.docs import BaseDocsRepository, docs_model from .model import Variable +VariableDocs = docs_model(Variable) + -class VariableDocsRepository(BaseDocsRepository): - model_class = docs_model(Variable) # VariableDocs +class VariableDocsRepository(BaseDocsRepository[Any]): + model_class = VariableDocs dimension_model_class = Variable diff --git a/ixmp4/data/db/iamc/variable/filter.py b/ixmp4/data/db/iamc/variable/filter.py index a16112e4..4eb83550 100644 --- a/ixmp4/data/db/iamc/variable/filter.py +++ b/ixmp4/data/db/iamc/variable/filter.py @@ -1,19 +1,20 @@ from ixmp4.data.db import filters as base +from ixmp4.data.db.base import SelectType from ixmp4.data.db.iamc.timeseries import TimeSeries -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, utils from ..measurand import Measurand from . import Variable class VariableFilter(base.VariableFilter, metaclass=filters.FilterMeta): - region: base.RegionFilter | None - unit: base.UnitFilter | None + region: base.RegionFilter | None = filters.Field(None) + unit: base.UnitFilter | None = filters.Field(None) run: base.RunFilter = filters.Field( default=base.RunFilter(id=None, version=None, is_default=True) ) - def join(self, exc, session=None): + def join(self, exc: SelectType, session: Session | None = None) -> SelectType: if not utils.is_joined(exc, Measurand): exc = exc.join(Measurand, Measurand.variable__id == Variable.id) diff --git a/ixmp4/data/db/iamc/variable/repository.py b/ixmp4/data/db/iamc/variable/repository.py index ca467353..a80e706f 100644 --- a/ixmp4/data/db/iamc/variable/repository.py +++ b/ixmp4/data/db/iamc/variable/repository.py @@ -1,13 +1,31 @@ +from typing import TYPE_CHECKING + import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + from ixmp4 import db from ixmp4.data.abstract import iamc as abstract +from ixmp4.data.abstract.annotations import HasNameFilter from ixmp4.data.auth.decorators import guard +from ixmp4.db.filters import BaseFilter from .. import base from .docs import VariableDocsRepository from .model import Variable +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + + +class EnumerateKwargs(HasNameFilter, total=False): + _filter: BaseFilter + + +class CreateKwargs(TypedDict, total=False): + name: str + class VariableRepository( base.Creator[Variable], @@ -17,9 +35,9 @@ class VariableRepository( ): model_class = Variable - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.docs = VariableDocsRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + self.docs = VariableDocsRepository(*args) from .filter import VariableFilter @@ -39,13 +57,13 @@ def get(self, name: str) -> Variable: raise Variable.NotFound @guard("edit") - def create(self, *args, **kwargs) -> Variable: + def create(self, *args: str, **kwargs: Unpack[CreateKwargs]) -> Variable: return super().create(*args, **kwargs) @guard("view") - def list(self, *args, **kwargs) -> list[Variable]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Variable]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + return super().tabulate(**kwargs) diff --git a/ixmp4/data/db/meta/filter.py b/ixmp4/data/db/meta/filter.py index 9c149dca..52945351 100644 --- a/ixmp4/data/db/meta/filter.py +++ b/ixmp4/data/db/meta/filter.py @@ -1,12 +1,14 @@ from ixmp4.data.db import filters as base from ixmp4.data.db.run import Run -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, sql, utils from .model import RunMetaEntry class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[RunMetaEntry]], session: Session | None = None + ) -> sql.Select[tuple[RunMetaEntry]]: if not utils.is_joined(exc, Run): exc = exc.join(Run, onclause=RunMetaEntry.run__id == Run.id) return exc diff --git a/ixmp4/data/db/meta/model.py b/ixmp4/data/db/meta/model.py index 28d2bb1a..3ab4c707 100644 --- a/ixmp4/data/db/meta/model.py +++ b/ixmp4/data/db/meta/model.py @@ -1,4 +1,7 @@ -from typing import ClassVar +from typing import ClassVar, cast + +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack from ixmp4 import db from ixmp4.core.exceptions import InvalidRunMeta @@ -7,6 +10,12 @@ from .. import base +class InitKwargs(TypedDict): + run__id: int + key: str + value: abstract.annotations.PrimitiveTypes + + class RunMetaEntry(base.BaseModel): NotFound: ClassVar = abstract.RunMetaEntry.NotFound NotUnique: ClassVar = abstract.RunMetaEntry.NotUnique @@ -14,7 +23,7 @@ class RunMetaEntry(base.BaseModel): Type: ClassVar = abstract.RunMetaEntry.Type - _column_map = { + _column_map: dict[str, str] = { abstract.RunMetaEntry.Type.INT: "value_int", abstract.RunMetaEntry.Type.STR: "value_str", abstract.RunMetaEntry.Type.FLOAT: "value_float", @@ -28,7 +37,7 @@ class RunMetaEntry(base.BaseModel): ), ) updateable_columns = [ - "type", + "dtype", "value_int", "value_str", "value_float", @@ -48,7 +57,7 @@ class RunMetaEntry(base.BaseModel): ) key: types.String = db.Column(db.String(1023), nullable=False) - type: types.String = db.Column(db.String(20), nullable=False) + dtype: types.String = db.Column(db.String(20), nullable=False) value_int: types.Integer = db.Column(db.Integer, nullable=True) value_str: types.String = db.Column(db.String(1023), nullable=True) @@ -57,12 +66,14 @@ class RunMetaEntry(base.BaseModel): @property def value(self) -> abstract.MetaValue: - type_ = RunMetaEntry.Type(self.type) + type_ = RunMetaEntry.Type(self.dtype) col = self._column_map[type_] - return getattr(self, col) + value: abstract.MetaValue = getattr(self, col) + return value - def __init__(self, *args, **kwargs) -> None: - value = kwargs.pop("value") + def __init__(self, **kwargs: Unpack[InitKwargs]) -> None: + _kwargs = cast(dict[str, abstract.annotations.PrimitiveTypes], kwargs) + value = _kwargs.pop("value") value_type = type(value) try: type_ = RunMetaEntry.Type.from_pytype(value_type) @@ -71,6 +82,6 @@ def __init__(self, *args, **kwargs) -> None: raise InvalidRunMeta( f"Invalid type `{value_type}` for value of `RunMetaEntry`." ) - kwargs["type"] = type_ - kwargs[col] = value - super().__init__(*args, **kwargs) + _kwargs["dtype"] = type_ + _kwargs[col] = value + super().__init__(**_kwargs) diff --git a/ixmp4/data/db/meta/repository.py b/ixmp4/data/db/meta/repository.py index 09318d3c..8257be85 100644 --- a/ixmp4/data/db/meta/repository.py +++ b/ixmp4/data/db/meta/repository.py @@ -1,10 +1,16 @@ -from typing import Union +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend import pandas as pd import pandera as pa from pandera.typing import DataFrame, Series from sqlalchemy.exc import NoResultFound +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + from ixmp4 import db from ixmp4.core.decorators import check_types from ixmp4.core.exceptions import InvalidRunMeta @@ -13,6 +19,7 @@ from ixmp4.data.db.model import Model from ixmp4.data.db.run import Run from ixmp4.data.db.scenario import Scenario +from ixmp4.db.filters import BaseFilter from .. import base from .model import RunMetaEntry @@ -37,6 +44,16 @@ class UpdateRunMetaEntryFrameSchema(AddRunMetaEntryFrameSchema): id: Series[pa.Int] = pa.Field(coerce=True) +class EnumerateKwargs(TypedDict, total=False): + _filter: BaseFilter + + +class CreateKwargs(TypedDict, total=False): + run__id: int + key: str + value: abstract.annotations.PrimitiveTypes + + class RunMetaEntryRepository( base.Creator[RunMetaEntry], base.Enumerator[RunMetaEntry], @@ -46,14 +63,15 @@ class RunMetaEntryRepository( ): model_class = RunMetaEntry - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + from .filter import RunMetaEntryFilter self.filter_class = RunMetaEntryFilter def add( - self, run__id: int, key: str, value: Union[str, int, bool, float] + self, run__id: int, key: str, value: abstract.annotations.PrimitiveTypes ) -> RunMetaEntry: if self.backend.auth_context is not None: self.backend.runs.check_access( @@ -69,7 +87,7 @@ def add( self.session.add(entry) return entry - def check_df_access(self, df: pd.DataFrame): + def check_df_access(self, df: pd.DataFrame) -> None: if self.backend.auth_context is not None: ts_ids = set(df["run__id"].unique().tolist()) self.backend.runs.check_access( @@ -80,7 +98,9 @@ def check_df_access(self, df: pd.DataFrame): ) @guard("edit") - def create(self, *args, **kwargs) -> RunMetaEntry: + def create( + self, *args: abstract.annotations.PrimitiveTypes, **kwargs: Unpack[CreateKwargs] + ) -> RunMetaEntry: return super().create(*args, **kwargs) @guard("view") @@ -93,18 +113,13 @@ def get(self, run__id: int, key: str) -> RunMetaEntry: default_only=False, ) - exc = self.select( - run_id=run__id, - key=key, - ) + exc = self.select(run_id=run__id, key=key) try: - return self.session.execute(exc).scalar_one() + runmetaentry = self.session.execute(exc).scalar_one() + return runmetaentry except NoResultFound: - raise RunMetaEntry.NotFound( - run__id=run__id, - key=key, - ) + raise RunMetaEntry.NotFound(run__id=run__id, key=key) @guard("edit") def delete(self, id: int) -> None: @@ -113,9 +128,7 @@ def delete(self, id: int) -> None: pre_exc = db.select(RunMetaEntry).where(RunMetaEntry.id == id) meta = self.session.execute(pre_exc).scalar_one() except NoResultFound: - raise RunMetaEntry.NotFound( - id=id, - ) + raise RunMetaEntry.NotFound(id=id) self.backend.runs.check_access( {meta.run__id}, access_type="edit", @@ -129,11 +142,11 @@ def delete(self, id: int) -> None: self.session.execute(exc) self.session.commit() except NoResultFound: - raise RunMetaEntry.NotFound( - id=id, - ) + raise RunMetaEntry.NotFound(id=id) - def join_auth(self, exc: db.sql.Select) -> db.sql.Select: + def join_auth( + self, exc: db.sql.Select[tuple[RunMetaEntry]] + ) -> db.sql.Select[tuple[RunMetaEntry]]: if not db.utils.is_joined(exc, Run): exc = exc.join(Run, RunMetaEntry.run) if not db.utils.is_joined(exc, Model): @@ -141,7 +154,7 @@ def join_auth(self, exc: db.sql.Select) -> db.sql.Select: return super().join_auth(exc) - def select_with_run_index(self) -> db.sql.Select: + def select_with_run_index(self) -> db.sql.Select[tuple[str, str, int, Any]]: _exc = db.select( Model.name.label("model_name"), Scenario.name.label("scenario_name"), @@ -156,23 +169,22 @@ def select_with_run_index(self) -> db.sql.Select: ) @guard("view") - def list(self, *args, **kwargs) -> list[RunMetaEntry]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[RunMetaEntry]: + return super().list(**kwargs) @guard("view") def tabulate( self, - *args, join_run_index: bool = False, _raw: bool | None = False, - **kwargs, + **kwargs: Unpack[EnumerateKwargs], ) -> pd.DataFrame: if _raw: - return super().tabulate(*args, **kwargs) + return super().tabulate(**kwargs) if join_run_index: _exc = self.select_with_run_index() - df = super().tabulate(*args, _exc=_exc, **kwargs) + df = super().tabulate(_exc=_exc, **kwargs) df.drop(columns="run__id", inplace=True) df.rename( columns={"model_name": "model", "scenario_name": "scenario"}, @@ -180,30 +192,31 @@ def tabulate( ) index_columns = ["model", "scenario", "version"] else: - df = super().tabulate(*args, **kwargs) + df = super().tabulate(**kwargs) index_columns = ["run__id"] if df.empty: return pd.DataFrame( - [], columns=index_columns + ["id", "type", "key", "value"] + [], columns=index_columns + ["id", "dtype", "key", "value"] ) - def map_value_column(df: pd.DataFrame): + def map_value_column(df: pd.DataFrame) -> pd.DataFrame: type_str = df.name type_ = RunMetaEntry.Type(type_str) col = RunMetaEntry._column_map[type_] df["value"] = df[col] - df["type"] = type_str + df["dtype"] = type_str return df.drop(columns=RunMetaEntry._column_map.values()) # ensure compatibility with pandas y 2.2 # TODO remove legacy-handling when dropping support for pandas < 2.2 - if pd.__version__[0:3] in ["2.0", "2.1"]: - apply_args = dict() - else: - apply_args = dict(include_groups=False) + apply_args = ( + dict() + if pd.__version__[0:3] in ["2.0", "2.1"] + else dict(include_groups=False) + ) - return df.groupby("type", group_keys=False).apply( + return df.groupby("dtype", group_keys=False).apply( map_value_column, **apply_args ) @@ -214,13 +227,13 @@ def bulk_upsert(self, df: DataFrame[AddRunMetaEntryFrameSchema]) -> None: raise InvalidRunMeta("Illegal meta key(s): " + ", ".join(illegal_keys)) self.check_df_access(df) - df["type"] = df["value"].map(type).map(RunMetaEntry.Type.from_pytype) + df["dtype"] = df["value"].map(type).map(RunMetaEntry.Type.from_pytype) - type_: RunMetaEntry.Type - for type_, type_df in df.groupby("type"): - col = RunMetaEntry._column_map[type_] + for type_, type_df in df.groupby("dtype"): + # This cast should always be a no-op + col = RunMetaEntry._column_map[cast(str, type_)] null_cols = set(RunMetaEntry._column_map.values()) - set([col]) - type_df["type"] = type_df["type"].map(lambda x: x.value) + type_df["dtype"] = type_df["dtype"].map(lambda x: x.value) type_df = type_df.rename(columns={"value": col}) # ensure all other columns are overwritten @@ -233,4 +246,4 @@ def bulk_upsert(self, df: DataFrame[AddRunMetaEntryFrameSchema]) -> None: @guard("edit") def bulk_delete(self, df: DataFrame[RemoveRunMetaEntryFrameSchema]) -> None: self.check_df_access(df) - return super().bulk_delete(df) + super().bulk_delete(df) diff --git a/ixmp4/data/db/mixins.py b/ixmp4/data/db/mixins.py index 5f5ba99e..092eb8e4 100644 --- a/ixmp4/data/db/mixins.py +++ b/ixmp4/data/db/mixins.py @@ -18,14 +18,11 @@ class HasCreationInfo: created_by: types.Username @staticmethod - def get_username(auth_context: "AuthorizationContext | None"): - if auth_context is not None: - return auth_context.user.username - else: - return "@unknown" + def get_username(auth_context: "AuthorizationContext | None") -> str: + return auth_context.user.username if auth_context is not None else "@unknown" @staticmethod - def get_timestamp(): + def get_timestamp() -> datetime: return datetime.now(tz=timezone.utc) def set_creation_info(self, auth_context: "AuthorizationContext | None") -> None: diff --git a/ixmp4/data/db/model/__init__.py b/ixmp4/data/db/model/__init__.py index f0fcf0f0..c5a828b7 100644 --- a/ixmp4/data/db/model/__init__.py +++ b/ixmp4/data/db/model/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa from .docs import ModelDocsRepository from .model import Model from .repository import ModelRepository diff --git a/ixmp4/data/db/model/docs.py b/ixmp4/data/db/model/docs.py index dd40824d..bb0aba25 100644 --- a/ixmp4/data/db/model/docs.py +++ b/ixmp4/data/db/model/docs.py @@ -1,15 +1,15 @@ -from ixmp4.data import abstract +from typing import Any from .. import base -from ..docs import BaseDocsRepository, docs_model +from ..docs import AbstractDocs, BaseDocsRepository, docs_model from .model import Model ModelDocs = docs_model(Model) -class ModelDocsRepository(BaseDocsRepository, base.BaseRepository): +class ModelDocsRepository(BaseDocsRepository[Any], base.BaseRepository[Model]): model_class = ModelDocs dimension_model_class = Model - def list(self, *, dimension_id: int | None = None) -> list[abstract.Docs]: + def list(self, *, dimension_id: int | None = None) -> list[AbstractDocs]: return super().list(dimension_id=dimension_id) diff --git a/ixmp4/data/db/model/filter.py b/ixmp4/data/db/model/filter.py index 99e06110..f3336734 100644 --- a/ixmp4/data/db/model/filter.py +++ b/ixmp4/data/db/model/filter.py @@ -1,15 +1,20 @@ +from typing import Any + from ixmp4 import db from ixmp4.data.db import filters as base from ixmp4.data.db.iamc.timeseries import TimeSeries from ixmp4.data.db.run.model import Run -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, typing_column, utils from . import Model class BaseIamcFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - def join_datapoints(self, exc: db.sql.Select, session=None): + def join_datapoints( + self, exc: db.sql.Select[tuple[Model]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Model]]: if not utils.is_joined(exc, Run): + # This looks like any Select coming in here must have Model in it exc = exc.join(Run, onclause=Run.model__id == Model.id) if not utils.is_joined(exc, TimeSeries): @@ -19,30 +24,40 @@ def join_datapoints(self, exc: db.sql.Select, session=None): class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwargs): + def join( + self, exc: db.sql.Select[tuple[Model]], session: Session | None = None + ) -> db.sql.Select[tuple[Model]]: if not utils.is_joined(exc, Run): exc = exc.join(Run, Run.model) return exc - model: base.ModelFilter = filters.Field(default=None, exclude=True) + model: base.ModelFilter | None = filters.Field(default=None, exclude=True) class IamcModelFilter(base.ModelFilter, BaseIamcFilter, metaclass=filters.FilterMeta): - region: base.RegionFilter - variable: base.VariableFilter - unit: base.UnitFilter + region: base.RegionFilter | None = filters.Field(None) + variable: base.VariableFilter | None = filters.Field(None) + unit: base.UnitFilter | None = filters.Field(None) run: RunFilter = filters.Field( default=RunFilter(id=None, version=None, is_default=True) ) - def join(self, exc, session=None): + def join( + self, exc: db.sql.Select[tuple[Model]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Model]]: return super().join_datapoints(exc, session) class ModelFilter(base.ModelFilter, BaseIamcFilter, metaclass=filters.FilterMeta): iamc: IamcModelFilter | filters.Boolean - def filter_iamc(self, exc, c, v, session=None): + def filter_iamc( + self, + exc: db.sql.Select[tuple[Model]], + c: typing_column[Any], + v: bool | None, + session: db.Session | None = None, + ) -> db.sql.Select[tuple[Model]]: if v is None: return exc @@ -53,5 +68,7 @@ def filter_iamc(self, exc, c, v, session=None): exc = exc.where(~Model.id.in_(ids)) return exc - def join(self, exc, **kwargs): + def join( + self, exc: db.sql.Select[tuple[Model]], session: Session | None = None + ) -> db.sql.Select[tuple[Model]]: return exc diff --git a/ixmp4/data/db/model/repository.py b/ixmp4/data/db/model/repository.py index cf4b5bbc..c143dd9e 100644 --- a/ixmp4/data/db/model/repository.py +++ b/ixmp4/data/db/model/repository.py @@ -1,14 +1,32 @@ +from typing import TYPE_CHECKING + import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + + from ixmp4 import db from ixmp4.data import abstract from ixmp4.data.auth.decorators import guard +from ixmp4.db.filters import BaseFilter from .. import base from .docs import ModelDocsRepository from .model import Model +class EnumerateKwargs(abstract.annotations.HasNameFilter, total=False): + _filter: BaseFilter + + +class CreateKwargs(TypedDict, total=False): + name: str + + class ModelRepository( base.Creator[Model], base.Retriever[Model], @@ -18,9 +36,9 @@ class ModelRepository( model_class = Model docs: ModelDocsRepository - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.docs = ModelDocsRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + self.docs = ModelDocsRepository(*args) from .filter import ModelFilter @@ -35,18 +53,19 @@ def add(self, name: str) -> Model: def get(self, name: str) -> Model: exc = self.select(name=name) try: - return self.session.execute(exc).scalar_one() + model: Model = self.session.execute(exc).scalar_one() + return model except db.NoResultFound: raise Model.NotFound @guard("edit") - def create(self, *args, **kwargs) -> Model: + def create(self, *args: str, **kwargs: Unpack[CreateKwargs]) -> Model: return super().create(*args, **kwargs) @guard("view") - def list(self, *args, **kwargs) -> list[Model]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Model]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + return super().tabulate(**kwargs) diff --git a/ixmp4/data/db/optimization/base.py b/ixmp4/data/db/optimization/base.py index 7c3cc084..1b6f6629 100644 --- a/ixmp4/data/db/optimization/base.py +++ b/ixmp4/data/db/optimization/base.py @@ -2,6 +2,8 @@ from ixmp4.core.exceptions import IxmpError from ixmp4.data import types +from ixmp4.data.abstract.annotations import HasNameFilter +from ixmp4.db.filters import BaseFilter from .. import mixins from ..base import BaseModel as RootBaseModel @@ -26,3 +28,7 @@ class BaseModel(RootBaseModel, mixins.HasCreationInfo): table_prefix = "optimization_" name: types.Name + + +class EnumerateKwargs(HasNameFilter, total=False): + _filter: BaseFilter diff --git a/ixmp4/data/db/optimization/column/docs.py b/ixmp4/data/db/optimization/column/docs.py index dd263636..83f34951 100644 --- a/ixmp4/data/db/optimization/column/docs.py +++ b/ixmp4/data/db/optimization/column/docs.py @@ -3,6 +3,6 @@ from .model import Column -class ColumnDocsRepository(BaseDocsRepository): +class ColumnDocsRepository(BaseDocsRepository): # type: ignore[type-arg] model_class = docs_model(Column) # ColumnDocs dimension_model_class = Column diff --git a/ixmp4/data/db/optimization/column/filter.py b/ixmp4/data/db/optimization/column/filter.py index 1b98c331..a5806370 100644 --- a/ixmp4/data/db/optimization/column/filter.py +++ b/ixmp4/data/db/optimization/column/filter.py @@ -10,7 +10,7 @@ class OptimizationTableFilter( base.OptimizationTableFilter, metaclass=filters.FilterMeta ): - def join(self, exc, session=None): + def join(self, exc, session=None): # type: ignore[no-untyped-def] if not utils.is_joined(exc, Table): exc = exc.join(Table, onclause=Column.table__id == Table.id) return exc @@ -19,5 +19,5 @@ def join(self, exc, session=None): class OptimizationColumnFilter( base.OptimizationColumnFilter, metaclass=filters.FilterMeta ): - def join(self, exc, session=None): + def join(self, exc, session=None): # type: ignore[no-untyped-def] return exc diff --git a/ixmp4/data/db/optimization/column/repository.py b/ixmp4/data/db/optimization/column/repository.py index eb1d2f72..31d950cd 100644 --- a/ixmp4/data/db/optimization/column/repository.py +++ b/ixmp4/data/db/optimization/column/repository.py @@ -11,7 +11,7 @@ class ColumnRepository( ): model_class = Column - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(*args, **kwargs) self.docs = ColumnDocsRepository(*args, **kwargs) @@ -44,8 +44,9 @@ def add( return column @guard("edit") - def create( + def create( # type: ignore[no-untyped-def] self, + /, name: str, constrained_to_indexset: int, dtype: str, @@ -96,7 +97,7 @@ def create( :class:`ixmp4.data.abstract.optimization.Column`: The created Column. """ - return super().create( + return super().create( # type: ignore[call-arg] name=name, constrained_to_indexset=constrained_to_indexset, dtype=dtype, diff --git a/ixmp4/data/db/optimization/equation/docs.py b/ixmp4/data/db/optimization/equation/docs.py index b3403cb6..0115eb56 100644 --- a/ixmp4/data/db/optimization/equation/docs.py +++ b/ixmp4/data/db/optimization/equation/docs.py @@ -1,8 +1,15 @@ +from typing import Any + from ixmp4.data.db.docs import BaseDocsRepository, docs_model from .model import Equation +EquationDocs = docs_model(Equation) + -class EquationDocsRepository(BaseDocsRepository): - model_class = docs_model(Equation) # EquationDocs +# NOTE Mypy is a static type checker, but we create the class(es) that would need to go +# in BaseDocsRepository[...] dynamically, so I don't know if there's any way to type +# hint them properly. TypeVar, TypeAlias, type(), type[], and NewType all did not work. +class EquationDocsRepository(BaseDocsRepository[Any]): + model_class = EquationDocs dimension_model_class = Equation diff --git a/ixmp4/data/db/optimization/equation/filter.py b/ixmp4/data/db/optimization/equation/filter.py index bd0a0686..1c7838ed 100644 --- a/ixmp4/data/db/optimization/equation/filter.py +++ b/ixmp4/data/db/optimization/equation/filter.py @@ -1,17 +1,21 @@ from ixmp4.data.db import filters as base from ixmp4.data.db.run import Run -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, sql, utils from .model import Equation class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Equation]], session: Session | None = None + ) -> sql.Select[tuple[Equation]]: if not utils.is_joined(exc, Run): exc = exc.join(Run, onclause=Equation.run__id == Run.id) return exc class EquationFilter(base.OptimizationEquationFilter, metaclass=filters.FilterMeta): - def join(self, exc, session=None): + def join( + self, exc: sql.Select[tuple[Equation]], session: Session | None = None + ) -> sql.Select[tuple[Equation]]: return exc diff --git a/ixmp4/data/db/optimization/equation/model.py b/ixmp4/data/db/optimization/equation/model.py index 2201d2f6..e7588043 100644 --- a/ixmp4/data/db/optimization/equation/model.py +++ b/ixmp4/data/db/optimization/equation/model.py @@ -25,7 +25,7 @@ class Equation(base.BaseModel): data: types.JsonDict = db.Column(db.JsonType, nullable=False, default={}) @validates("data") - def validate_data(self, key, data: dict[str, Any]): + def validate_data(self, key: Any, data: dict[str, Any]) -> dict[str, Any]: if data == {}: return data data_to_validate = copy.deepcopy(data) diff --git a/ixmp4/data/db/optimization/equation/repository.py b/ixmp4/data/db/optimization/equation/repository.py index f9aa91b5..34f9fdcd 100644 --- a/ixmp4/data/db/optimization/equation/repository.py +++ b/ixmp4/data/db/optimization/equation/repository.py @@ -1,4 +1,12 @@ -from typing import Any, Iterable +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + import pandas as pd @@ -22,16 +30,16 @@ class EquationRepository( UsageError = OptimizationItemUsageError - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.docs = EquationDocsRepository(*args, **kwargs) - self.columns = ColumnRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + self.docs = EquationDocsRepository(*args) + self.columns = ColumnRepository(*args) from .filter import EquationFilter self.filter_class = EquationFilter - def _add_column( + def _add_column( # type: ignore[no-untyped-def] self, run_id: int, equation_id: int, @@ -104,15 +112,11 @@ def get_by_id(self, id: int) -> Equation: @guard("edit") def create( self, - run_id: int, name: str, + run_id: int, constrained_to_indexsets: list[str], column_names: list[str] | None = None, - **kwargs, ) -> Equation: - # Convert to list to avoid enumerate() splitting strings to letters - if isinstance(constrained_to_indexsets, str): - constrained_to_indexsets = list(constrained_to_indexsets) if column_names and len(column_names) != len(constrained_to_indexsets): raise OptimizationItemUsageError( f"While processing Equation {name}: \n" @@ -129,11 +133,7 @@ def create( "The given `column_names` are not unique!" ) - equation = super().create( - run_id=run_id, - name=name, - **kwargs, - ) + equation = super().create(run_id=run_id, name=name) for i, name in enumerate(constrained_to_indexsets): self._add_column( run_id=run_id, @@ -145,12 +145,12 @@ def create( return equation @guard("view") - def list(self, *args, **kwargs) -> Iterable[Equation]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack["base.EnumerateKwargs"]) -> Iterable[Equation]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack["base.EnumerateKwargs"]) -> pd.DataFrame: + return super().tabulate(**kwargs) @guard("edit") def add_data(self, equation_id: int, data: dict[str, Any] | pd.DataFrame) -> None: @@ -174,7 +174,7 @@ def add_data(self, equation_id: int, data: dict[str, Any] | pd.DataFrame) -> Non existing_data.set_index(index_list, inplace=True) equation.data = ( data.set_index(index_list).combine_first(existing_data).reset_index() - ).to_dict(orient="list") + ).to_dict(orient="list") # type: ignore[assignment] self.session.commit() diff --git a/ixmp4/data/db/optimization/indexset/docs.py b/ixmp4/data/db/optimization/indexset/docs.py index db49e4a8..1496a22c 100644 --- a/ixmp4/data/db/optimization/indexset/docs.py +++ b/ixmp4/data/db/optimization/indexset/docs.py @@ -1,8 +1,12 @@ +from typing import Any + from ixmp4.data.db.docs import BaseDocsRepository, docs_model from .model import IndexSet +IndexSetDocs = docs_model(IndexSet) + -class IndexSetDocsRepository(BaseDocsRepository): - model_class = docs_model(IndexSet) # IndexSetDocs +class IndexSetDocsRepository(BaseDocsRepository[Any]): + model_class = IndexSetDocs dimension_model_class = IndexSet diff --git a/ixmp4/data/db/optimization/indexset/filter.py b/ixmp4/data/db/optimization/indexset/filter.py index 9535955a..b3f05562 100644 --- a/ixmp4/data/db/optimization/indexset/filter.py +++ b/ixmp4/data/db/optimization/indexset/filter.py @@ -1,12 +1,14 @@ from ixmp4.data.db import filters as base from ixmp4.data.db.run import Run -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, sql, utils from .model import IndexSet class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[IndexSet]], session: Session | None = None + ) -> sql.Select[tuple[IndexSet]]: if not utils.is_joined(exc, Run): exc = exc.join(Run, onclause=IndexSet.run__id == Run.id) return exc @@ -15,5 +17,7 @@ def join(self, exc, **kwargs): class OptimizationIndexSetFilter( base.OptimizationIndexSetFilter, metaclass=filters.FilterMeta ): - def join(self, exc, session=None): + def join( + self, exc: sql.Select[tuple[IndexSet]], session: Session | None = None + ) -> sql.Select[tuple[IndexSet]]: return exc diff --git a/ixmp4/data/db/optimization/indexset/model.py b/ixmp4/data/db/optimization/indexset/model.py index 58afe208..e33fad58 100644 --- a/ixmp4/data/db/optimization/indexset/model.py +++ b/ixmp4/data/db/optimization/indexset/model.py @@ -19,11 +19,11 @@ class IndexSet(base.BaseModel): _data_type: types.OptimizationDataType _data: types.Mapped[list["IndexSetData"]] = db.relationship( - back_populates="indexset" + back_populates="indexset", order_by="IndexSetData.id" ) @property - def data(self) -> list[float | int | str]: + def data(self) -> list[float] | list[int] | list[str]: return ( [] if self._data_type is None @@ -31,7 +31,7 @@ def data(self) -> list[float | int | str]: ) @data.setter - def data(self, value: list[float | int | str]) -> None: + def data(self, value: list[float] | list[int] | list[str]) -> None: return None run__id: types.RunId diff --git a/ixmp4/data/db/optimization/indexset/repository.py b/ixmp4/data/db/optimization/indexset/repository.py index d2281a15..32455dd1 100644 --- a/ixmp4/data/db/optimization/indexset/repository.py +++ b/ixmp4/data/db/optimization/indexset/repository.py @@ -1,4 +1,10 @@ -from typing import List +from typing import TYPE_CHECKING, List, Literal, cast + +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend import pandas as pd @@ -19,9 +25,9 @@ class IndexSetRepository( ): model_class = IndexSet - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.docs = IndexSetDocsRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + self.docs = IndexSetDocsRepository(*args) from .filter import OptimizationIndexSetFilter @@ -52,23 +58,23 @@ def get_by_id(self, id: int) -> IndexSet: return obj @guard("edit") - def create(self, run_id: int, name: str, **kwargs) -> IndexSet: - return super().create(run_id=run_id, name=name, **kwargs) + def create(self, run_id: int, name: str) -> IndexSet: + return super().create(run_id=run_id, name=name) @guard("view") - def list(self, *args, **kwargs) -> list[IndexSet]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack["base.EnumerateKwargs"]) -> list[IndexSet]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, include_data: bool = False, **kwargs) -> pd.DataFrame: + def tabulate( + self, *, include_data: bool = False, **kwargs: Unpack["base.EnumerateKwargs"] + ) -> pd.DataFrame: if not include_data: return ( - super() - .tabulate(*args, **kwargs) - .rename(columns={"_data_type": "data_type"}) + super().tabulate(**kwargs).rename(columns={"_data_type": "data_type"}) ) else: - result = super().tabulate(*args, **kwargs).drop(labels="_data_type", axis=1) + result = super().tabulate(**kwargs).drop(labels="_data_type", axis=1) result.insert( loc=0, column="data", @@ -80,14 +86,13 @@ def tabulate(self, *args, include_data: bool = False, **kwargs) -> pd.DataFrame: def add_data( self, indexset_id: int, - data: float | int | List[float | int | str] | str, + data: float | int | str | List[float] | List[int] | List[str], ) -> None: indexset = self.get_by_id(id=indexset_id) - if not isinstance(data, list): - data = [data] + _data = data if isinstance(data, list) else [data] bulk_insert_enabled_data: list[dict[str, str]] = [ - {"value": str(d)} for d in data + {"value": str(d)} for d in _data ] try: self.session.execute( @@ -98,7 +103,10 @@ def add_data( self.session.rollback() raise indexset.DataInvalid from e - indexset._data_type = type(data[0]).__name__ + # Due to _data's limitation above, __name__ will always be that + indexset._data_type = cast( + Literal["float", "int", "str"], type(_data[0]).__name__ + ) self.session.add(indexset) self.session.commit() diff --git a/ixmp4/data/db/optimization/parameter/docs.py b/ixmp4/data/db/optimization/parameter/docs.py index db1cb774..47d25b7f 100644 --- a/ixmp4/data/db/optimization/parameter/docs.py +++ b/ixmp4/data/db/optimization/parameter/docs.py @@ -1,8 +1,12 @@ +from typing import Any + from ixmp4.data.db.docs import BaseDocsRepository, docs_model from .model import Parameter +ParameterDocs = docs_model(Parameter) + -class ParameterDocsRepository(BaseDocsRepository): - model_class = docs_model(Parameter) # ParameterDocs +class ParameterDocsRepository(BaseDocsRepository[Any]): + model_class = ParameterDocs dimension_model_class = Parameter diff --git a/ixmp4/data/db/optimization/parameter/filter.py b/ixmp4/data/db/optimization/parameter/filter.py index cbd913bb..0066e483 100644 --- a/ixmp4/data/db/optimization/parameter/filter.py +++ b/ixmp4/data/db/optimization/parameter/filter.py @@ -1,12 +1,14 @@ from ixmp4.data.db import filters as base from ixmp4.data.db.run import Run -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, sql, utils from .model import Parameter class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Parameter]], session: Session | None = None + ) -> sql.Select[tuple[Parameter]]: if not utils.is_joined(exc, Run): exc = exc.join(Run, onclause=Parameter.run__id == Run.id) return exc @@ -15,5 +17,7 @@ def join(self, exc, **kwargs): class OptimizationParameterFilter( base.OptimizationParameterFilter, metaclass=filters.FilterMeta ): - def join(self, exc, session=None): + def join( + self, exc: sql.Select[tuple[Parameter]], session: Session | None = None + ) -> sql.Select[tuple[Parameter]]: return exc diff --git a/ixmp4/data/db/optimization/parameter/model.py b/ixmp4/data/db/optimization/parameter/model.py index bb052ea7..3b96809b 100644 --- a/ixmp4/data/db/optimization/parameter/model.py +++ b/ixmp4/data/db/optimization/parameter/model.py @@ -25,7 +25,7 @@ class Parameter(base.BaseModel): data: types.JsonDict = db.Column(db.JsonType, nullable=False, default={}) @validates("data") - def validate_data(self, key, data: dict[str, Any]): + def validate_data(self, key: Any, data: dict[str, Any]) -> dict[str, Any]: data_to_validate = copy.deepcopy(data) del data_to_validate["values"] del data_to_validate["units"] diff --git a/ixmp4/data/db/optimization/parameter/repository.py b/ixmp4/data/db/optimization/parameter/repository.py index 126a51fa..82d50f88 100644 --- a/ixmp4/data/db/optimization/parameter/repository.py +++ b/ixmp4/data/db/optimization/parameter/repository.py @@ -1,4 +1,12 @@ -from typing import Any, Iterable +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + import pandas as pd @@ -23,16 +31,16 @@ class ParameterRepository( UsageError = OptimizationItemUsageError - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.docs = ParameterDocsRepository(*args, **kwargs) - self.columns = ColumnRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + self.docs = ParameterDocsRepository(*args) + self.columns = ColumnRepository(*args) from .filter import OptimizationParameterFilter self.filter_class = OptimizationParameterFilter - def _add_column( + def _add_column( # type: ignore[no-untyped-def] self, run_id: int, parameter_id: int, @@ -72,11 +80,7 @@ def _add_column( **kwargs, ) - def add( - self, - run_id: int, - name: str, - ) -> Parameter: + def add(self, run_id: int, name: str) -> Parameter: parameter = Parameter(name=name, run__id=run_id) parameter.set_creation_info(auth_context=self.backend.auth_context) self.session.add(parameter) @@ -109,11 +113,7 @@ def create( name: str, constrained_to_indexsets: list[str], column_names: list[str] | None = None, - **kwargs, ) -> Parameter: - # Convert to list to avoid enumerate() splitting strings to letters - if isinstance(constrained_to_indexsets, str): - constrained_to_indexsets = list(constrained_to_indexsets) if column_names and len(column_names) != len(constrained_to_indexsets): raise self.UsageError( f"While processing Parameter {name}: \n" @@ -130,11 +130,7 @@ def create( "The given `column_names` are not unique!" ) - parameter = super().create( - run_id=run_id, - name=name, - **kwargs, - ) + parameter = super().create(run_id=run_id, name=name) for i, name in enumerate(constrained_to_indexsets): self._add_column( run_id=run_id, @@ -146,12 +142,12 @@ def create( return parameter @guard("view") - def list(self, *args, **kwargs) -> Iterable[Parameter]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack["base.EnumerateKwargs"]) -> Iterable[Parameter]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack["base.EnumerateKwargs"]) -> pd.DataFrame: + return super().tabulate(**kwargs) @guard("edit") def add_data(self, parameter_id: int, data: dict[str, Any] | pd.DataFrame) -> None: @@ -184,8 +180,10 @@ def add_data(self, parameter_id: int, data: dict[str, Any] | pd.DataFrame) -> No existing_data = pd.DataFrame(parameter.data) if not existing_data.empty: existing_data.set_index(index_list, inplace=True) + # TODO Ignoring this for now since I'll likely refactor this soon, anyway + # Same applies to equation, table, and variable. parameter.data = ( data.set_index(index_list).combine_first(existing_data).reset_index() - ).to_dict(orient="list") + ).to_dict(orient="list") # type: ignore[assignment] self.session.commit() diff --git a/ixmp4/data/db/optimization/scalar/docs.py b/ixmp4/data/db/optimization/scalar/docs.py index 561f0241..b0a0b9b2 100644 --- a/ixmp4/data/db/optimization/scalar/docs.py +++ b/ixmp4/data/db/optimization/scalar/docs.py @@ -1,8 +1,12 @@ +from typing import Any + from ixmp4.data.db.docs import BaseDocsRepository, docs_model from .model import Scalar +ScalarDocs = docs_model(Scalar) + -class ScalarDocsRepository(BaseDocsRepository): - model_class = docs_model(Scalar) # ScalarDocs +class ScalarDocsRepository(BaseDocsRepository[Any]): + model_class = ScalarDocs dimension_model_class = Scalar diff --git a/ixmp4/data/db/optimization/scalar/filter.py b/ixmp4/data/db/optimization/scalar/filter.py index 34fcc39e..2303a013 100644 --- a/ixmp4/data/db/optimization/scalar/filter.py +++ b/ixmp4/data/db/optimization/scalar/filter.py @@ -1,20 +1,24 @@ from ixmp4.data.db import filters as base from ixmp4.data.db.run import Run from ixmp4.data.db.unit import Unit -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, sql, utils from .model import Scalar class OptimizationUnitFilter(base.UnitFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwawrgs): + def join( + self, exc: sql.Select[tuple[Scalar]], session: Session | None = None + ) -> sql.Select[tuple[Scalar]]: if not utils.is_joined(exc, Unit): exc = exc.join(Unit, onclause=Scalar.unit__id == Unit.id) return exc class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Scalar]], session: Session | None = None + ) -> sql.Select[tuple[Scalar]]: if not utils.is_joined(exc, Run): exc = exc.join(Run, onclause=Scalar.run__id == Run.id) return exc @@ -23,5 +27,7 @@ def join(self, exc, **kwargs): class OptimizationScalarFilter( base.OptimizationScalarFilter, metaclass=filters.FilterMeta ): - def join(self, exc, session=None): + def join( + self, exc: sql.Select[tuple[Scalar]], session: Session | None = None + ) -> sql.Select[tuple[Scalar]]: return exc diff --git a/ixmp4/data/db/optimization/scalar/model.py b/ixmp4/data/db/optimization/scalar/model.py index c364f807..12e7326b 100644 --- a/ixmp4/data/db/optimization/scalar/model.py +++ b/ixmp4/data/db/optimization/scalar/model.py @@ -17,8 +17,8 @@ class Scalar(base.BaseModel): value: types.Float = db.Column(db.Float, nullable=True, unique=False) - unit: types.Mapped[Unit | None] = db.relationship() - unit__id: types.Mapped[int | None] = db.Column( + unit: types.Mapped[Unit] = db.relationship() + unit__id: types.Mapped[int] = db.Column( db.Integer, db.ForeignKey("unit.id"), index=True ) diff --git a/ixmp4/data/db/optimization/scalar/repository.py b/ixmp4/data/db/optimization/scalar/repository.py index 7d7a154c..3ad02f51 100644 --- a/ixmp4/data/db/optimization/scalar/repository.py +++ b/ixmp4/data/db/optimization/scalar/repository.py @@ -1,9 +1,17 @@ -from typing import Iterable +from collections.abc import Iterable +from typing import TYPE_CHECKING + +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend import pandas as pd from ixmp4 import db from ixmp4.data.abstract import optimization as abstract +from ixmp4.data.abstract.annotations import HasUnitIdFilter from ixmp4.data.auth.decorators import guard from .. import base @@ -11,6 +19,9 @@ from .model import Scalar +class EnumerateKwargs(base.EnumerateKwargs, HasUnitIdFilter, total=False): ... + + class ScalarRepository( base.Creator[Scalar], base.Retriever[Scalar], @@ -19,9 +30,9 @@ class ScalarRepository( ): model_class = Scalar - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.docs = ScalarDocsRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + self.docs = ScalarDocsRepository(*args) from .filter import OptimizationScalarFilter @@ -58,11 +69,9 @@ def get_by_id(self, id: int) -> Scalar: return obj @guard("edit") - def create( - self, name: str, value: float, unit_name: str, run_id: int, **kwargs - ) -> Scalar: + def create(self, name: str, value: float, unit_name: str, run_id: int) -> Scalar: return super().create( - name=name, value=value, unit_name=unit_name, run_id=run_id, **kwargs + name=name, value=value, unit_name=unit_name, run_id=run_id ) @guard("edit") @@ -80,12 +89,13 @@ def update( self.session.execute(exc) self.session.commit() - return self.get_by_id(id) + scalar: Scalar = self.get_by_id(id) + return scalar @guard("view") - def list(self, *args, **kwargs) -> Iterable[Scalar]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> Iterable[Scalar]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + return super().tabulate(**kwargs) diff --git a/ixmp4/data/db/optimization/table/docs.py b/ixmp4/data/db/optimization/table/docs.py index 23322b1b..c28c3bb1 100644 --- a/ixmp4/data/db/optimization/table/docs.py +++ b/ixmp4/data/db/optimization/table/docs.py @@ -1,8 +1,12 @@ +from typing import Any + from ixmp4.data.db.docs import BaseDocsRepository, docs_model from .model import Table +TableDocs = docs_model(Table) + -class TableDocsRepository(BaseDocsRepository): - model_class = docs_model(Table) # TableDocs +class TableDocsRepository(BaseDocsRepository[Any]): + model_class = TableDocs dimension_model_class = Table diff --git a/ixmp4/data/db/optimization/table/filter.py b/ixmp4/data/db/optimization/table/filter.py index de6c069c..78152b80 100644 --- a/ixmp4/data/db/optimization/table/filter.py +++ b/ixmp4/data/db/optimization/table/filter.py @@ -1,12 +1,14 @@ from ixmp4.data.db import filters as base from ixmp4.data.db.run import Run -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, sql, utils from .model import Table class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Table]], session: Session | None = None + ) -> sql.Select[tuple[Table]]: if not utils.is_joined(exc, Run): exc = exc.join(Run, onclause=Table.run__id == Run.id) return exc @@ -15,5 +17,7 @@ def join(self, exc, **kwargs): class OptimizationTableFilter( base.OptimizationTableFilter, metaclass=filters.FilterMeta ): - def join(self, exc, session=None): + def join( + self, exc: sql.Select[tuple[Table]], session: Session | None = None + ) -> sql.Select[tuple[Table]]: return exc diff --git a/ixmp4/data/db/optimization/table/model.py b/ixmp4/data/db/optimization/table/model.py index ea99cd11..1143007a 100644 --- a/ixmp4/data/db/optimization/table/model.py +++ b/ixmp4/data/db/optimization/table/model.py @@ -27,7 +27,7 @@ class Table(base.BaseModel): # TODO: should we pass self to validate_data to raise more specific errors? @validates("data") - def validate_data(self, key, data: dict[str, Any]): + def validate_data(self, key: Any, data: dict[str, Any]) -> dict[str, Any]: return utils.validate_data( host=self, data=data, diff --git a/ixmp4/data/db/optimization/table/repository.py b/ixmp4/data/db/optimization/table/repository.py index 717b6c51..27dbef2c 100644 --- a/ixmp4/data/db/optimization/table/repository.py +++ b/ixmp4/data/db/optimization/table/repository.py @@ -1,4 +1,11 @@ -from typing import Any, Iterable +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend import pandas as pd @@ -22,16 +29,16 @@ class TableRepository( UsageError = OptimizationItemUsageError - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.docs = TableDocsRepository(*args, **kwargs) - self.columns = ColumnRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + self.docs = TableDocsRepository(*args) + self.columns = ColumnRepository(*args) from .filter import OptimizationTableFilter self.filter_class = OptimizationTableFilter - def _add_column( + def _add_column( # type: ignore[no-untyped-def] self, run_id: int, table_id: int, @@ -71,11 +78,7 @@ def _add_column( **kwargs, ) - def add( - self, - run_id: int, - name: str, - ) -> Table: + def add(self, run_id: int, name: str) -> Table: table = Table(name=name, run__id=run_id) self.session.add(table) @@ -105,11 +108,7 @@ def create( name: str, constrained_to_indexsets: list[str], column_names: list[str] | None = None, - **kwargs, ) -> Table: - # Convert to list to avoid enumerate() splitting strings to letters - if isinstance(constrained_to_indexsets, str): - constrained_to_indexsets = list(constrained_to_indexsets) if column_names and len(column_names) != len(constrained_to_indexsets): raise self.UsageError( f"While processing Table {name}: \n" @@ -126,11 +125,7 @@ def create( "The given `column_names` are not unique!" ) - table = super().create( - run_id=run_id, - name=name, - **kwargs, - ) + table = super().create(run_id=run_id, name=name) for i, name in enumerate(constrained_to_indexsets): self._add_column( run_id=run_id, @@ -142,12 +137,12 @@ def create( return table @guard("view") - def list(self, *args, **kwargs) -> Iterable[Table]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack["base.EnumerateKwargs"]) -> Iterable[Table]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack["base.EnumerateKwargs"]) -> pd.DataFrame: + return super().tabulate(**kwargs) @guard("edit") def add_data(self, table_id: int, data: dict[str, Any] | pd.DataFrame) -> None: @@ -157,7 +152,7 @@ def add_data(self, table_id: int, data: dict[str, Any] | pd.DataFrame) -> None: table.data = pd.concat([pd.DataFrame.from_dict(table.data), data]).to_dict( orient="list" - ) + ) # type: ignore[assignment] self.session.add(table) self.session.commit() diff --git a/ixmp4/data/db/optimization/utils.py b/ixmp4/data/db/optimization/utils.py index 08b2dce2..3ff3bd8c 100644 --- a/ixmp4/data/db/optimization/utils.py +++ b/ixmp4/data/db/optimization/utils.py @@ -19,7 +19,9 @@ def collect_indexsets_to_check( return collection -def validate_data(host: base.BaseModel, data: dict[str, Any], columns: list["Column"]): +def validate_data( + host: base.BaseModel, data: dict[str, Any], columns: list["Column"] +) -> dict[str, Any]: data_frame: pd.DataFrame = pd.DataFrame.from_dict(data) # TODO for all of the following, we might want to create unique exceptions # Could me make both more specific by specifiying missing/extra columns? @@ -60,4 +62,6 @@ def validate_data(host: base.BaseModel, data: dict[str, Any], columns: list["Col "and Columns it is constrained to!" ) - return data_frame.to_dict(orient="list") + # we can assume the keys are always str + dict_data: dict[str, Any] = data_frame.to_dict(orient="list") # type: ignore[assignment] + return dict_data diff --git a/ixmp4/data/db/optimization/variable/docs.py b/ixmp4/data/db/optimization/variable/docs.py index dbe34d9c..d9250dce 100644 --- a/ixmp4/data/db/optimization/variable/docs.py +++ b/ixmp4/data/db/optimization/variable/docs.py @@ -1,8 +1,12 @@ +from typing import Any + from ixmp4.data.db.docs import BaseDocsRepository, docs_model from .model import OptimizationVariable as Variable +OptimizationVariableDocs = docs_model(Variable) + -class OptimizationVariableDocsRepository(BaseDocsRepository): - model_class = docs_model(Variable) # VariableDocs +class OptimizationVariableDocsRepository(BaseDocsRepository[Any]): + model_class = OptimizationVariableDocs dimension_model_class = Variable diff --git a/ixmp4/data/db/optimization/variable/filter.py b/ixmp4/data/db/optimization/variable/filter.py index 692bebbc..7ae1ffdb 100644 --- a/ixmp4/data/db/optimization/variable/filter.py +++ b/ixmp4/data/db/optimization/variable/filter.py @@ -1,12 +1,14 @@ from ixmp4.data.db import filters as base from ixmp4.data.db.run import Run -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, sql, utils from .model import OptimizationVariable as Variable class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwargs): + def join( + self, exc: sql.Select[tuple[Variable]], session: Session | None = None + ) -> sql.Select[tuple[Variable]]: if not utils.is_joined(exc, Run): exc = exc.join(Run, onclause=Variable.run__id == Run.id) return exc @@ -15,5 +17,7 @@ def join(self, exc, **kwargs): class OptimizationVariableFilter( base.OptimizationVariableFilter, metaclass=filters.FilterMeta ): - def join(self, exc, session=None): + def join( + self, exc: sql.Select[tuple[Variable]], session: Session | None = None + ) -> sql.Select[tuple[Variable]]: return exc diff --git a/ixmp4/data/db/optimization/variable/model.py b/ixmp4/data/db/optimization/variable/model.py index 6c634f31..32cfae59 100644 --- a/ixmp4/data/db/optimization/variable/model.py +++ b/ixmp4/data/db/optimization/variable/model.py @@ -26,7 +26,7 @@ class OptimizationVariable(base.BaseModel): data: types.JsonDict = db.Column(db.JsonType, nullable=False, default={}) @validates("data") - def validate_data(self, key, data: dict[str, Any]): + def validate_data(self, key: Any, data: dict[str, Any]) -> dict[str, Any]: if data == {}: return data data_to_validate = copy.deepcopy(data) diff --git a/ixmp4/data/db/optimization/variable/repository.py b/ixmp4/data/db/optimization/variable/repository.py index b8828fd3..76aa13f3 100644 --- a/ixmp4/data/db/optimization/variable/repository.py +++ b/ixmp4/data/db/optimization/variable/repository.py @@ -1,4 +1,11 @@ -from typing import Any, Iterable +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend import pandas as pd @@ -22,16 +29,16 @@ class VariableRepository( UsageError = OptimizationItemUsageError - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.docs = OptimizationVariableDocsRepository(*args, **kwargs) - self.columns = ColumnRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + self.docs = OptimizationVariableDocsRepository(*args) + self.columns = ColumnRepository(*args) from .filter import OptimizationVariableFilter self.filter_class = OptimizationVariableFilter - def _add_column( + def _add_column( # type: ignore[no-untyped-def] self, run_id: int, variable_id: int, @@ -71,11 +78,7 @@ def _add_column( **kwargs, ) - def add( - self, - run_id: int, - name: str, - ) -> Variable: + def add(self, run_id: int, name: str) -> Variable: variable = Variable(name=name, run__id=run_id) variable.set_creation_info(auth_context=self.backend.auth_context) self.session.add(variable) @@ -84,6 +87,7 @@ def add( @guard("view") def get(self, run_id: int, name: str) -> Variable: + # TODO by avoiding super().select, don't we also miss out on filters and auth? exc = db.select(Variable).where( (Variable.name == name) & (Variable.run__id == run_id) ) @@ -106,13 +110,9 @@ def create( self, run_id: int, name: str, - constrained_to_indexsets: str | list[str] | None = None, + constrained_to_indexsets: list[str] | None = None, column_names: list[str] | None = None, - **kwargs, ) -> Variable: - # Convert to list to avoid enumerate() splitting strings to letters - if isinstance(constrained_to_indexsets, str): - constrained_to_indexsets = list(constrained_to_indexsets) if column_names: # TODO If this is removed, need to check above that constrained_to_indexsets # is not None @@ -140,11 +140,7 @@ def create( "The given `column_names` are not unique!" ) - variable = super().create( - run_id=run_id, - name=name, - **kwargs, - ) + variable = super().create(run_id=run_id, name=name) if constrained_to_indexsets: for i, name in enumerate(constrained_to_indexsets): self._add_column( @@ -157,12 +153,12 @@ def create( return variable @guard("view") - def list(self, *args, **kwargs) -> Iterable[Variable]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack["base.EnumerateKwargs"]) -> Iterable[Variable]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack["base.EnumerateKwargs"]) -> pd.DataFrame: + return super().tabulate(**kwargs) @guard("edit") def add_data(self, variable_id: int, data: dict[str, Any] | pd.DataFrame) -> None: @@ -186,7 +182,7 @@ def add_data(self, variable_id: int, data: dict[str, Any] | pd.DataFrame) -> Non existing_data.set_index(index_list, inplace=True) variable.data = ( data.set_index(index_list).combine_first(existing_data).reset_index() - ).to_dict(orient="list") + ).to_dict(orient="list") # type: ignore[assignment] self.session.commit() diff --git a/ixmp4/data/db/region/__init__.py b/ixmp4/data/db/region/__init__.py index 1ae001e3..60f4514e 100644 --- a/ixmp4/data/db/region/__init__.py +++ b/ixmp4/data/db/region/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa from .docs import RegionDocsRepository from .model import Region from .repository import RegionRepository diff --git a/ixmp4/data/db/region/docs.py b/ixmp4/data/db/region/docs.py index a78b7d31..cacfacdb 100644 --- a/ixmp4/data/db/region/docs.py +++ b/ixmp4/data/db/region/docs.py @@ -1,17 +1,18 @@ -from ixmp4.data import abstract +from typing import Any + from ixmp4.data.auth.decorators import guard from .. import base -from ..docs import BaseDocsRepository, docs_model +from ..docs import AbstractDocs, BaseDocsRepository, docs_model from .model import Region RegionDocs = docs_model(Region) -class RegionDocsRepository(BaseDocsRepository, base.BaseRepository): +class RegionDocsRepository(BaseDocsRepository[Any], base.BaseRepository[Region]): model_class = RegionDocs dimension_model_class = Region @guard("view") - def list(self, *, dimension_id: int | None = None) -> list[abstract.Docs]: + def list(self, *, dimension_id: int | None = None) -> list[AbstractDocs]: return super().list(dimension_id=dimension_id) diff --git a/ixmp4/data/db/region/filter.py b/ixmp4/data/db/region/filter.py index 00e5add6..9905c573 100644 --- a/ixmp4/data/db/region/filter.py +++ b/ixmp4/data/db/region/filter.py @@ -1,13 +1,17 @@ +from typing import Any + from ixmp4 import db from ixmp4.data.db import filters as base from ixmp4.data.db.iamc.timeseries import TimeSeries -from ixmp4.db import filters, utils +from ixmp4.db import Session, filters, typing_column, utils from . import Region class BaseIamcFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - def join_datapoints(self, exc: db.sql.Select, session=None): + def join_datapoints( + self, exc: db.sql.Select[tuple[Region]], session: Session | None = None + ) -> db.sql.Select[tuple[Region]]: if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, onclause=TimeSeries.region__id == Region.id) return exc @@ -16,25 +20,35 @@ def join_datapoints(self, exc: db.sql.Select, session=None): class SimpleIamcRegionFilter( base.RegionFilter, BaseIamcFilter, metaclass=filters.FilterMeta ): - def join(self, exc, session=None): + def join( + self, exc: db.sql.Select[tuple[Region]], session: Session | None = None + ) -> db.sql.Select[tuple[Region]]: return super().join_datapoints(exc, session) class IamcRegionFilter(base.RegionFilter, BaseIamcFilter, metaclass=filters.FilterMeta): - variable: base.VariableFilter - unit: base.UnitFilter + variable: base.VariableFilter | None = filters.Field(None) + unit: base.UnitFilter | None = filters.Field(None) run: base.RunFilter = filters.Field( default=base.RunFilter(id=None, version=None, is_default=True) ) - def join(self, exc, session=None): + def join( + self, exc: db.sql.Select[tuple[Region]], session: Session | None = None + ) -> db.sql.Select[tuple[Region]]: return super().join_datapoints(exc, session) class RegionFilter(base.RegionFilter, BaseIamcFilter, metaclass=filters.FilterMeta): - iamc: IamcRegionFilter | filters.Boolean | None + iamc: IamcRegionFilter | filters.Boolean | None = filters.Field(None) - def filter_iamc(self, exc, c, v, session=None): + def filter_iamc( + self, + exc: db.sql.Select[tuple[Region]], + c: typing_column[Any], # Any since it is unused + v: bool | None, + session: Session | None = None, + ) -> db.sql.Select[tuple[Region]]: if v is None: return exc @@ -45,5 +59,7 @@ def filter_iamc(self, exc, c, v, session=None): exc = exc.where(~Region.id.in_(ids)) return exc - def join(self, exc, **kwargs): + def join( + self, exc: db.sql.Select[tuple[Region]], session: Session | None = None + ) -> db.sql.Select[tuple[Region]]: return exc diff --git a/ixmp4/data/db/region/repository.py b/ixmp4/data/db/region/repository.py index cf61b481..37302e8e 100644 --- a/ixmp4/data/db/region/repository.py +++ b/ixmp4/data/db/region/repository.py @@ -1,14 +1,36 @@ +from typing import TYPE_CHECKING + import pandas as pd from sqlalchemy.exc import NoResultFound +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + from ixmp4.data import abstract from ixmp4.data.auth.decorators import guard +from ixmp4.db.filters import BaseFilter from .. import base from .docs import RegionDocsRepository from .model import Region +class EnumerateKwargs( + abstract.annotations.HasNameFilter, + abstract.annotations.HasHierarchyFilter, + total=False, +): + _filter: BaseFilter + + +class CreateKwargs(TypedDict, total=False): + name: str + hierarchy: str + + class RegionRepository( base.Creator[Region], base.Deleter[Region], @@ -18,13 +40,13 @@ class RegionRepository( ): model_class = Region - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) from .filter import RegionFilter self.filter_class = RegionFilter - self.docs = RegionDocsRepository(*args, **kwargs) + self.docs = RegionDocsRepository(*args) def add(self, name: str, hierarchy: str) -> Region: region = Region(name=name, hierarchy=hierarchy) @@ -32,25 +54,26 @@ def add(self, name: str, hierarchy: str) -> Region: return region @guard("manage") - def create(self, *args, **kwargs) -> Region: + def create(self, *args: str, **kwargs: Unpack[CreateKwargs]) -> Region: return super().create(*args, **kwargs) @guard("manage") - def delete(self, *args, **kwargs): - return super().delete(*args, **kwargs) + def delete(self, *args: int) -> None: + super().delete(*args) @guard("view") def get(self, name: str) -> Region: exc = self.select().where(Region.name == name) try: - return self.session.execute(exc).scalar_one() + region: Region = self.session.execute(exc).scalar_one() + return region except NoResultFound: raise Region.NotFound @guard("view") - def list(self, *args, **kwargs) -> list[Region]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Region]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + return super().tabulate(**kwargs) diff --git a/ixmp4/data/db/run/filter.py b/ixmp4/data/db/run/filter.py index 0af282e1..d1304a49 100644 --- a/ixmp4/data/db/run/filter.py +++ b/ixmp4/data/db/run/filter.py @@ -1,8 +1,10 @@ +from typing import Any + from ixmp4 import db from ixmp4.data.db import filters as base from ixmp4.data.db.iamc.timeseries import TimeSeries from ixmp4.data.db.run.model import Run -from ixmp4.db import filters, utils +from ixmp4.db import filters, typing_column, utils class IamcRunFilter(filters.BaseFilter, metaclass=filters.FilterMeta): @@ -10,7 +12,9 @@ class IamcRunFilter(filters.BaseFilter, metaclass=filters.FilterMeta): variable: base.VariableFilter unit: base.UnitFilter - def join(self, exc, session=None): + def join( + self, exc: db.sql.Select[tuple[Run]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Run]]: if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, onclause=TimeSeries.run__id == Run.id) return exc @@ -19,12 +23,20 @@ def join(self, exc, session=None): class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): iamc: IamcRunFilter | filters.Boolean | None = None - def join_datapoints(self, exc: db.sql.Select, session=None): + def join_datapoints( + self, exc: db.sql.Select[tuple[Run]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Run]]: if not utils.is_joined(exc, TimeSeries): exc = exc.join(TimeSeries, onclause=TimeSeries.run__id == Run.id) return exc - def filter_iamc(self, exc, c, v, session=None): + def filter_iamc( + self, + exc: db.sql.Select[tuple[Run]], + c: typing_column[Any], # Any since it is unused + v: bool | None, + session: db.Session | None = None, + ) -> db.sql.Select[tuple[Run]]: if v is None: return exc @@ -35,5 +47,7 @@ def filter_iamc(self, exc, c, v, session=None): exc = exc.where(~Run.id.in_(ids)) return exc - def join(self, exc, **kwargs): + def join( + self, exc: db.sql.Select[tuple[Run]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Run]]: return exc diff --git a/ixmp4/data/db/run/repository.py b/ixmp4/data/db/run/repository.py index 38b4730e..81370410 100644 --- a/ixmp4/data/db/run/repository.py +++ b/ixmp4/data/db/run/repository.py @@ -1,6 +1,14 @@ +from typing import TYPE_CHECKING + import pandas as pd from sqlalchemy.exc import NoResultFound +# TODO Adapt import when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + from ixmp4 import db from ixmp4.core.exceptions import Forbidden, IxmpError, NoDefaultRunVersion from ixmp4.data import abstract @@ -13,6 +21,10 @@ from .model import Run +class CreateKwargs(TypedDict, total=False): + scenario_name: str + + class RunRepository( base.Creator[Run], base.Retriever[Run], @@ -24,16 +36,16 @@ class RunRepository( models: ModelRepository scenarios: ScenarioRepository - def __init__(self, *args, **kwargs) -> None: - self.models = ModelRepository(*args, **kwargs) - self.scenarios = ScenarioRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + self.models = ModelRepository(*args) + self.scenarios = ScenarioRepository(*args) from .filter import RunFilter self.filter_class = RunFilter - super().__init__(*args, **kwargs) + super().__init__(*args) - def join_auth(self, exc: db.sql.Select): + def join_auth(self, exc: db.sql.Select[tuple[Run]]) -> db.sql.Select[tuple[Run]]: if not utils.is_joined(exc, Model): exc = exc.join(Model, Run.model) return exc @@ -41,16 +53,16 @@ def join_auth(self, exc: db.sql.Select): def add(self, model_name: str, scenario_name: str) -> Run: # Get or create model try: - exc: db.sql.Select = self.models.select(name=model_name) - model = self.session.execute(exc).scalar_one() + exc_model = self.models.select(name=model_name) + model = self.session.execute(exc_model).scalar_one() except NoResultFound: model = Model(name=model_name) self.session.add(model) # Get or create scenario try: - exc = self.scenarios.select(name=scenario_name) - scenario: Scenario = self.session.execute(exc).scalar_one() + exc_scenario = self.scenarios.select(name=scenario_name) + scenario = self.session.execute(exc_scenario).scalar_one() except NoResultFound: scenario = Scenario(name=scenario_name) self.session.add(scenario) @@ -72,19 +84,16 @@ def add(self, model_name: str, scenario_name: str) -> Run: return run @guard("edit") - def create(self, model_name: str, *args, **kwargs) -> Run: + def create( + self, model_name: str, *args: str, **kwargs: Unpack[CreateKwargs] + ) -> Run: if self.backend.auth_context is not None: if not self.backend.auth_context.check_access("edit", model_name): raise Forbidden(f"Access to model '{model_name}' denied.") return super().create(model_name, *args, **kwargs) @guard("view") - def get( - self, - model_name: str, - scenario_name: str, - version: int, - ) -> Run: + def get(self, model_name: str, scenario_name: str, version: int) -> Run: exc = self.select( model={"name": model_name}, scenario={"name": scenario_name}, @@ -93,7 +102,9 @@ def get( ) try: - return self.session.execute(exc).scalar_one() + # TODO clean up unnecessary cast such as this + run: Run = self.session.execute(exc).scalar_one() + return run except NoResultFound: raise Run.NotFound( model=model_name, @@ -111,28 +122,25 @@ def get_by_id(self, id: int) -> Run: return obj @guard("view") - def get_default_version( - self, - model_name: str, - scenario_name: str, - ) -> Run: + def get_default_version(self, model_name: str, scenario_name: str) -> Run: exc = self.select( model={"name": model_name}, scenario={"name": scenario_name}, ) try: - return self.session.execute(exc).scalar_one() + run: Run = self.session.execute(exc).scalar_one() + return run except NoResultFound: raise NoDefaultRunVersion @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack[abstract.run.EnumerateKwargs]) -> pd.DataFrame: + return super().tabulate(**kwargs) @guard("view") - def list(self, *args, **kwargs) -> list[Run]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[abstract.run.EnumerateKwargs]) -> list[Run]: + return super().list(**kwargs) @guard("edit") def set_as_default_version(self, id: int) -> None: diff --git a/ixmp4/data/db/scenario/__init__.py b/ixmp4/data/db/scenario/__init__.py index 342b9321..772b9db2 100644 --- a/ixmp4/data/db/scenario/__init__.py +++ b/ixmp4/data/db/scenario/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa from .docs import ScenarioDocsRepository from .model import Scenario from .repository import ScenarioRepository diff --git a/ixmp4/data/db/scenario/docs.py b/ixmp4/data/db/scenario/docs.py index e1a885eb..e5cd9d95 100644 --- a/ixmp4/data/db/scenario/docs.py +++ b/ixmp4/data/db/scenario/docs.py @@ -1,7 +1,11 @@ +from typing import Any + from ..docs import BaseDocsRepository, docs_model from .model import Scenario +ScenarioDocs = docs_model(Scenario) + -class ScenarioDocsRepository(BaseDocsRepository): - model_class = docs_model(Scenario) # ScenarioDocs +class ScenarioDocsRepository(BaseDocsRepository[Any]): + model_class = ScenarioDocs dimension_model_class = Scenario diff --git a/ixmp4/data/db/scenario/filter.py b/ixmp4/data/db/scenario/filter.py index c77f8656..9c967366 100644 --- a/ixmp4/data/db/scenario/filter.py +++ b/ixmp4/data/db/scenario/filter.py @@ -1,14 +1,18 @@ +from typing import Any + from ixmp4 import db from ixmp4.data.db import filters as base from ixmp4.data.db.iamc.timeseries import TimeSeries from ixmp4.data.db.run.model import Run -from ixmp4.db import filters, utils +from ixmp4.db import filters, typing_column, utils from .model import Scenario class BaseIamcFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - def join_datapoints(self, exc: db.sql.Select, session=None): + def join_datapoints( + self, exc: db.sql.Select[tuple[Scenario]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Scenario]]: if not utils.is_joined(exc, Run): exc = exc.join(Run, onclause=Run.scenario__id == Scenario.id) @@ -19,32 +23,43 @@ def join_datapoints(self, exc: db.sql.Select, session=None): class RunFilter(base.RunFilter, metaclass=filters.FilterMeta): - def join(self, exc, **kwargs): + def join( + self, exc: db.sql.Select[tuple[Scenario]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Scenario]]: if not utils.is_joined(exc, Run): + # TODO should this really be Run.model? exc = exc.join(Run, Run.model) return exc - scenario: base.ScenarioFilter = filters.Field(default=None, exclude=True) + scenario: base.ScenarioFilter | None = filters.Field(default=None, exclude=True) class IamcScenarioFilter( base.ScenarioFilter, BaseIamcFilter, metaclass=filters.FilterMeta ): - region: base.RegionFilter - variable: base.VariableFilter - unit: base.UnitFilter + region: base.RegionFilter | None = filters.Field(None) + variable: base.VariableFilter | None = filters.Field(None) + unit: base.UnitFilter | None = filters.Field(None) run: RunFilter = filters.Field( default=RunFilter(id=None, version=None, is_default=True) ) - def join(self, exc, session=None): + def join( + self, exc: db.sql.Select[tuple[Scenario]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Scenario]]: return super().join_datapoints(exc, session) class ScenarioFilter(base.ScenarioFilter, BaseIamcFilter, metaclass=filters.FilterMeta): iamc: IamcScenarioFilter | filters.Boolean - def filter_iamc(self, exc, c, v, session=None): + def filter_iamc( + self, + exc: db.sql.Select[tuple[Scenario]], + c: typing_column[Any], # Any since it is unused + v: bool | None, + session: db.Session | None = None, + ) -> db.sql.Select[tuple[Scenario]]: if v is None: return exc @@ -55,5 +70,7 @@ def filter_iamc(self, exc, c, v, session=None): exc = exc.where(~Scenario.id.in_(ids)) return exc - def join(self, exc, **kwargs): + def join( + self, exc: db.sql.Select[tuple[Scenario]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Scenario]]: return exc diff --git a/ixmp4/data/db/scenario/repository.py b/ixmp4/data/db/scenario/repository.py index af45d3a0..e77858ce 100644 --- a/ixmp4/data/db/scenario/repository.py +++ b/ixmp4/data/db/scenario/repository.py @@ -1,15 +1,32 @@ +from typing import TYPE_CHECKING + import pandas as pd +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + from ixmp4 import db from ixmp4.data import abstract from ixmp4.data.auth.decorators import guard from ixmp4.data.db.model.model import Model +from ixmp4.db.filters import BaseFilter from .. import base from .docs import ScenarioDocsRepository from .model import Scenario +class EnumerateKwargs(abstract.annotations.HasNameFilter, total=False): + _filter: BaseFilter + + +class CreateKwargs(TypedDict, total=False): + name: str + + class ScenarioRepository( base.Creator[Scenario], base.Retriever[Scenario], @@ -18,15 +35,17 @@ class ScenarioRepository( ): model_class = Scenario - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.docs = ScenarioDocsRepository(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) + self.docs = ScenarioDocsRepository(*args) from .filter import ScenarioFilter self.filter_class = ScenarioFilter - def join_auth(self, exc: db.sql.Select) -> db.sql.Select: + def join_auth( + self, exc: db.sql.Select[tuple[Scenario]] + ) -> db.sql.Select[tuple[Scenario]]: from ixmp4.data.db.run.model import Run if not db.utils.is_joined(exc, Run): @@ -50,13 +69,13 @@ def get(self, name: str) -> Scenario: raise Scenario.NotFound @guard("edit") - def create(self, *args, **kwargs) -> Scenario: + def create(self, *args: str, **kwargs: Unpack[CreateKwargs]) -> Scenario: return super().create(*args, **kwargs) @guard("view") - def list(self, *args, **kwargs) -> list[Scenario]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[Scenario]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + return super().tabulate(**kwargs) diff --git a/ixmp4/data/db/timeseries.py b/ixmp4/data/db/timeseries.py index 5bcf3600..7f62dfca 100644 --- a/ixmp4/data/db/timeseries.py +++ b/ixmp4/data/db/timeseries.py @@ -1,14 +1,19 @@ -from typing import Any, ClassVar, Generic, Mapping, TypeVar +from collections.abc import Mapping +from typing import Any, ClassVar, Generic, TypeVar import pandas as pd from sqlalchemy.ext.declarative import AbstractConcreteBase from sqlalchemy.orm.decl_api import declared_attr from sqlalchemy.orm.exc import NoResultFound +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + from ixmp4 import db from ixmp4.data import abstract from ixmp4.data.db.model import Model from ixmp4.data.db.run import Run +from ixmp4.db.filters import BaseFilter from ..auth.decorators import guard from . import base @@ -20,10 +25,10 @@ class TimeSeries(AbstractConcreteBase, base.BaseModel): DeletionPrevented: ClassVar = abstract.TimeSeries.DeletionPrevented __abstract__ = True - parameters: dict = {} + parameters: dict[str, Any] = {} @declared_attr - def run__id(cls): + def run__id(cls) -> db.MappedColumn[int]: return db.Column( "run__id", db.Integer, @@ -33,17 +38,39 @@ def run__id(cls): ) @declared_attr - def run(cls): - return db.relationship("Run", backref="time_series", foreign_keys=[cls.run__id]) + def run(cls) -> db.Relationship["Run"]: + # Mypy doesn't recognize cls.run__id as Mapped[int], even when type hinting as + # such directly + return db.relationship("Run", backref="time_series", foreign_keys=[cls.run__id]) # type: ignore[list-item] @property def run_id(self) -> int: - return self.run__id + run_id: int = self.run__id + return run_id ModelType = TypeVar("ModelType", bound=TimeSeries) +class GetKwargs(TypedDict): + _filter: BaseFilter + + +class SelectKwargs(TypedDict, total=False): + _filter: BaseFilter + run: dict[str, int] + + +class EnumerateKwargs(abstract.annotations.HasNameFilter, total=False): + _filter: BaseFilter + join_parameters: bool | None + + +class CreateKwargs(TypedDict): + run__id: int + parameters: Mapping[str, Any] + + class TimeSeriesRepository( base.Creator[ModelType], base.Retriever[ModelType], @@ -51,7 +78,9 @@ class TimeSeriesRepository( base.BulkUpserter[ModelType], Generic[ModelType], ): - def join_auth(self, exc: db.sql.Select) -> db.sql.Select: + def join_auth( + self, exc: db.sql.Select[tuple[ModelType]] + ) -> db.sql.Select[tuple[ModelType]]: if not db.utils.is_joined(exc, Run): exc = exc.join(Run, onclause=Run.id == self.model_class.run__id) if not db.utils.is_joined(exc, Model): @@ -59,17 +88,17 @@ def join_auth(self, exc: db.sql.Select) -> db.sql.Select: return exc - def add(self, run_id: int, parameters: Mapping) -> ModelType: + def add(self, run_id: int, parameters: Mapping[str, Any]) -> ModelType: time_series = self.model_class(run_id=run_id, **parameters) self.session.add(time_series) return time_series @guard("edit") - def create(self, *args, **kwargs) -> ModelType: - return super().create(*args, **kwargs) + def create(self, **kwargs: Unpack[CreateKwargs]) -> ModelType: + return super().create(**kwargs) @guard("view") - def get(self, run_id: int, **kwargs: Any) -> ModelType: + def get(self, run_id: int, **kwargs: Unpack[GetKwargs]) -> ModelType: exc = self.select(run={"id": run_id}, **kwargs) try: @@ -86,16 +115,16 @@ def get_by_id(self, id: int) -> ModelType: return obj - def select_joined_parameters(self) -> db.sql.Select: + def select_joined_parameters(self) -> db.sql.Select[tuple[Any, ...]]: raise NotImplementedError def select( self, *, - _exc: db.sql.Select | None = None, + _exc: db.sql.Select[tuple[ModelType]] | None = None, join_parameters: bool | None = False, - **kwargs, - ) -> db.sql.Select: + **kwargs: Unpack[SelectKwargs], + ) -> db.sql.Select[tuple[ModelType]]: if _exc is not None: exc = _exc elif join_parameters: @@ -106,12 +135,12 @@ def select( return super().select(_exc=exc, **kwargs) @guard("view") - def list(self, *args, **kwargs) -> list[ModelType]: - return super().list(*args, **kwargs) + def list(self, **kwargs: Unpack[EnumerateKwargs]) -> list[ModelType]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + return super().tabulate(**kwargs) @guard("edit") def bulk_upsert(self, df: pd.DataFrame) -> None: diff --git a/ixmp4/data/db/unit/__init__.py b/ixmp4/data/db/unit/__init__.py index da2377c5..afbc0701 100644 --- a/ixmp4/data/db/unit/__init__.py +++ b/ixmp4/data/db/unit/__init__.py @@ -1,5 +1,3 @@ -# flake8: noqa - from .docs import UnitDocs, UnitDocsRepository from .model import Unit from .repository import UnitRepository diff --git a/ixmp4/data/db/unit/docs.py b/ixmp4/data/db/unit/docs.py index 2fcdcf87..084af25d 100644 --- a/ixmp4/data/db/unit/docs.py +++ b/ixmp4/data/db/unit/docs.py @@ -1,20 +1,21 @@ -from ixmp4.data import abstract +from typing import Any + from ixmp4.data.auth.decorators import guard from .. import base -from ..docs import BaseDocsRepository, docs_model +from ..docs import AbstractDocs, BaseDocsRepository, docs_model from .model import Unit UnitDocs = docs_model(Unit) class UnitDocsRepository( - BaseDocsRepository, - base.BaseRepository, + BaseDocsRepository[Any], + base.BaseRepository[Unit], ): model_class = UnitDocs dimension_model_class = Unit @guard("view") - def list(self, *, dimension_id: int | None = None) -> list[abstract.Docs]: + def list(self, *, dimension_id: int | None = None) -> list[AbstractDocs]: return super().list(dimension_id=dimension_id) diff --git a/ixmp4/data/db/unit/filter.py b/ixmp4/data/db/unit/filter.py index 3d2927e7..0189f209 100644 --- a/ixmp4/data/db/unit/filter.py +++ b/ixmp4/data/db/unit/filter.py @@ -1,14 +1,18 @@ +from typing import Any + from ixmp4 import db from ixmp4.data.db import filters as base from ixmp4.data.db.iamc.measurand import Measurand from ixmp4.data.db.iamc.timeseries import TimeSeries -from ixmp4.db import filters, utils +from ixmp4.db import filters, typing_column, utils from . import Unit class BaseIamcFilter(filters.BaseFilter, metaclass=filters.FilterMeta): - def join_datapoints(self, exc, session=None): + def join_datapoints( + self, exc: db.sql.Select[tuple[Unit]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Unit]]: if not utils.is_joined(exc, Measurand): exc = exc.join(Measurand, Measurand.unit__id == Unit.id) @@ -22,25 +26,35 @@ def join_datapoints(self, exc, session=None): class SimpleIamcUnitFilter( base.UnitFilter, BaseIamcFilter, metaclass=filters.FilterMeta ): - def join(self, exc, session=None): + def join( + self, exc: db.sql.Select[tuple[Unit]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Unit]]: return super().join_datapoints(exc, session) class IamcUnitFilter(base.UnitFilter, BaseIamcFilter, metaclass=filters.FilterMeta): - variable: base.VariableFilter - region: base.RegionFilter + variable: base.VariableFilter | None = filters.Field(None) + region: base.RegionFilter | None = filters.Field(None) run: base.RunFilter = filters.Field( default=base.RunFilter(id=None, version=None, is_default=True) ) - def join(self, exc, session=None): + def join( + self, exc: db.sql.Select[tuple[Unit]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Unit]]: return super().join_datapoints(exc, session) class UnitFilter(base.UnitFilter, BaseIamcFilter, metaclass=filters.FilterMeta): iamc: IamcUnitFilter | filters.Boolean - def filter_iamc(self, exc, c, v, session=None): + def filter_iamc( + self, + exc: db.sql.Select[tuple[Unit]], + c: typing_column[Any], # Any since it is unused + v: bool | None, + session: db.Session | None = None, + ) -> db.sql.Select[tuple[Unit]]: if v is None: return exc @@ -51,5 +65,7 @@ def filter_iamc(self, exc, c, v, session=None): exc = exc.where(~Unit.id.in_(ids)) return exc - def join(self, exc, **kwargs): + def join( + self, exc: db.sql.Select[tuple[Unit]], session: db.Session | None = None + ) -> db.sql.Select[tuple[Unit]]: return exc diff --git a/ixmp4/data/db/unit/repository.py b/ixmp4/data/db/unit/repository.py index 0c19d655..7f02c008 100644 --- a/ixmp4/data/db/unit/repository.py +++ b/ixmp4/data/db/unit/repository.py @@ -1,15 +1,32 @@ +from typing import TYPE_CHECKING + import pandas as pd from sqlalchemy.exc import NoResultFound +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + +if TYPE_CHECKING: + from ixmp4.data.backend.db import SqlAlchemyBackend + from ixmp4 import db from ixmp4.data import abstract from ixmp4.data.auth.decorators import guard +from ixmp4.db.filters import BaseFilter from .. import base from .docs import UnitDocsRepository from .model import Unit +class EnumerateKwargs(abstract.annotations.HasNameFilter, total=False): + _filter: BaseFilter + + +class CreateKwargs(TypedDict, total=False): + name: str + + class UnitRepository( base.Creator[Unit], base.Deleter[Unit], @@ -19,13 +36,13 @@ class UnitRepository( ): model_class = Unit - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *args: "SqlAlchemyBackend") -> None: + super().__init__(*args) from .filter import UnitFilter self.filter_class = UnitFilter - self.docs = UnitDocsRepository(*args, **kwargs) + self.docs = UnitDocsRepository(*args) def add(self, name: str) -> Unit: unit = Unit(name=name) @@ -33,16 +50,16 @@ def add(self, name: str) -> Unit: return unit @guard("manage") - def create(self, *args, **kwargs) -> Unit: + def create(self, /, *args: str, **kwargs: Unpack[CreateKwargs]) -> Unit: return super().create(*args, **kwargs) @guard("manage") - def delete(self, *args, **kwargs): - return super().delete(*args, **kwargs) + def delete(self, /, *args: Unpack[tuple[int]]) -> None: + return super().delete(*args) @guard("view") def get(self, name: str) -> Unit: - exc: db.sql.Select = db.select(Unit).where(Unit.name == name) + exc = db.select(Unit).where(Unit.name == name) try: return self.session.execute(exc).scalar_one() except NoResultFound: @@ -58,9 +75,9 @@ def get_by_id(self, id: int) -> Unit: return obj @guard("view") - def list(self, *args, **kwargs) -> list[Unit]: - return super().list(*args, **kwargs) + def list(self, /, **kwargs: Unpack[EnumerateKwargs]) -> list[Unit]: + return super().list(**kwargs) @guard("view") - def tabulate(self, *args, **kwargs) -> pd.DataFrame: - return super().tabulate(*args, **kwargs) + def tabulate(self, **kwargs: Unpack[EnumerateKwargs]) -> pd.DataFrame: + return super().tabulate(**kwargs) diff --git a/ixmp4/data/db/utils.py b/ixmp4/data/db/utils.py index 662e67ab..92527224 100644 --- a/ixmp4/data/db/utils.py +++ b/ixmp4/data/db/utils.py @@ -1,3 +1,6 @@ +from typing import Any + +import numpy as np import pandas as pd @@ -6,7 +9,7 @@ def map_existing( existing_df: pd.DataFrame, join_on: tuple[str, str], map: tuple[str, str], -): +) -> tuple[pd.DataFrame, np.ndarray[Any, np.dtype[np.str_]]]: _, join_to = join_on _, map_to = map existing_df = existing_df.rename(columns=dict([join_on, map]))[[join_to, map_to]] diff --git a/ixmp4/data/generator.py b/ixmp4/data/generator.py index b118e335..33adb668 100644 --- a/ixmp4/data/generator.py +++ b/ixmp4/data/generator.py @@ -1,8 +1,9 @@ import random import sys +from collections.abc import Generator, Iterator from datetime import datetime, timedelta from itertools import cycle -from typing import Generator +from typing import Any import numpy as np import pandas as pd @@ -37,11 +38,11 @@ def __init__( self.num_units = num_units self.num_datapoints = num_datapoints - def yield_model_names(self): + def yield_model_names(self) -> Generator[str, Any, None]: for i in range(self.num_models): yield f"Model {i}" - def yield_runs(self, model_names: Generator[str, None, None]): + def yield_runs(self, model_names: Iterator[str]) -> Generator[Run, Any, None]: scen_per_model = self.num_runs // self.num_models if scen_per_model == 0: scen_per_model = 1 @@ -58,7 +59,7 @@ def yield_runs(self, model_names: Generator[str, None, None]): model_name = next(model_names) scenario_index = 0 - def yield_regions(self): + def yield_regions(self) -> Generator[Region, Any, None]: for i in range(self.num_regions): name = f"Region {i}" try: @@ -66,7 +67,7 @@ def yield_regions(self): except Region.NotUnique: yield self.platform.regions.get(name) - def yield_units(self): + def yield_units(self) -> Generator[Unit, Any, None]: for i in range(self.num_units): name = f"Unit {i}" try: @@ -74,17 +75,17 @@ def yield_units(self): except Unit.NotUnique: yield self.platform.units.get(name) - def yield_variable_names(self): + def yield_variable_names(self) -> Generator[str, Any, None]: for i in range(self.num_variables): yield f"Variable {i}" def yield_datapoints( self, - runs: Generator[Run, None, None], - variable_names: Generator[str, None, None], - units: Generator[Unit, None, None], - regions: Generator[Region, None, None], - ): + runs: Iterator[Run], + variable_names: Iterator[str], + units: Iterator[Unit], + regions: Iterator[Region], + ) -> Generator[pd.DataFrame, Any, None]: dp_count = 0 for run in runs: region_name = next(regions).name @@ -110,7 +111,9 @@ def yield_datapoints( if self.num_datapoints == dp_count: break - def get_datapoints(self, type: DataPoint.Type, max: int = sys.maxsize): + def get_datapoints( + self, type: DataPoint.Type, max: int = sys.maxsize + ) -> pd.DataFrame: df = pd.DataFrame( columns=[ "region", @@ -153,7 +156,7 @@ def get_datapoints(self, type: DataPoint.Type, max: int = sys.maxsize): df["value"] = values return df - def generate(self): + def generate(self) -> None: model_names = cycle([n for n in self.yield_model_names()]) runs = cycle([r for r in self.yield_runs(model_names=model_names)]) regions = cycle([r for r in self.yield_regions()]) diff --git a/ixmp4/data/types.py b/ixmp4/data/types.py index 55549c4d..449657f9 100644 --- a/ixmp4/data/types.py +++ b/ixmp4/data/types.py @@ -10,6 +10,7 @@ Float = Mapped[float] IndexSetId = Mapped[db.IndexSetIdType] Integer = Mapped[int] +# NOTE only one type will ever be in list, but not sure if we can map a union of lists OptimizationDataList = Mapped[list[float | int | str]] JsonDict = Mapped[dict[str, Any]] OptimizationDataType = Mapped[Literal["float", "int", "str"] | None] diff --git a/ixmp4/db/__init__.py b/ixmp4/db/__init__.py index 90a3d67d..7834e9a7 100644 --- a/ixmp4/db/__init__.py +++ b/ixmp4/db/__init__.py @@ -36,12 +36,17 @@ from typing import Annotated from sqlalchemy import ( + BinaryExpression, + BindParameter, + ColumnExpressionArgument, ForeignKey, Index, + Label, Sequence, UniqueConstraint, delete, exists, + false, func, insert, or_, @@ -49,10 +54,13 @@ sql, update, ) +from sqlalchemy import Column as typing_column from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.exc import IntegrityError, MultipleResultsFound from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import ( + Bundle, + MappedColumn, Relationship, Session, aliased, @@ -71,7 +79,9 @@ Column(Integer, ForeignKey("optimization_indexset.id"), nullable=False, index=True), ] JsonType = JSON() -JsonType = JsonType.with_variant(JSONB(), "postgresql") +# NOTE sqlalchemy's JSON is untyped, but we may not need it if we redesign the opt DB +# model +JsonType = JsonType.with_variant(JSONB(), "postgresql") # type:ignore[no-untyped-call] NameType = Annotated[str, Column(String(255), nullable=False, unique=False)] RunIdType = Annotated[ int, diff --git a/ixmp4/db/filters.py b/ixmp4/db/filters.py index 074c0990..275e2356 100644 --- a/ixmp4/db/filters.py +++ b/ixmp4/db/filters.py @@ -1,35 +1,43 @@ import operator -from types import UnionType -from typing import Any, ClassVar, Optional, Union, get_args, get_origin +from collections.abc import Callable, Iterable +from types import GenericAlias, UnionType +from typing import Any, ClassVar, Optional, TypeVar, Union, cast, get_args, get_origin from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator from pydantic.fields import FieldInfo +# TODO Import this from typing when dropping support for 3.10 +from typing_extensions import Self + from ixmp4 import db from ixmp4.core.exceptions import BadFilterArguments, ProgrammingError +in_Type = TypeVar("in_Type") + -def in_(c, v): +def in_( + c: db.typing_column[in_Type], v: Iterable[in_Type] | db.BindParameter[in_Type] +) -> db.BinaryExpression[bool]: return c.in_(v) -def like(c, v): +def like(c: db.typing_column[str], v: str) -> db.BinaryExpression[bool]: return c.like(escape_wildcard(v), escape="\\") -def ilike(c, v): +def ilike(c: db.typing_column[str], v: str) -> db.BinaryExpression[bool]: return c.ilike(escape_wildcard(v), escape="\\") -def notlike(c, v): +def notlike(c: db.typing_column[str], v: str) -> db.BinaryExpression[bool]: return c.notlike(escape_wildcard(v), escape="\\") -def notilike(c, v): +def notilike(c: db.typing_column[str], v: str) -> db.BinaryExpression[bool]: return c.notilike(escape_wildcard(v), escape="\\") -def escape_wildcard(v): +def escape_wildcard(v: str) -> str: return v.replace("%", "\\%").replace("*", "%") @@ -61,7 +69,7 @@ class String(str): argument_seperator = "__" filter_func_prefix = "filter_" -lookup_map: dict[object, dict] = { +lookup_map: dict[type, dict[str, tuple[type, Callable[..., Any]]]] = { Id: { "__root__": (int, operator.eq), "in": (list[int], in_), @@ -100,33 +108,53 @@ def get_filter_func_name(n: str) -> str: return filter_func_prefix + n.strip() +def _ensure_str_list(any_list: list[Any]) -> list[str]: + str_list: list[str] = [] + for item in any_list: + if isinstance(item, str): + str_list.append(item) + else: + raise ProgrammingError("Field argument `lookups` must be `list` of `str`.") + + return str_list + + PydanticMeta: type = type(BaseModel) -class FilterMeta(PydanticMeta): - def __new__(cls, name: str, bases: tuple, namespace: dict, **kwargs): +# NOTE mypy seems to say PydanticMeta has type Any, don't see how we could change that +class FilterMeta(PydanticMeta): # type: ignore[misc] + def __new__( + cls, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> type["BaseFilter"]: annots = namespace.get("__annotations__", {}).copy() for name, annot in annots.items(): if get_origin(annot) == ClassVar: continue cls.process_field(namespace, name, annot) - return super().__new__(cls, name, bases, namespace, **kwargs) + return cast(FilterMeta, super().__new__(cls, name, bases, namespace, **kwargs)) @classmethod - def build_lookups(cls, field_type: type) -> dict: + def build_lookups( + cls, field_type: type + ) -> dict[str, tuple[type, Callable[..., Any]]]: global lookup_map if field_type not in lookup_map.keys(): if get_origin(field_type) in [Union, UnionType]: - unified_types = get_args(field_type) + unified_types: tuple[type, ...] = get_args(field_type) ut_lookup_map = { type_: cls.build_lookups(type_) for type_ in unified_types } - all_lookup_names = set() + all_lookup_names: set[str] = set() for tl in ut_lookup_map.values(): all_lookup_names |= set(tl.keys()) - lookups = {} + lookups: dict[str, tuple[type, Callable[..., Any]]] = {} for lookup_name in all_lookup_names: tuples = [ ut_lookup_map[type_][lookup_name] @@ -135,8 +163,21 @@ def build_lookups(cls, field_type: type) -> dict: ] types, _ = zip(*tuples) - def lookup_func(c, v, tuples=tuples): + def lookup_func( + c: db.typing_column[Any], + v: Integer + | Float + | Id + | String + | Boolean + | list[Integer | Float | Id | String | Boolean], + tuples: list[tuple[type, Callable[..., Any]]] = tuples, + ) -> Any: for t, lf in tuples: + # NOTE can't check isinstance(..., list[...]) directly, but + # all these cases call the same lf() anyway; skipping + if isinstance(t, GenericAlias): + return lf(c, v) if isinstance(v, t): return lf(c, v) raise ProgrammingError @@ -144,7 +185,7 @@ def lookup_func(c, v, tuples=tuples): lookups[lookup_name] = ( # dynamic union types can't # be done according to type checkers - Union[tuple(types)], # type:ignore + Union[tuple(types)], # type: ignore[assignment] lookup_func, ) return lookups @@ -154,7 +195,9 @@ def lookup_func(c, v, tuples=tuples): return lookup_map[field_type] @classmethod - def process_field(cls, namespace: dict, field_name: str, field_type: type): + def process_field( + cls, namespace: dict[str, Any], field_name: str, field_type: type + ) -> None: lookups = cls.build_lookups(field_type) field: FieldInfo | None = namespace.get(field_name, Field(default=None)) @@ -162,23 +205,19 @@ def process_field(cls, namespace: dict, field_name: str, field_type: type): return namespace.setdefault(field_name, field) - override_lookups: list | None = None + override_lookups: list[str] | None = None if isinstance(field.json_schema_extra, dict): jschema_lookups = field.json_schema_extra.get("lookups", None) + # NOTE We can't `isinstance` parametrized generics. Nothing seems to utilize + # `lookups`, though, so this should not worsen performance. if isinstance(jschema_lookups, list): - override_lookups = jschema_lookups - else: - raise ProgrammingError( - "Field argument `lookups` must be `list` of `str`." - ) + override_lookups = _ensure_str_list(jschema_lookups) else: override_lookups = None if isinstance(override_lookups, list): lookups = {k: v for k, v in lookups.items() if k in override_lookups} - elif override_lookups is None: + else: # override_lookups is None pass - else: - lookups = {} base_field_alias = str(field.alias) if field.alias else field_name cls.expand_lookups( @@ -192,21 +231,34 @@ def process_field(cls, namespace: dict, field_name: str, field_type: type): def expand_lookups( cls, name: str, - lookups: dict, - namespace: dict, + lookups: dict[str, tuple[type, Callable[..., Any]]], + namespace: dict[str, Any], base_field_alias: str | None = None, - ): + ) -> None: global argument_seperator for lookup_alias, (type_, func) in lookups.items(): - if lookup_alias == "__root__": - filter_name = name - else: - filter_name = name + argument_seperator + lookup_alias + filter_name = ( + name + if lookup_alias == "__root__" + else name + argument_seperator + lookup_alias + ) namespace["__annotations__"][filter_name] = Optional[type_] func_name = get_filter_func_name(filter_name) - def filter_func(self, exc, f, v, func=func, session=None): + FilterType = TypeVar("FilterType") + + def filter_func( + self: Self, + exc: db.sql.Select[tuple[FilterType, ...]], + f: str, + v: Integer | Float | Id | String | Boolean, + func: Callable[ + [str, Integer | Float | Id | String | Boolean], + db.ColumnExpressionArgument[bool], + ] = func, + session: db.Session | None = None, + ) -> db.sql.Select[tuple[FilterType, ...]]: return exc.where(func(f, v)) namespace.setdefault(func_name, filter_func) @@ -228,6 +280,10 @@ def filter_func(self, exc, f, v, func=func, session=None): namespace[filter_name] = field +ExpandType = TypeVar("ExpandType", str, list[str]) +FilterType = TypeVar("FilterType") + + class BaseFilter(BaseModel, metaclass=FilterMeta): model_config = ConfigDict( arbitrary_types_allowed=True, @@ -238,7 +294,9 @@ class BaseFilter(BaseModel, metaclass=FilterMeta): @model_validator(mode="before") @classmethod - def expand_simple_filters(cls, v): + def expand_simple_filters( + cls, v: ExpandType | dict[str, ExpandType] + ) -> dict[str, ExpandType]: return expand_simple_filter(v) def __init__(self, **data: Any) -> None: @@ -248,10 +306,19 @@ def __init__(self, **data: Any) -> None: except ValidationError as e: raise BadFilterArguments(model=e.title, errors=e.errors()) - def join(self, exc: db.sql.Select, session=None) -> db.sql.Select: + def join( + self, + exc: db.sql.Select[tuple[FilterType]], + session: db.Session | None = None, + ) -> db.sql.Select[tuple[FilterType]]: return exc - def apply(self, exc: db.sql.Select, model, session) -> db.sql.Select: + def apply( + self, + exc: db.sql.Select[tuple[FilterType]], + model: object, + session: db.Session, + ) -> db.sql.Select[tuple[FilterType]]: dict_model = dict(self) for name, field_info in self.model_fields.items(): value = dict_model.get(name, field_info.get_default()) @@ -279,20 +346,18 @@ def apply(self, exc: db.sql.Select, model, session) -> db.sql.Select: sqla_column = jschema_col else: sqla_column = None - if sqla_column is None: - column = None - else: - column = getattr(model, sqla_column, None) + column = ( + None if sqla_column is None else getattr(model, sqla_column, None) + ) exc = filter_func(exc, column, value, session=session) return exc.distinct() -def expand_simple_filter(value): +def expand_simple_filter( + value: ExpandType | dict[str, ExpandType], +) -> dict[str, ExpandType]: if isinstance(value, str): - if "*" in value: - return dict(name__like=value) - else: - return dict(name=value) + return dict(name__like=value) if "*" in value else dict(name=value) elif isinstance(value, list): if any(["*" in v for v in value]): raise NotImplementedError("Filter by list with wildcard is not implemented") diff --git a/ixmp4/db/migrations/versions/d66a8276ba0a_make_scalar_unit_non_optional.py b/ixmp4/db/migrations/versions/d66a8276ba0a_make_scalar_unit_non_optional.py new file mode 100644 index 00000000..c055928e --- /dev/null +++ b/ixmp4/db/migrations/versions/d66a8276ba0a_make_scalar_unit_non_optional.py @@ -0,0 +1,41 @@ +# type: ignore +"""Make Scalar.unit non-optional + +Revision ID: d66a8276ba0a +Revises: 914991d09f59 +Create Date: 2024-11-29 14:14:42.857695 + +""" + +import sqlalchemy as sa +from alembic import op + +# Revision identifiers, used by Alembic. +revision = "d66a8276ba0a" +down_revision = "914991d09f59" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("optimization_scalar", schema=None) as batch_op: + batch_op.alter_column("unit__id", existing_type=sa.INTEGER(), nullable=False) + + with op.batch_alter_table("runmetaentry", schema=None) as batch_op: + batch_op.add_column(sa.Column("dtype", sa.String(length=20), nullable=False)) + batch_op.drop_column("type") + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("runmetaentry", schema=None) as batch_op: + batch_op.add_column(sa.Column("type", sa.VARCHAR(length=20), nullable=False)) + batch_op.drop_column("dtype") + + with op.batch_alter_table("optimization_scalar", schema=None) as batch_op: + batch_op.alter_column("unit__id", existing_type=sa.INTEGER(), nullable=True) + + # ### end Alembic commands ### diff --git a/ixmp4/db/utils/__init__.py b/ixmp4/db/utils/__init__.py index c8636ffe..3ef8a97d 100644 --- a/ixmp4/db/utils/__init__.py +++ b/ixmp4/db/utils/__init__.py @@ -1,13 +1,23 @@ from contextlib import suppress +from typing import TYPE_CHECKING, Any, TypeVar from sqlalchemy import inspect, sql from sqlalchemy.orm import Mapper -from sqlalchemy.sql import ColumnCollection +from sqlalchemy.sql import ColumnCollection, ColumnElement +from sqlalchemy.sql.base import ReadOnlyColumnCollection from ixmp4.core.exceptions import ProgrammingError +if TYPE_CHECKING: + from ixmp4.data.db.base import BaseModel -def is_joined(exc: sql.Select, model): +JoinType = TypeVar("JoinType", bound=sql.Select[tuple["BaseModel", ...]]) + + +# This should not need to be Any, I think, but if I put "BaseModel" instead, various +# things like RunMetaEntry are not recognized -> this sounds like covarianve again, +# but covariant typevars are not allowed as type hints. +def is_joined(exc: sql.Select[tuple[Any, ...]], model: type["BaseModel"]) -> bool: """Returns `True` if `model` has been joined in `exc`.""" for visitor in sql.visitors.iterate(exc): # Checking for `.join(Child)` clauses @@ -15,21 +25,23 @@ def is_joined(exc: sql.Select, model): # Visitor might be of ColumnCollection or so, # which cannot be compared to model with suppress(TypeError): - if model == visitor.entity_namespace: # type: ignore + if model == visitor.entity_namespace: # type: ignore[attr-defined] return True return False -def get_columns(model_class: type) -> ColumnCollection: - mapper: Mapper | None = inspect(model_class) +def get_columns(model_class: type) -> ColumnCollection[str, ColumnElement[Any]]: + mapper: Mapper[Any] | None = inspect(model_class) if mapper is not None: return mapper.selectable.columns else: raise ProgrammingError(f"Model class `{model_class.__name__}` is not mapped.") -def get_pk_columns(model_class: type) -> ColumnCollection: - columns: ColumnCollection = ColumnCollection() +def get_pk_columns( + model_class: type, +) -> ReadOnlyColumnCollection[str, ColumnElement[int]]: + columns: ColumnCollection[str, ColumnElement[int]] = ColumnCollection() for col in get_columns(model_class): if col.primary_key: columns.add(col) @@ -37,8 +49,10 @@ def get_pk_columns(model_class: type) -> ColumnCollection: return columns.as_readonly() -def get_foreign_columns(model_class: type) -> ColumnCollection: - columns: ColumnCollection = ColumnCollection() +def get_foreign_columns( + model_class: type, +) -> ReadOnlyColumnCollection[str, ColumnElement[int]]: + columns: ColumnCollection[str, ColumnElement[int]] = ColumnCollection() for col in get_columns(model_class): if len(col.foreign_keys) > 0: columns.add(col) diff --git a/ixmp4/db/utils/sqlite.py b/ixmp4/db/utils/sqlite.py index 12b93773..e7db043e 100644 --- a/ixmp4/db/utils/sqlite.py +++ b/ixmp4/db/utils/sqlite.py @@ -1,5 +1,5 @@ +from collections.abc import Generator from pathlib import Path -from typing import Generator from ixmp4.conf import settings @@ -31,7 +31,4 @@ def search_databases(name: str) -> str | None: """Returns a database URI if the desired database exists, otherwise `None`.""" database_path = get_database_path(name) - if database_path.exists(): - return get_dsn(database_path) - else: - return None + return get_dsn(database_path) if database_path.exists() else None diff --git a/ixmp4/server/rest/__init__.py b/ixmp4/server/rest/__init__.py index cb49e0cb..762ffc52 100644 --- a/ixmp4/server/rest/__init__.py +++ b/ixmp4/server/rest/__init__.py @@ -74,10 +74,7 @@ class APIInfo(BaseModel): @v1.get("/", response_model=APIInfo) -def root( - platform: str = Path(), - version: str = Depends(deps.get_version), -): +def root(platform: str = Path(), version: str = Depends(deps.get_version)) -> APIInfo: return APIInfo( name=platform, version=version, @@ -88,7 +85,7 @@ def root( @v1.exception_handler(IxmpError) -async def http_exception_handler(request: Request, exc: IxmpError): +async def http_exception_handler(request: Request, exc: IxmpError) -> JSONResponse: return JSONResponse( content=jsonable_encoder( { diff --git a/ixmp4/server/rest/base.py b/ixmp4/server/rest/base.py index 17ad863a..3fe73325 100644 --- a/ixmp4/server/rest/base.py +++ b/ixmp4/server/rest/base.py @@ -4,6 +4,9 @@ from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict, Field +# TODO Import this from typing when dropping Python 3.11 +from typing_extensions import TypedDict, Unpack + from ixmp4.conf import settings from ixmp4.data import api @@ -24,6 +27,11 @@ class Pagination(BaseModel): offset: int = Field(default=0, ge=0) +class InitKwargs(TypedDict): + total: int + pagination: Pagination + + class EnumerationOutput(BaseModel, Generic[EnumeratedT]): pagination: Pagination total: int @@ -31,13 +39,13 @@ class EnumerationOutput(BaseModel, Generic[EnumeratedT]): def __init__( __pydantic_self__, - *args, results: pd.DataFrame | api.DataFrame | list[EnumeratedT], - **kwargs, - ): - kwargs["results"] = ( + **kwargs: Unpack[InitKwargs], + ) -> None: + _kwargs = {"results": results, **kwargs} + _kwargs["results"] = ( api.DataFrame.model_validate(results) if isinstance(results, pd.DataFrame) else results ) - super().__init__(**kwargs) + super().__init__(**_kwargs) diff --git a/ixmp4/server/rest/decorators.py b/ixmp4/server/rest/decorators.py index c5e9ec1f..20b6070e 100644 --- a/ixmp4/server/rest/decorators.py +++ b/ixmp4/server/rest/decorators.py @@ -1,6 +1,7 @@ -def autodoc(f): +from collections.abc import Callable +from typing import Any + + +def autodoc(f: Callable[..., Any]) -> None: funcname = f""":func:`{f.__module__}.{f.__qualname__}`\n\n""" - if f.__doc__ is not None: - f.__doc__ = funcname + f.__doc__ - else: - f.__doc__ = funcname + f.__doc__ = funcname + f.__doc__ if f.__doc__ is not None else funcname diff --git a/ixmp4/server/rest/deps.py b/ixmp4/server/rest/deps.py index 55a6e93c..abe74a28 100644 --- a/ixmp4/server/rest/deps.py +++ b/ixmp4/server/rest/deps.py @@ -1,12 +1,13 @@ import logging -from typing import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Callable +from typing import Any import jwt from fastapi import Depends, Header, Path from ixmp4.conf import settings from ixmp4.conf.auth import SelfSignedAuth -from ixmp4.conf.manager import ManagerConfig +from ixmp4.conf.manager import ManagerConfig, ManagerPlatformInfo from ixmp4.conf.user import User, anonymous_user, local_user from ixmp4.core.exceptions import Forbidden, InvalidToken, PlatformNotFound from ixmp4.data.backend.db import SqlAlchemyBackend @@ -17,7 +18,9 @@ ) -async def validate_token(authorization: str = Header(None)) -> dict | None: +async def validate_token( + authorization: str | None = Header(None), +) -> dict[str, Any] | None: """Validates a JSON Web Token with the secret supplied in the `IXMP4_SECRET_HS256` environment variable.""" @@ -26,19 +29,23 @@ async def validate_token(authorization: str = Header(None)) -> dict | None: encoded_jwt = authorization.split(" ")[1] try: - return jwt.decode( + # jwt.decode just returns Any, so this is already assuming + decoded_jwt: dict[str, Any] = jwt.decode( encoded_jwt, settings.secret_hs256, leeway=300, algorithms=["HS256"] ) + return decoded_jwt except jwt.InvalidTokenError as e: raise InvalidToken("The supplied token is expired or invalid.") from e -async def do_not_validate_token(authorization: str = Header(None)) -> dict | None: +async def do_not_validate_token( + authorization: str = Header(None), +) -> dict[str, dict[str, Any]] | None: """Override dependency used for skipping authentication while testing.""" return {"user": local_user.model_dump()} -async def get_user(token: dict | None = Depends(validate_token)) -> User: +async def get_user(token: dict[str, Any] | None = Depends(validate_token)) -> User: """Returns a user object for permission checks.""" if token is None: return anonymous_user @@ -52,7 +59,7 @@ async def get_user(token: dict | None = Depends(validate_token)) -> User: return User(**user_dict) -async def get_version(): +async def get_version() -> str: from ixmp4 import __version__ return __version__ @@ -63,7 +70,8 @@ async def get_managed_backend( ) -> AsyncGenerator[SqlAlchemyBackend, None]: """Returns a platform backend for a platform name as a path parameter. Also checks user access permissions if in managed mode.""" - info = manager.get_platform(platform, jti=manager.auth.get_user().jti) + jti = manager.auth.get_user().jti if manager.auth else None + info = manager.get_platform(platform, jti=jti) if info.dsn.startswith("http"): raise PlatformNotFound(f"Platform '{platform}' was not found.") @@ -93,13 +101,13 @@ async def get_toml_backend( backend.close() -if settings.managed: - get_backend = get_managed_backend -else: - get_backend = get_toml_backend +get_backend = get_managed_backend if settings.managed else get_toml_backend -def get_test_backend_dependency(backend, auth_params) -> Callable: +def get_test_backend_dependency( + backend: SqlAlchemyBackend, + auth_params: tuple[User, ManagerConfig, ManagerPlatformInfo], +) -> Callable[[str, User], AsyncGenerator[SqlAlchemyBackend, None]]: async def get_memory_backend( platform: str = Path(), user: User = Depends(get_user) ) -> AsyncGenerator[SqlAlchemyBackend, None]: diff --git a/ixmp4/server/rest/docs.py b/ixmp4/server/rest/docs.py index 5423bef7..f8be7cc5 100644 --- a/ixmp4/server/rest/docs.py +++ b/ixmp4/server/rest/docs.py @@ -1,7 +1,10 @@ +from typing import cast + from fastapi import APIRouter, Depends, Path, Query from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend +from ixmp4.data.db.docs import AbstractDocs from . import deps from .base import BaseModel, EnumerationOutput, Pagination @@ -22,7 +25,7 @@ def list_models( dimension_id: int | None = Query(None), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[AbstractDocs]: return EnumerationOutput( results=backend.models.docs.list(dimension_id=dimension_id), total=backend.models.docs.count(dimension_id=dimension_id), @@ -34,16 +37,16 @@ def list_models( def set_models( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.models.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.models.docs.set(**docs.model_dump())) @router.delete("/models/{dimension_id}/") def delete_models( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.models.docs.delete(dimension_id) +) -> None: + backend.models.docs.delete(dimension_id) @router.get("/regions/", response_model=EnumerationOutput[api.Docs]) @@ -51,7 +54,7 @@ def list_regions( dimension_id: int | None = Query(None), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[AbstractDocs]: return EnumerationOutput( results=backend.regions.docs.list(dimension_id=dimension_id), total=backend.regions.docs.count(dimension_id=dimension_id), @@ -63,16 +66,16 @@ def list_regions( def set_regions( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.regions.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.regions.docs.set(**docs.model_dump())) @router.delete("/regions/{dimension_id}/") def delete_regions( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.regions.docs.delete(dimension_id) +) -> None: + backend.regions.docs.delete(dimension_id) @router.get("/scenarios/", response_model=EnumerationOutput[api.Docs]) @@ -80,7 +83,7 @@ def list_scenarios( dimension_id: int | None = Query(None), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[AbstractDocs]: return EnumerationOutput( results=backend.scenarios.docs.list(dimension_id=dimension_id), total=backend.scenarios.docs.count(dimension_id=dimension_id), @@ -92,16 +95,16 @@ def list_scenarios( def set_scenarios( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.scenarios.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.scenarios.docs.set(**docs.model_dump())) @router.delete("/scenarios/{dimension_id}/") def delete_scenarios( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.scenarios.docs.delete(dimension_id) +) -> None: + backend.scenarios.docs.delete(dimension_id) @router.get("/units/", response_model=EnumerationOutput[api.Docs]) @@ -109,7 +112,7 @@ def list_units( dimension_id: int | None = Query(None), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[AbstractDocs]: return EnumerationOutput( results=backend.units.docs.list(dimension_id=dimension_id), total=backend.units.docs.count(dimension_id=dimension_id), @@ -121,16 +124,16 @@ def list_units( def set_units( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.units.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.units.docs.set(**docs.model_dump())) @router.delete("/units/{dimension_id}/") def delete_units( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.units.docs.delete(dimension_id) +) -> None: + backend.units.docs.delete(dimension_id) @router.get("/iamc/variables/", response_model=EnumerationOutput[api.Docs]) @@ -138,7 +141,7 @@ def list_variables( dimension_id: int | None = Query(None), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[AbstractDocs]: return EnumerationOutput( results=backend.iamc.variables.docs.list(dimension_id=dimension_id), total=backend.iamc.variables.docs.count(dimension_id=dimension_id), @@ -150,16 +153,16 @@ def list_variables( def set_variables( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.iamc.variables.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.iamc.variables.docs.set(**docs.model_dump())) @router.delete("/iamc/variables/{dimension_id}/") def delete_variables( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.iamc.variables.docs.delete(dimension_id) +) -> None: + backend.iamc.variables.docs.delete(dimension_id) @router.get("/optimization/indexsets/", response_model=EnumerationOutput[api.Docs]) @@ -167,7 +170,7 @@ def list_indexsets( dimension_id: int | None = Query(None), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[AbstractDocs]: return EnumerationOutput( results=backend.optimization.indexsets.docs.list(dimension_id=dimension_id), total=backend.optimization.indexsets.docs.count(dimension_id=dimension_id), @@ -179,133 +182,158 @@ def list_indexsets( def set_indexsets( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.indexsets.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.optimization.indexsets.docs.set(**docs.model_dump())) @router.delete("/optimization/indexsets/{dimension_id}/") def delete_indexsets( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.indexsets.docs.delete(dimension_id) +) -> None: + backend.optimization.indexsets.docs.delete(dimension_id) -@router.get("/optimization/scalars/", response_model=list[api.Docs]) +@router.get("/optimization/scalars/", response_model=EnumerationOutput[api.Docs]) def list_scalars( dimension_id: int | None = Query(None), + pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.scalars.docs.list(dimension_id=dimension_id) +) -> EnumerationOutput[AbstractDocs]: + return EnumerationOutput( + results=backend.optimization.scalars.docs.list(dimension_id=dimension_id), + total=backend.optimization.scalars.docs.count(dimension_id=dimension_id), + pagination=pagination, + ) @router.post("/optimization/scalars/", response_model=api.Docs) def set_scalars( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.scalars.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.optimization.scalars.docs.set(**docs.model_dump())) @router.delete("/optimization/scalars/{dimension_id}/") def delete_scalars( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.scalars.docs.delete(dimension_id) +) -> None: + backend.optimization.scalars.docs.delete(dimension_id) -@router.get("/optimization/tables/", response_model=list[api.Docs]) +@router.get("/optimization/tables/", response_model=EnumerationOutput[api.Docs]) def list_tables( dimension_id: int | None = Query(None), + pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.tables.docs.list(dimension_id=dimension_id) +) -> EnumerationOutput[AbstractDocs]: + return EnumerationOutput( + results=backend.optimization.tables.docs.list(dimension_id=dimension_id), + total=backend.optimization.tables.docs.count(dimension_id=dimension_id), + pagination=pagination, + ) @router.post("/optimization/tables/", response_model=api.Docs) def set_tables( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.tables.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.optimization.tables.docs.set(**docs.model_dump())) @router.delete("/optimization/tables/{dimension_id}/") def delete_tables( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.tables.docs.delete(dimension_id) +) -> None: + backend.optimization.tables.docs.delete(dimension_id) -@router.get("/optimization/parameters/", response_model=list[api.Docs]) +@router.get("/optimization/parameters/", response_model=EnumerationOutput[api.Docs]) def list_parameters( dimension_id: int | None = Query(None), + pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.parameters.docs.list(dimension_id=dimension_id) +) -> EnumerationOutput[AbstractDocs]: + return EnumerationOutput( + results=backend.optimization.parameters.docs.list(dimension_id=dimension_id), + total=backend.optimization.parameters.docs.count(dimension_id=dimension_id), + pagination=pagination, + ) @router.post("/optimization/parameters/", response_model=api.Docs) def set_parameters( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.parameters.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.optimization.parameters.docs.set(**docs.model_dump())) @router.delete("/optimization/parameters/{dimension_id}/") def delete_parameters( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.parameters.docs.delete(dimension_id) +) -> None: + backend.optimization.parameters.docs.delete(dimension_id) -@router.get("/optimization/variables/", response_model=list[api.Docs]) +@router.get("/optimization/variables/", response_model=EnumerationOutput[api.Docs]) def list_optimization_variables( dimension_id: int | None = Query(None), + pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.variables.docs.list(dimension_id=dimension_id) +) -> EnumerationOutput[AbstractDocs]: + return EnumerationOutput( + results=backend.optimization.variables.docs.list(dimension_id=dimension_id), + total=backend.optimization.variables.docs.count(dimension_id=dimension_id), + pagination=pagination, + ) @router.post("/optimization/variables/", response_model=api.Docs) def set_optimization_variables( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.variables.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.optimization.variables.docs.set(**docs.model_dump())) @router.delete("/optimization/variables/{dimension_id}/") def delete_optimization_variables( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.variables.docs.delete(dimension_id) +) -> None: + backend.optimization.variables.docs.delete(dimension_id) -@router.get("/optimization/equations/", response_model=list[api.Docs]) +@router.get("/optimization/equations/", response_model=EnumerationOutput[api.Docs]) def list_equations( dimension_id: int | None = Query(None), + pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.equations.docs.list(dimension_id=dimension_id) +) -> EnumerationOutput[AbstractDocs]: + return EnumerationOutput( + results=backend.optimization.equations.docs.list(dimension_id=dimension_id), + total=backend.optimization.equations.docs.count(dimension_id=dimension_id), + pagination=pagination, + ) @router.post("/optimization/equations/", response_model=api.Docs) def set_equations( docs: DocsInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.equations.docs.set(**docs.model_dump()) +) -> api.Docs: + return cast(api.Docs, backend.optimization.equations.docs.set(**docs.model_dump())) @router.delete("/optimization/equations/{dimension_id}/") def delete_equations( dimension_id: int = Path(), backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.equations.docs.delete(dimension_id) +) -> None: + backend.optimization.equations.docs.delete(dimension_id) diff --git a/ixmp4/server/rest/iamc/datapoint.py b/ixmp4/server/rest/iamc/datapoint.py index 90f2e949..1f6e4cb5 100644 --- a/ixmp4/server/rest/iamc/datapoint.py +++ b/ixmp4/server/rest/iamc/datapoint.py @@ -4,6 +4,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.iamc.datapoint.filter import DataPointFilter +from ixmp4.data.db.iamc.datapoint.model import DataPoint from .. import deps from ..base import EnumerationOutput, Pagination @@ -24,7 +25,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[DataPoint]: """This endpoint is used to retrieve and optionally filter data.add() Filter parameters are provided as keyword arguments. @@ -83,8 +84,11 @@ def query( def bulk_upsert( df: api.DataFrame, backend: Backend = Depends(deps.get_backend), -): - return backend.iamc.datapoints.bulk_upsert(df.to_pandas()) +) -> None: + # A pandera.DataFrame is a subclass of pd.DataFrame, so this is fine. Mypy likely + # complains because our decorators change the type hint in some incompatible way. + # Might be about covariance again. + backend.iamc.datapoints.bulk_upsert(df.to_pandas()) # type: ignore[arg-type] @autodoc @@ -92,5 +96,5 @@ def bulk_upsert( def bulk_delete( df: api.DataFrame, backend: Backend = Depends(deps.get_backend), -): - return backend.iamc.datapoints.bulk_delete(df.to_pandas()) +) -> None: + backend.iamc.datapoints.bulk_delete(df.to_pandas()) # type: ignore[arg-type] diff --git a/ixmp4/server/rest/iamc/model.py b/ixmp4/server/rest/iamc/model.py index 31095aa4..e8e119ab 100644 --- a/ixmp4/server/rest/iamc/model.py +++ b/ixmp4/server/rest/iamc/model.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.model.filter import IamcModelFilter +from ixmp4.data.db.model.model import Model from .. import deps from ..base import EnumerationOutput, Pagination @@ -19,7 +20,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Model]: return EnumerationOutput( results=backend.models.paginate( _filter=filter, diff --git a/ixmp4/server/rest/iamc/region.py b/ixmp4/server/rest/iamc/region.py index a90b34da..6981229f 100644 --- a/ixmp4/server/rest/iamc/region.py +++ b/ixmp4/server/rest/iamc/region.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.region.filter import IamcRegionFilter +from ixmp4.data.db.region.model import Region from .. import deps from ..base import EnumerationOutput, Pagination @@ -19,7 +20,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Region]: return EnumerationOutput( results=backend.regions.paginate( _filter=filter, diff --git a/ixmp4/server/rest/iamc/scenario.py b/ixmp4/server/rest/iamc/scenario.py index 044dd14b..fe1050d7 100644 --- a/ixmp4/server/rest/iamc/scenario.py +++ b/ixmp4/server/rest/iamc/scenario.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.scenario.filter import IamcScenarioFilter +from ixmp4.data.db.scenario.model import Scenario from .. import deps from ..base import EnumerationOutput, Pagination @@ -19,7 +20,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Scenario]: return EnumerationOutput( results=backend.scenarios.paginate( _filter=filter, diff --git a/ixmp4/server/rest/iamc/timeseries.py b/ixmp4/server/rest/iamc/timeseries.py index c50df3fb..cf9f197a 100644 --- a/ixmp4/server/rest/iamc/timeseries.py +++ b/ixmp4/server/rest/iamc/timeseries.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import Any from fastapi import APIRouter, Body, Depends, Query, Response @@ -5,6 +6,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.iamc.timeseries.filter import TimeSeriesFilter +from ixmp4.data.db.iamc.timeseries.model import TimeSeries from .. import deps from ..base import BaseModel, EnumerationOutput, Pagination @@ -18,7 +20,7 @@ class TimeSeriesInput(BaseModel): run__id: int - parameters: dict[str, Any] + parameters: Mapping[str, Any] @autodoc @@ -29,7 +31,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[TimeSeries]: return EnumerationOutput( results=backend.iamc.timeseries.paginate( _filter=filter, @@ -43,11 +45,11 @@ def query( ) -@router.post("/", response_model=api.Run) +@router.post("/", response_model=api.TimeSeries) def create( timeseries: TimeSeriesInput, backend: Backend = Depends(deps.get_backend), -): +) -> TimeSeries: return backend.iamc.timeseries.create(**timeseries.model_dump()) @@ -55,15 +57,15 @@ def create( def get_by_id( id: int, backend: Backend = Depends(deps.get_backend), -): +) -> TimeSeries: return backend.iamc.timeseries.get_by_id(id) @router.post("/bulk/") def bulk_upsert( df: api.DataFrame, - create_related: bool | None = Query(False), + create_related: bool = Query(False), backend: Backend = Depends(deps.get_backend), -): +) -> Response: backend.iamc.timeseries.bulk_upsert(df.to_pandas(), create_related=create_related) return Response(status_code=201) diff --git a/ixmp4/server/rest/iamc/unit.py b/ixmp4/server/rest/iamc/unit.py index 1085d01e..12bd1ac6 100644 --- a/ixmp4/server/rest/iamc/unit.py +++ b/ixmp4/server/rest/iamc/unit.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.unit.filter import IamcUnitFilter +from ixmp4.data.db.unit.model import Unit from .. import deps from ..base import EnumerationOutput, Pagination @@ -19,7 +20,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Unit]: return EnumerationOutput( results=backend.units.paginate( _filter=filter, diff --git a/ixmp4/server/rest/iamc/variable.py b/ixmp4/server/rest/iamc/variable.py index 7262ee7b..e80a7d53 100644 --- a/ixmp4/server/rest/iamc/variable.py +++ b/ixmp4/server/rest/iamc/variable.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.iamc.variable.filter import VariableFilter +from ixmp4.data.db.iamc.variable.model import Variable from .. import deps from ..base import BaseModel, EnumerationOutput, Pagination @@ -25,7 +26,7 @@ def query( table: bool = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Variable]: return EnumerationOutput( results=backend.iamc.variables.paginate( _filter=filter, @@ -43,5 +44,5 @@ def query( def create( variable: VariableInput, backend: Backend = Depends(deps.get_backend), -): +) -> Variable: return backend.iamc.variables.create(**variable.model_dump()) diff --git a/ixmp4/server/rest/meta.py b/ixmp4/server/rest/meta.py index 5f06e734..4251a295 100644 --- a/ixmp4/server/rest/meta.py +++ b/ixmp4/server/rest/meta.py @@ -4,6 +4,7 @@ from ixmp4.data import abstract, api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.meta.filter import RunMetaEntryFilter +from ixmp4.data.db.meta.model import RunMetaEntry from . import deps from .base import BaseModel, EnumerationOutput, Pagination @@ -27,7 +28,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[RunMetaEntry]: if join_run_index and not table: raise BadRequest("`join_run_index` can only be used with `table=true`.") @@ -48,7 +49,7 @@ def query( def create( runmeta: RunMetaEntryInput, backend: Backend = Depends(deps.get_backend), -): +) -> RunMetaEntry: return backend.meta.create(**runmeta.model_dump()) @@ -56,7 +57,7 @@ def create( def delete( id: int = Path(), backend: Backend = Depends(deps.get_backend), -): +) -> None: backend.meta.delete(id) @@ -64,13 +65,16 @@ def delete( def bulk_upsert( df: api.DataFrame, backend: Backend = Depends(deps.get_backend), -): - return backend.meta.bulk_upsert(df.to_pandas()) +) -> None: + # A pandera.DataFrame is a subclass of pd.DataFrame, so this is fine. Mypy likely + # complains because our decorators change the type hint in some incompatible way. + # Might be about covariance again. + backend.meta.bulk_upsert(df.to_pandas()) # type: ignore[arg-type] @router.patch("/bulk/") def bulk_delete( df: api.DataFrame, backend: Backend = Depends(deps.get_backend), -): - return backend.meta.bulk_delete(df.to_pandas()) +) -> None: + backend.meta.bulk_delete(df.to_pandas()) # type: ignore[arg-type] diff --git a/ixmp4/server/rest/middleware.py b/ixmp4/server/rest/middleware.py index aa91f33a..664f97ec 100644 --- a/ixmp4/server/rest/middleware.py +++ b/ixmp4/server/rest/middleware.py @@ -1,13 +1,17 @@ import logging import time +from collections.abc import Awaitable, Callable +from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware logger = logging.getLogger(__name__) class RequestTimeLoggerMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: start_time = time.time() response = await call_next(request) process_time = time.time() - start_time @@ -16,7 +20,9 @@ async def dispatch(self, request, call_next): class RequestSizeLoggerMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: body = await request.body() logger.debug(f"Request body size: {len(body)} bytes.") return await call_next(request) diff --git a/ixmp4/server/rest/model.py b/ixmp4/server/rest/model.py index 12fe83e0..fe370d23 100644 --- a/ixmp4/server/rest/model.py +++ b/ixmp4/server/rest/model.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.model.filter import ModelFilter +from ixmp4.data.db.model.model import Model from . import deps from .base import BaseModel, EnumerationOutput, Pagination @@ -25,7 +26,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Model]: return EnumerationOutput( results=backend.models.paginate( _filter=filter, @@ -43,5 +44,5 @@ def query( def create( model: ModelInput, backend: Backend = Depends(deps.get_backend), -): +) -> Model: return backend.models.create(**model.model_dump()) diff --git a/ixmp4/server/rest/optimization/equation.py b/ixmp4/server/rest/optimization/equation.py index 79dbdb88..d35d9e31 100644 --- a/ixmp4/server/rest/optimization/equation.py +++ b/ixmp4/server/rest/optimization/equation.py @@ -5,6 +5,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.optimization.equation.filter import EquationFilter +from ixmp4.data.db.optimization.equation.model import Equation from .. import deps from ..base import BaseModel, EnumerationOutput, Pagination @@ -17,8 +18,8 @@ class EquationCreateInput(BaseModel): - run_id: int name: str + run_id: int constrained_to_indexsets: list[str] column_names: list[str] | None @@ -32,7 +33,7 @@ class DataInput(BaseModel): def get_by_id( id: int, backend: Backend = Depends(deps.get_backend), -): +) -> Equation: return backend.optimization.equations.get_by_id(id) @@ -43,7 +44,7 @@ def query( table: bool = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Equation]: return EnumerationOutput( results=backend.optimization.equations.paginate( _filter=filter, @@ -62,8 +63,8 @@ def add_data( equation_id: int, data: DataInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.equations.add_data( +) -> None: + backend.optimization.equations.add_data( equation_id=equation_id, **data.model_dump() ) @@ -73,7 +74,7 @@ def add_data( def remove_data( equation_id: int, backend: Backend = Depends(deps.get_backend), -): +) -> None: backend.optimization.equations.remove_data(equation_id == equation_id) @@ -82,5 +83,5 @@ def remove_data( def create( equation: EquationCreateInput, backend: Backend = Depends(deps.get_backend), -): +) -> Equation: return backend.optimization.equations.create(**equation.model_dump()) diff --git a/ixmp4/server/rest/optimization/indexset.py b/ixmp4/server/rest/optimization/indexset.py index f6471586..97cda612 100644 --- a/ixmp4/server/rest/optimization/indexset.py +++ b/ixmp4/server/rest/optimization/indexset.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.optimization.indexset.filter import OptimizationIndexSetFilter +from ixmp4.data.db.optimization.indexset.model import IndexSet from .. import deps from ..base import BaseModel, EnumerationOutput, Pagination @@ -20,7 +21,7 @@ class IndexSetInput(BaseModel): class DataInput(BaseModel): - data: float | int | str | list[float | int | str] + data: float | int | str | list[float] | list[int] | list[str] @autodoc @@ -31,7 +32,7 @@ def query( include_data: bool = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[IndexSet]: return EnumerationOutput( results=backend.optimization.indexsets.paginate( _filter=filter, @@ -50,7 +51,7 @@ def query( def create( indexset: IndexSetInput, backend: Backend = Depends(deps.get_backend), -): +) -> IndexSet: return backend.optimization.indexsets.create(**indexset.model_dump()) @@ -60,7 +61,7 @@ def add_data( indexset_id: int, data: DataInput, backend: Backend = Depends(deps.get_backend), -): +) -> None: backend.optimization.indexsets.add_data( indexset_id=indexset_id, **data.model_dump() ) diff --git a/ixmp4/server/rest/optimization/parameter.py b/ixmp4/server/rest/optimization/parameter.py index 3f8993b3..c4818320 100644 --- a/ixmp4/server/rest/optimization/parameter.py +++ b/ixmp4/server/rest/optimization/parameter.py @@ -5,6 +5,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.optimization.parameter.filter import OptimizationParameterFilter +from ixmp4.data.db.optimization.parameter.model import Parameter from .. import deps from ..base import BaseModel, EnumerationOutput, Pagination @@ -17,8 +18,8 @@ class ParameterCreateInput(BaseModel): - run_id: int name: str + run_id: int constrained_to_indexsets: list[str] column_names: list[str] | None @@ -32,7 +33,7 @@ class DataInput(BaseModel): def get_by_id( id: int, backend: Backend = Depends(deps.get_backend), -): +) -> Parameter: return backend.optimization.parameters.get_by_id(id) @@ -45,7 +46,7 @@ def query( table: bool = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Parameter]: return EnumerationOutput( results=backend.optimization.parameters.paginate( _filter=filter, @@ -64,8 +65,8 @@ def add_data( parameter_id: int, data: DataInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.parameters.add_data( +) -> None: + backend.optimization.parameters.add_data( parameter_id=parameter_id, **data.model_dump() ) @@ -75,5 +76,5 @@ def add_data( def create( parameter: ParameterCreateInput, backend: Backend = Depends(deps.get_backend), -): +) -> Parameter: return backend.optimization.parameters.create(**parameter.model_dump()) diff --git a/ixmp4/server/rest/optimization/scalar.py b/ixmp4/server/rest/optimization/scalar.py index 55a1e657..fad8126d 100644 --- a/ixmp4/server/rest/optimization/scalar.py +++ b/ixmp4/server/rest/optimization/scalar.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.optimization.scalar.filter import OptimizationScalarFilter +from ixmp4.data.db.optimization.scalar.model import Scalar from .. import deps from ..base import BaseModel, EnumerationOutput, Pagination @@ -31,7 +32,7 @@ class ScalarUpdateInput(BaseModel): def get_by_id( id: int, backend: Backend = Depends(deps.get_backend), -): +) -> Scalar: return backend.optimization.scalars.get_by_id(id) @@ -44,7 +45,7 @@ def query( table: bool = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Scalar]: return EnumerationOutput( results=backend.optimization.scalars.paginate( _filter=filter, @@ -63,7 +64,7 @@ def update( id: int, scalar: ScalarUpdateInput, backend: Backend = Depends(deps.get_backend), -): +) -> Scalar: return backend.optimization.scalars.update(id, **scalar.model_dump()) @@ -72,5 +73,5 @@ def update( def create( scalar: ScalarCreateInput, backend: Backend = Depends(deps.get_backend), -): +) -> Scalar: return backend.optimization.scalars.create(**scalar.model_dump()) diff --git a/ixmp4/server/rest/optimization/table.py b/ixmp4/server/rest/optimization/table.py index 1f3bcb11..9b551cb6 100644 --- a/ixmp4/server/rest/optimization/table.py +++ b/ixmp4/server/rest/optimization/table.py @@ -5,6 +5,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.optimization.table.filter import OptimizationTableFilter +from ixmp4.data.db.optimization.table.model import Table from .. import deps from ..base import BaseModel, EnumerationOutput, Pagination @@ -17,8 +18,8 @@ class TableCreateInput(BaseModel): - run_id: int name: str + run_id: int constrained_to_indexsets: list[str] column_names: list[str] | None @@ -32,7 +33,7 @@ class DataInput(BaseModel): def get_by_id( id: int, backend: Backend = Depends(deps.get_backend), -): +) -> Table: return backend.optimization.tables.get_by_id(id) @@ -43,7 +44,7 @@ def query( table: bool = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Table]: return EnumerationOutput( results=backend.optimization.tables.paginate( _filter=filter, @@ -62,8 +63,8 @@ def add_data( table_id: int, data: DataInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.tables.add_data(table_id=table_id, **data.model_dump()) +) -> None: + backend.optimization.tables.add_data(table_id=table_id, **data.model_dump()) @autodoc @@ -71,5 +72,5 @@ def add_data( def create( table: TableCreateInput, backend: Backend = Depends(deps.get_backend), -): +) -> Table: return backend.optimization.tables.create(**table.model_dump()) diff --git a/ixmp4/server/rest/optimization/variable.py b/ixmp4/server/rest/optimization/variable.py index 6a4a09f5..ed4a89fb 100644 --- a/ixmp4/server/rest/optimization/variable.py +++ b/ixmp4/server/rest/optimization/variable.py @@ -5,6 +5,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.optimization.variable.filter import OptimizationVariableFilter +from ixmp4.data.db.optimization.variable.model import OptimizationVariable from .. import deps from ..base import BaseModel, EnumerationOutput, Pagination @@ -17,8 +18,8 @@ class VariableCreateInput(BaseModel): - run_id: int name: str + run_id: int constrained_to_indexsets: str | list[str] | None column_names: list[str] | None @@ -32,7 +33,7 @@ class DataInput(BaseModel): def get_by_id( id: int, backend: Backend = Depends(deps.get_backend), -): +) -> OptimizationVariable: return backend.optimization.variables.get_by_id(id) @@ -45,7 +46,7 @@ def query( table: bool = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[OptimizationVariable]: return EnumerationOutput( results=backend.optimization.variables.paginate( _filter=filter, @@ -64,8 +65,8 @@ def add_data( variable_id: int, data: DataInput, backend: Backend = Depends(deps.get_backend), -): - return backend.optimization.variables.add_data( +) -> None: + backend.optimization.variables.add_data( variable_id=variable_id, **data.model_dump() ) @@ -75,7 +76,7 @@ def add_data( def remove_data( variable_id: int, backend: Backend = Depends(deps.get_backend), -): +) -> None: backend.optimization.variables.remove_data(variable_id=variable_id) @@ -84,5 +85,5 @@ def remove_data( def create( variable: VariableCreateInput, backend: Backend = Depends(deps.get_backend), -): +) -> OptimizationVariable: return backend.optimization.variables.create(**variable.model_dump()) diff --git a/ixmp4/server/rest/region.py b/ixmp4/server/rest/region.py index 730b4b16..82ec9f97 100644 --- a/ixmp4/server/rest/region.py +++ b/ixmp4/server/rest/region.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.region.filter import RegionFilter +from ixmp4.data.db.region.model import Region from . import deps from .base import BaseModel, EnumerationOutput, Pagination @@ -26,7 +27,7 @@ def query( table: bool = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Region]: return EnumerationOutput( results=backend.regions.paginate( _filter=filter, @@ -44,7 +45,7 @@ def query( def create( region: RegionInput, backend: Backend = Depends(deps.get_backend), -): +) -> Region: return backend.regions.create(**region.model_dump()) @@ -52,5 +53,5 @@ def create( def delete( id: int = Path(), backend: Backend = Depends(deps.get_backend), -): +) -> None: backend.regions.delete(id) diff --git a/ixmp4/server/rest/run.py b/ixmp4/server/rest/run.py index 19af8057..1457db83 100644 --- a/ixmp4/server/rest/run.py +++ b/ixmp4/server/rest/run.py @@ -4,6 +4,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.run.filter import RunFilter +from ixmp4.data.db.run.model import Run from . import deps from .base import BaseModel, EnumerationOutput, Pagination @@ -25,7 +26,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Run]: return EnumerationOutput( results=backend.runs.paginate( _filter=filter, @@ -42,7 +43,7 @@ def query( def create( run: RunInput, backend: Backend = Depends(deps.get_backend), -): +) -> Run: return backend.runs.create(**run.model_dump(by_alias=True)) @@ -50,7 +51,7 @@ def create( def set_as_default_version( id: int = Path(), backend: Backend = Depends(deps.get_backend), -): +) -> None: backend.runs.set_as_default_version(id) @@ -58,5 +59,5 @@ def set_as_default_version( def unset_as_default_version( id: int = Path(), backend: Backend = Depends(deps.get_backend), -): +) -> None: backend.runs.unset_as_default_version(id) diff --git a/ixmp4/server/rest/scenario.py b/ixmp4/server/rest/scenario.py index af3f85c6..c0d4e027 100644 --- a/ixmp4/server/rest/scenario.py +++ b/ixmp4/server/rest/scenario.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.scenario.filter import ScenarioFilter +from ixmp4.data.db.scenario.model import Scenario from . import deps from .base import BaseModel, EnumerationOutput, Pagination @@ -25,7 +26,7 @@ def query( table: bool | None = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Scenario]: return EnumerationOutput( results=backend.scenarios.paginate( _filter=filter, @@ -43,5 +44,5 @@ def query( def create( scenario: ScenarioInput, backend: Backend = Depends(deps.get_backend), -): +) -> Scenario: return backend.scenarios.create(**scenario.model_dump()) diff --git a/ixmp4/server/rest/unit.py b/ixmp4/server/rest/unit.py index 5d0cc21c..a8d75616 100644 --- a/ixmp4/server/rest/unit.py +++ b/ixmp4/server/rest/unit.py @@ -3,6 +3,7 @@ from ixmp4.data import api from ixmp4.data.backend.db import SqlAlchemyBackend as Backend from ixmp4.data.db.unit.filter import UnitFilter +from ixmp4.data.db.unit.model import Unit from . import deps from .base import BaseModel, EnumerationOutput, Pagination @@ -25,7 +26,7 @@ def query( table: bool = Query(False), pagination: Pagination = Depends(), backend: Backend = Depends(deps.get_backend), -): +) -> EnumerationOutput[Unit]: return EnumerationOutput( results=backend.units.paginate( _filter=filter, @@ -42,7 +43,7 @@ def query( def create( unit: UnitInput, backend: Backend = Depends(deps.get_backend), -): +) -> Unit: return backend.units.create(**unit.model_dump()) @@ -50,5 +51,5 @@ def create( def delete( id: int = Path(), backend: Backend = Depends(deps.get_backend), -): +) -> None: backend.units.delete(id) diff --git a/poetry.lock b/poetry.lock index f4123acf..218c1e20 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1236,47 +1236,53 @@ files = [ [[package]] name = "mypy" -version = "1.10.1" +version = "1.13.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" files = [ - {file = "mypy-1.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e36f229acfe250dc660790840916eb49726c928e8ce10fbdf90715090fe4ae02"}, - {file = "mypy-1.10.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:51a46974340baaa4145363b9e051812a2446cf583dfaeba124af966fa44593f7"}, - {file = "mypy-1.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:901c89c2d67bba57aaaca91ccdb659aa3a312de67f23b9dfb059727cce2e2e0a"}, - {file = "mypy-1.10.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0cd62192a4a32b77ceb31272d9e74d23cd88c8060c34d1d3622db3267679a5d9"}, - {file = "mypy-1.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:a2cbc68cb9e943ac0814c13e2452d2046c2f2b23ff0278e26599224cf164e78d"}, - {file = "mypy-1.10.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bd6f629b67bb43dc0d9211ee98b96d8dabc97b1ad38b9b25f5e4c4d7569a0c6a"}, - {file = "mypy-1.10.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a1bbb3a6f5ff319d2b9d40b4080d46cd639abe3516d5a62c070cf0114a457d84"}, - {file = "mypy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8edd4e9bbbc9d7b79502eb9592cab808585516ae1bcc1446eb9122656c6066f"}, - {file = "mypy-1.10.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6166a88b15f1759f94a46fa474c7b1b05d134b1b61fca627dd7335454cc9aa6b"}, - {file = "mypy-1.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:5bb9cd11c01c8606a9d0b83ffa91d0b236a0e91bc4126d9ba9ce62906ada868e"}, - {file = "mypy-1.10.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d8681909f7b44d0b7b86e653ca152d6dff0eb5eb41694e163c6092124f8246d7"}, - {file = "mypy-1.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:378c03f53f10bbdd55ca94e46ec3ba255279706a6aacaecac52ad248f98205d3"}, - {file = "mypy-1.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bacf8f3a3d7d849f40ca6caea5c055122efe70e81480c8328ad29c55c69e93e"}, - {file = "mypy-1.10.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:701b5f71413f1e9855566a34d6e9d12624e9e0a8818a5704d74d6b0402e66c04"}, - {file = "mypy-1.10.1-cp312-cp312-win_amd64.whl", hash = "sha256:3c4c2992f6ea46ff7fce0072642cfb62af7a2484efe69017ed8b095f7b39ef31"}, - {file = "mypy-1.10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:604282c886497645ffb87b8f35a57ec773a4a2721161e709a4422c1636ddde5c"}, - {file = "mypy-1.10.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37fd87cab83f09842653f08de066ee68f1182b9b5282e4634cdb4b407266bade"}, - {file = "mypy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8addf6313777dbb92e9564c5d32ec122bf2c6c39d683ea64de6a1fd98b90fe37"}, - {file = "mypy-1.10.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5cc3ca0a244eb9a5249c7c583ad9a7e881aa5d7b73c35652296ddcdb33b2b9c7"}, - {file = "mypy-1.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:1b3a2ffce52cc4dbaeee4df762f20a2905aa171ef157b82192f2e2f368eec05d"}, - {file = "mypy-1.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fe85ed6836165d52ae8b88f99527d3d1b2362e0cb90b005409b8bed90e9059b3"}, - {file = "mypy-1.10.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c2ae450d60d7d020d67ab440c6e3fae375809988119817214440033f26ddf7bf"}, - {file = "mypy-1.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6be84c06e6abd72f960ba9a71561c14137a583093ffcf9bbfaf5e613d63fa531"}, - {file = "mypy-1.10.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2189ff1e39db399f08205e22a797383613ce1cb0cb3b13d8bcf0170e45b96cc3"}, - {file = "mypy-1.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:97a131ee36ac37ce9581f4220311247ab6cba896b4395b9c87af0675a13a755f"}, - {file = "mypy-1.10.1-py3-none-any.whl", hash = "sha256:71d8ac0b906354ebda8ef1673e5fde785936ac1f29ff6987c7483cfbd5a4235a"}, - {file = "mypy-1.10.1.tar.gz", hash = "sha256:1f8f492d7db9e3593ef42d4f115f04e556130f2819ad33ab84551403e97dd4c0"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, ] [package.dependencies] mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=4.1.0" +typing-extensions = ">=4.6.0" [package.extras] dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] install-types = ["pip"] mypyc = ["setuptools (>=50)"] reports = ["lxml"] @@ -1395,40 +1401,53 @@ files = [ [[package]] name = "pandas" -version = "2.2.2" +version = "2.2.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" files = [ - {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, - {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, - {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, - {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, - {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, - {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, - {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, - {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, - {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, - {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d9c45366def9a3dd85a6454c0e7908f2b3b8e9c138f5dc38fed7ce720d8453ed"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86976a1c5b25ae3f8ccae3a5306e443569ee3c3faf444dfd0f41cda24667ad57"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b8661b0238a69d7aafe156b7fa86c44b881387509653fdf857bebc5e4008ad42"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37e0aced3e8f539eccf2e099f65cdb9c8aa85109b0be6e93e2baff94264bdc6f"}, + {file = "pandas-2.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:56534ce0746a58afaf7942ba4863e0ef81c9c50d3f0ae93e9497d6a41a057645"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32"}, + {file = "pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a"}, + {file = "pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb"}, + {file = "pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc6b93f9b966093cb0fd62ff1a7e4c09e6d546ad7c1de191767baffc57628f39"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5dbca4c1acd72e8eeef4753eeca07de9b1db4f398669d5994086f788a5d7cc30"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8cd6d7cc958a3910f934ea8dbdf17b2364827bb4dafc38ce6eef6bb3d65ff09c"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99df71520d25fade9db7c1076ac94eb994f4d2673ef2aa2e86ee039b6746d20c"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31d0ced62d4ea3e231a9f228366919a5ea0b07440d9d4dac345376fd8e1477ea"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7eee9e7cea6adf3e3d24e304ac6b8300646e2a5d1cd3a3c2abed9101b0846761"}, + {file = "pandas-2.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:4850ba03528b6dd51d6c5d273c46f183f39a9baf3f0143e566b89450965b105e"}, + {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, ] [package.dependencies] @@ -3542,4 +3561,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10, <3.14" -content-hash = "2f176246cdb32b28ac4308370e2f439b74cb8d87be964f9c5238cb6906fbacdb" +content-hash = "87c22be1bda398513036f2bcf4e3d9defcbbcfde34b601a8d9cf0a2d03516610" diff --git a/pyproject.toml b/pyproject.toml index 6fb24f33..6d729319 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ optional = true [tool.poetry.group.dev.dependencies] build = ">=1.0.3" -mypy = ">=1.0.0" +mypy = ">=1.13.0" pandas-stubs = ">=2.0.0.230412" pre-commit = ">=3.3.3" ptvsd = ">=4.3.2" @@ -83,20 +83,44 @@ requires = ["poetry-core>=1.2.0", "poetry-dynamic-versioning"] [tool.mypy] exclude = [ - '^example\.py$', - '^import\.py$', - '^tests\/', - '^doc\/', '^ixmp4\/db\/migrations\/', ] disable_error_code = ['override'] -implicit_reexport = true -plugins = ['sqlalchemy.ext.mypy.plugin'] +show_error_codes = true +plugins = ['numpy.typing.mypy_plugin', 'pandera.mypy', 'pydantic.mypy', 'sqlalchemy.ext.mypy.plugin'] +# The following are equivalent to --strict mypy as seen in +# https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true +strict_equality = true +extra_checks = true +check_untyped_defs = true +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_any_generics = true +disallow_untyped_calls = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +#TODO enable this in a follow-up PR to satisfy --strict +# no_implicit_reexport = true +warn_return_any = true +# These are bonus, it seems: +disallow_any_unimported = true +no_implicit_optional = true +warn_unreachable = true [[tool.mypy.overrides]] -module = ["pandas", "uvicorn.workers", "sqlalchemy_utils"] +# Removing this introduces several errors +module = ["uvicorn.workers"] +# Without this, mypy is still fine, but pyproject.toml complains ignore_missing_imports = true +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true + [tool.ruff] exclude = [ ".git", @@ -122,12 +146,19 @@ ignore = ["B008"] # Ignore unused imports: "__init__.py" = ["F401"] "ixmp4/data/db/optimization/base.py" = ["F401"] +"ixmp4/data/db/iamc/base.py"= ["F401"] # Ignore importing * and resulting possibly missing imports: "ixmp4/db/__init__.py" = ["F403", "F405"] [tool.ruff.lint.mccabe] max-complexity = 10 +[tool.coverage.report] +exclude_also = [ + # Imports only used by type checkers + "if TYPE_CHECKING:", +] + [tool.poetry-dynamic-versioning] enable = true style = "pep440" diff --git a/tests/conf/test_toml.py b/tests/conf/test_toml.py index 550978db..1eae26dc 100644 --- a/tests/conf/test_toml.py +++ b/tests/conf/test_toml.py @@ -23,13 +23,13 @@ class HasPath(Protocol): class TomlTest: - def assert_toml_file(self, toml_config: HasPath, expected_toml: str): + def assert_toml_file(self, toml_config: HasPath, expected_toml: str) -> None: with toml_config.path.open() as f: assert f.read() == expected_toml class TestTomlPlatforms(TomlTest): - def test_add_platform(self, toml_config: TomlConfig): + def test_add_platform(self, toml_config: TomlConfig) -> None: toml_config.add_platform("test", "test://test/") expected_toml = '[test]\ndsn = "test://test/"\n' @@ -42,11 +42,11 @@ def test_add_platform(self, toml_config: TomlConfig): ) self.assert_toml_file(toml_config, expected_toml) - def test_platform_unique(self, toml_config: TomlConfig): + def test_platform_unique(self, toml_config: TomlConfig) -> None: with pytest.raises(PlatformNotUnique): toml_config.add_platform("test", "test://test/") - def test_remove_platform(self, toml_config: TomlConfig): + def test_remove_platform(self, toml_config: TomlConfig) -> None: toml_config.remove_platform("test") expected_toml = '[test2]\ndsn = "test2://test2/"\n' @@ -59,7 +59,7 @@ def test_remove_platform(self, toml_config: TomlConfig): with toml_config.path.open() as f: assert f.read() == expected_toml - def test_remove_missing_platform(self, toml_config: TomlConfig): + def test_remove_missing_platform(self, toml_config: TomlConfig) -> None: with pytest.raises(PlatformNotFound): toml_config.remove_platform("test") @@ -73,16 +73,16 @@ def credentials() -> Credentials: class TestTomlCredentials(TomlTest): - def test_set_credentials(self, credentials): + def test_set_credentials(self, credentials: Credentials) -> None: credentials.set("test", "user", "password") expected_toml = '[test]\nusername = "user"\npassword = "password"\n' self.assert_toml_file(credentials, expected_toml) - def test_get_credentials(self, credentials): + def test_get_credentials(self, credentials: Credentials) -> None: ret = credentials.get("test") assert ret == ("user", "password") - def test_clear_credentials(self, credentials): + def test_clear_credentials(self, credentials: Credentials) -> None: credentials.clear("test") expected_toml = "" self.assert_toml_file(credentials, expected_toml) @@ -90,7 +90,7 @@ def test_clear_credentials(self, credentials): # clearing non-exsistent credentials is fine credentials.clear("test") - def test_add_credentials(self, credentials): + def test_add_credentials(self, credentials: Credentials) -> None: credentials.set("test", "user", "password") expected_toml = '[test]\nusername = "user"\npassword = "password"\n' self.assert_toml_file(credentials, expected_toml) diff --git a/tests/conftest.py b/tests/conftest.py index 6249ff1d..34ec09c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ import cProfile import pstats -from contextlib import contextmanager +from collections.abc import Callable, Generator +from contextlib import _GeneratorContextManager, contextmanager from pathlib import Path +from typing import Any, TypeAlias import pytest @@ -26,7 +28,7 @@ } -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: """Called to set up the pytest command line parser. We can add our own options here.""" @@ -51,7 +53,7 @@ def __init__(self, postgres_dsn: str) -> None: self.postgres_dsn = postgres_dsn @contextmanager - def rest_sqlite(self): + def rest_sqlite(self) -> Generator[RestTestBackend, Any, None]: with self.sqlite() as backend: rest = RestTestBackend(backend) rest.setup() @@ -60,7 +62,7 @@ def rest_sqlite(self): rest.teardown() @contextmanager - def rest_postgresql(self): + def rest_postgresql(self) -> Generator[RestTestBackend, Any, None]: with self.postgresql() as backend: rest = RestTestBackend(backend) rest.setup() @@ -69,7 +71,7 @@ def rest_postgresql(self): rest.teardown() @contextmanager - def postgresql(self): + def postgresql(self) -> Generator[PostgresTestBackend, Any, None]: pgsql = PostgresTestBackend( PlatformInfo( name="postgres-test", @@ -82,7 +84,7 @@ def postgresql(self): pgsql.teardown() @contextmanager - def sqlite(self): + def sqlite(self) -> Generator[SqliteTestBackend, Any, None]: sqlite = SqliteTestBackend( PlatformInfo(name="sqlite-test", dsn="sqlite:///:memory:") ) @@ -92,9 +94,20 @@ def sqlite(self): sqlite.teardown() -def get_backend_context(type, postgres_dsn): +def get_backend_context( + type: str, postgres_dsn: str +) -> ( + _GeneratorContextManager[RestTestBackend] + | _GeneratorContextManager[PostgresTestBackend] + | _GeneratorContextManager[SqliteTestBackend] +): backends = Backends(postgres_dsn) + bctx: ( + _GeneratorContextManager[RestTestBackend] + | _GeneratorContextManager[PostgresTestBackend] + | _GeneratorContextManager[SqliteTestBackend] + ) if type == "rest-sqlite": bctx = backends.rest_sqlite() elif type == "rest-postgres": @@ -106,7 +119,7 @@ def get_backend_context(type, postgres_dsn): return bctx -def platform_fixture(request): +def platform_fixture(request: pytest.FixtureRequest) -> Generator[Platform, Any, None]: type = request.param postgres_dsn = request.config.option.postgres_dsn bctx = get_backend_context(type, postgres_dsn) @@ -125,8 +138,12 @@ def platform_fixture(request): medium = MediumIamcDataset() -def td_platform_fixture(td): - def platform_with_td(request): +def td_platform_fixture( + td: BigIamcDataset | MediumIamcDataset, +) -> Callable[[pytest.FixtureRequest], Generator[Platform, Any, None]]: + def platform_with_td( + request: pytest.FixtureRequest, + ) -> Generator[Platform, Any, None]: type = request.param postgres_dsn = request.config.option.postgres_dsn bctx = get_backend_context(type, postgres_dsn) @@ -153,7 +170,7 @@ def platform_with_td(request): ) -def pytest_generate_tests(metafunc): +def pytest_generate_tests(metafunc: pytest.Metafunc) -> Any: # This is called for every test. Only get/set command line arguments # if the argument is specified in the list of test "fixturenames". @@ -173,7 +190,9 @@ def pytest_generate_tests(metafunc): @pytest.fixture(scope="function") -def profiled(request): +def profiled( + request: pytest.FixtureRequest, +) -> Generator[Callable[[], _GeneratorContextManager[None]]]: """Use this fixture for profiling tests: ``` def test(profiled): @@ -189,7 +208,7 @@ def test(profiled): pr = cProfile.Profile() @contextmanager - def profiled(): + def profiled() -> Generator[None, Any, None]: pr.enable() yield pr.disable() @@ -198,3 +217,6 @@ def profiled(): ps = pstats.Stats(pr) Path(".profiles").mkdir(parents=True, exist_ok=True) ps.dump_stats(f".profiles/{testname}.prof") + + +Profiled: TypeAlias = Callable[[], _GeneratorContextManager[None]] diff --git a/tests/core/test_iamc.py b/tests/core/test_iamc.py index 581cf356..9bea5854 100644 --- a/tests/core/test_iamc.py +++ b/tests/core/test_iamc.py @@ -1,5 +1,7 @@ import asyncio +from collections.abc import Iterable +import pandas as pd import pytest import ixmp4 @@ -17,12 +19,12 @@ class TestCoreIamc: small = SmallIamcDataset() filter = FilterIamcDataset() - def test_run_annual_datapoints_raw(self, platform: ixmp4.Platform): + def test_run_annual_datapoints_raw(self, platform: ixmp4.Platform) -> None: self.do_run_datapoints( platform, self.small.annual.copy(), True, DataPoint.Type.ANNUAL ) - def test_run_annual_datapoints_iamc(self, platform: ixmp4.Platform): + def test_run_annual_datapoints_iamc(self, platform: ixmp4.Platform) -> None: # convert to test data to standard IAMC format df = self.small.annual.copy().rename(columns={"step_year": "year"}) self.do_run_datapoints(platform, df, False) @@ -31,14 +33,14 @@ def test_run_annual_datapoints_iamc(self, platform: ixmp4.Platform): "invalid_type", (DataPoint.Type.CATEGORICAL, DataPoint.Type.DATETIME) ) def test_run_inconsistent_annual_raises( - self, platform: ixmp4.Platform, invalid_type - ): + self, platform: ixmp4.Platform, invalid_type: DataPoint.Type + ) -> None: with pytest.raises(SchemaError): self.do_run_datapoints( platform, self.small.annual.copy(), True, invalid_type ) - def test_run_categorical_datapoints_raw(self, platform: ixmp4.Platform): + def test_run_categorical_datapoints_raw(self, platform: ixmp4.Platform) -> None: self.do_run_datapoints( platform, self.small.categorical.copy(), True, DataPoint.Type.CATEGORICAL ) @@ -47,14 +49,14 @@ def test_run_categorical_datapoints_raw(self, platform: ixmp4.Platform): "invalid_type", (DataPoint.Type.ANNUAL, DataPoint.Type.DATETIME) ) def test_run_inconsistent_categorical_raises( - self, platform: ixmp4.Platform, invalid_type - ): + self, platform: ixmp4.Platform, invalid_type: DataPoint.Type + ) -> None: with pytest.raises(SchemaError): self.do_run_datapoints( platform, self.small.categorical.copy(), True, invalid_type ) - def test_run_datetime_datapoints_raw(self, platform: ixmp4.Platform): + def test_run_datetime_datapoints_raw(self, platform: ixmp4.Platform) -> None: self.do_run_datapoints( platform, self.small.datetime.copy(), True, DataPoint.Type.DATETIME ) @@ -63,20 +65,26 @@ def test_run_datetime_datapoints_raw(self, platform: ixmp4.Platform): "invalid_type", (DataPoint.Type.ANNUAL, DataPoint.Type.CATEGORICAL) ) def test_run_inconsistent_datetime_type_raises( - self, platform: ixmp4.Platform, invalid_type - ): + self, platform: ixmp4.Platform, invalid_type: DataPoint.Type + ) -> None: with pytest.raises(SchemaError): self.do_run_datapoints( platform, self.small.datetime.copy(), True, invalid_type ) - def test_unit_dimensionless_raw(self, platform: ixmp4.Platform): + def test_unit_dimensionless_raw(self, platform: ixmp4.Platform) -> None: test_data = self.small.annual.copy() test_data.loc[0, "unit"] = "" platform.units.create("") self.do_run_datapoints(platform, test_data, True, DataPoint.Type.ANNUAL) - def do_run_datapoints(self, platform: ixmp4.Platform, data, raw=True, _type=None): + def do_run_datapoints( + self, + platform: ixmp4.Platform, + data: pd.DataFrame, + raw: bool = True, + _type: DataPoint.Type | None = None, + ) -> None: # Test adding, updating, removing data to a run # either as ixmp4-database format (columns `step_[year/datetime/categorical]`) # or as standard iamc format (column names 'year' or 'time') @@ -203,24 +211,28 @@ def do_run_datapoints(self, platform: ixmp4.Platform, data, raw=True, _type=None ), ) def test_run_tabulate_with_filter_raw( - self, platform: ixmp4.Platform, filters, run, exp_len - ): + self, + platform: ixmp4.Platform, + filters: dict[str, dict[str, str | Iterable[str]]], + run: tuple[str, str, int], + exp_len: int, + ) -> None: self.filter.load_dataset(platform) - run = platform.runs.get(*run) - obs = run.iamc.tabulate(raw=True, **filters) + _run = platform.runs.get(*run) + obs = _run.iamc.tabulate(raw=True, **filters) assert len(obs) == exp_len class TestCoreIamcReadOnly: - def test_mp_tabulate_big_async(self, platform_med: ixmp4.Platform): + def test_mp_tabulate_big_async(self, platform_med: ixmp4.Platform) -> None: """Tests if big tabulations work in async contexts.""" - async def tabulate(): + async def tabulate() -> pd.DataFrame: return platform_med.iamc.tabulate(raw=True, run={"default_only": False}) res = asyncio.run(tabulate()) assert len(res) > settings.default_page_size - def test_mp_tabulate_big(self, platform_med: ixmp4.Platform): + def test_mp_tabulate_big(self, platform_med: ixmp4.Platform) -> None: res = platform_med.iamc.tabulate(raw=True, run={"default_only": False}) assert len(res) > settings.default_page_size diff --git a/tests/core/test_meta.py b/tests/core/test_meta.py index da038a6a..7eaf9226 100644 --- a/tests/core/test_meta.py +++ b/tests/core/test_meta.py @@ -8,12 +8,14 @@ EXP_META_COLS = ["model", "scenario", "version", "key", "value"] -def test_run_meta(platform: ixmp4.Platform): +def test_run_meta(platform: ixmp4.Platform) -> None: run1 = platform.runs.create("Model 1", "Scenario 1") run1.set_as_default() # set and update different types of meta indicators - run1.meta = {"mint": 13, "mfloat": 0.0, "mstr": "foo"} + # NOTE mypy doesn't support setters taking a different type than their property + # https://github.com/python/mypy/issues/3004 + run1.meta = {"mint": 13, "mfloat": 0.0, "mstr": "foo"} # type: ignore[assignment] run1.meta["mfloat"] = -1.9 run2 = platform.runs.get("Model 1", "Scenario 1") @@ -34,7 +36,7 @@ def test_run_meta(platform: ixmp4.Platform): pdt.assert_frame_equal(platform.meta.tabulate(run_id=1), exp) # remove all meta indicators and set a new indicator - run1.meta = {"mnew": "bar"} + run1.meta = {"mnew": "bar"} # type: ignore[assignment] run2 = platform.runs.get("Model 1", "Scenario 1") @@ -60,8 +62,8 @@ def test_run_meta(platform: ixmp4.Platform): pdt.assert_frame_equal(platform.meta.tabulate(run_id=1), exp, check_dtype=False) run2 = platform.runs.create("Model 2", "Scenario 2") - run1.meta = {"mstr": "baz"} - run2.meta = {"mfloat": 3.1415926535897} + run1.meta = {"mstr": "baz"} # type: ignore[assignment] + run2.meta = {"mfloat": 3.1415926535897} # type: ignore[assignment] # test default_only run filter exp = pd.DataFrame( @@ -94,7 +96,7 @@ def test_run_meta(platform: ixmp4.Platform): ) # test filter by key - run1.meta = {"mstr": "baz", "mfloat": 3.1415926535897} + run1.meta = {"mstr": "baz", "mfloat": 3.1415926535897} # type: ignore[assignment] exp = pd.DataFrame( [["Model 1", "Scenario 1", 1, "mstr", "baz"]], columns=EXP_META_COLS ) @@ -109,18 +111,22 @@ def test_run_meta(platform: ixmp4.Platform): ], ) def test_run_meta_numpy( - platform: ixmp4.Platform, npvalue1, pyvalue1, npvalue2, pyvalue2 -): + platform: ixmp4.Platform, + npvalue1: np.int64 | np.float64, + pyvalue1: int | float, + npvalue2: np.int64 | np.float64, + pyvalue2: int | float, +) -> None: """Test that numpy types are cast to simple types""" run1 = platform.runs.create("Model", "Scenario") run1.set_as_default() # set multiple meta indicators of same type ("value"-column of numpy-type) - run1.meta = {"key": npvalue1, "other key": npvalue1} + run1.meta = {"key": npvalue1, "other key": npvalue1} # type: ignore[assignment] assert run1.meta["key"] == pyvalue1 # set meta indicators of different types ("value"-column of type `object`) - run1.meta = {"key": npvalue1, "other key": "some value"} + run1.meta = {"key": npvalue1, "other key": "some value"} # type: ignore[assignment] assert run1.meta["key"] == pyvalue1 # set meta via setter @@ -133,13 +139,13 @@ def test_run_meta_numpy( @pytest.mark.parametrize("nonevalue", (None, np.nan)) -def test_run_meta_none(platform, nonevalue): +def test_run_meta_none(platform: ixmp4.Platform, nonevalue: float | None) -> None: """Test that None-values are handled correctly""" run1 = platform.runs.create("Model", "Scenario") run1.set_as_default() # set multiple indicators where one value is None - run1.meta = {"mint": 13, "mnone": nonevalue} + run1.meta = {"mint": 13, "mnone": nonevalue} # type: ignore[assignment] assert run1.meta["mint"] == 13 with pytest.raises(KeyError, match="'mnone'"): run1.meta["mnone"] @@ -154,7 +160,7 @@ def test_run_meta_none(platform, nonevalue): assert not dict(platform.runs.get("Model", "Scenario").meta) -def test_platform_meta_empty(platform: ixmp4.Platform): +def test_platform_meta_empty(platform: ixmp4.Platform) -> None: """Test that an empty dataframe is returned if there are no scenarios""" exp = pd.DataFrame([], columns=["model", "scenario", "version", "key", "value"]) pdt.assert_frame_equal(platform.meta.tabulate(), exp) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 7fcb6407..fefd61e0 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable + import pandas as pd import pytest @@ -7,13 +9,13 @@ from ..utils import assert_unordered_equality -def create_testcase_models(test_mp): +def create_testcase_models(test_mp: ixmp4.Platform) -> tuple[Model, Model]: model = test_mp.models.create("Model") model2 = test_mp.models.create("Model 2") return model, model2 -def df_from_list(models): +def df_from_list(models: Iterable[Model]) -> pd.DataFrame: return pd.DataFrame( [[m.id, m.name, m.created_at, m.created_by] for m in models], columns=["id", "name", "created_at", "created_by"], @@ -21,19 +23,19 @@ def df_from_list(models): class TestCoreModel: - def test_retrieve_model(self, platform: ixmp4.Platform): + def test_retrieve_model(self, platform: ixmp4.Platform) -> None: model1 = platform.models.create("Model") model2 = platform.models.get("Model") assert model1.id == model2.id - def test_model_unqiue(self, platform: ixmp4.Platform): + def test_model_unqiue(self, platform: ixmp4.Platform) -> None: platform.models.create("Model") with pytest.raises(Model.NotUnique): platform.models.create("Model") - def test_list_model(self, platform: ixmp4.Platform): + def test_list_model(self, platform: ixmp4.Platform) -> None: models = create_testcase_models(platform) model, _ = models @@ -45,7 +47,7 @@ def test_list_model(self, platform: ixmp4.Platform): b = [m.id for m in platform.models.list(name="Model")] assert not (set(a) ^ set(b)) - def test_tabulate_model(self, platform: ixmp4.Platform): + def test_tabulate_model(self, platform: ixmp4.Platform) -> None: models = create_testcase_models(platform) model, _ = models @@ -57,7 +59,7 @@ def test_tabulate_model(self, platform: ixmp4.Platform): b = platform.models.tabulate(name="Model") assert_unordered_equality(a, b, check_dtype=False) - def test_retrieve_docs(self, platform: ixmp4.Platform): + def test_retrieve_docs(self, platform: ixmp4.Platform) -> None: platform.models.create("Model") docs_model1 = platform.models.set_docs("Model", "Description of test Model") docs_model2 = platform.models.get_docs("Model") @@ -71,7 +73,7 @@ def test_retrieve_docs(self, platform: ixmp4.Platform): assert platform.models.get_docs("Model2") == model2.docs - def test_delete_docs(self, platform: ixmp4.Platform): + def test_delete_docs(self, platform: ixmp4.Platform) -> None: model = platform.models.create("Model") model.docs = "Description of test Model" model.docs = None @@ -83,7 +85,8 @@ def test_delete_docs(self, platform: ixmp4.Platform): assert model.docs is None - model.docs = "Third description of test Model" + # Mypy doesn't recognize del properly, it seems + model.docs = "Third description of test Model" # type: ignore[unreachable] platform.models.delete_docs("Model") assert model.docs is None diff --git a/tests/core/test_optimization_equation.py b/tests/core/test_optimization_equation.py index 67b2fc1c..b2ff381d 100644 --- a/tests/core/test_optimization_equation.py +++ b/tests/core/test_optimization_equation.py @@ -11,7 +11,7 @@ from ..utils import assert_unordered_equality, create_indexsets_for_run -def df_from_list(equations: list): +def df_from_list(equations: list[Equation]) -> pd.DataFrame: return pd.DataFrame( [ [ @@ -36,7 +36,7 @@ def df_from_list(equations: list): class TestCoreEquation: - def test_create_equation(self, platform: ixmp4.Platform): + def test_create_equation(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") # Test normal creation @@ -99,7 +99,7 @@ def test_create_equation(self, platform: ixmp4.Platform): assert equation_3.columns[0].dtype == "object" assert equation_3.columns[1].dtype == "int64" - def test_get_equation(self, platform: ixmp4.Platform): + def test_get_equation(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 @@ -120,7 +120,7 @@ def test_get_equation(self, platform: ixmp4.Platform): with pytest.raises(Equation.NotFound): _ = run.optimization.equations.get("Equation 2") - def test_equation_add_data(self, platform: ixmp4.Platform): + def test_equation_add_data(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = tuple( IndexSet(_backend=platform.backend, _model=model) @@ -248,7 +248,7 @@ def test_equation_add_data(self, platform: ixmp4.Platform): ) assert_unordered_equality(expected, pd.DataFrame(equation_4.data)) - def test_equation_remove_data(self, platform: ixmp4.Platform): + def test_equation_remove_data(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset = run.optimization.indexsets.create("Indexset") indexset.add(data=["foo", "bar"]) @@ -267,7 +267,7 @@ def test_equation_remove_data(self, platform: ixmp4.Platform): equation.remove_data() assert equation.data == {} - def test_list_equation(self, platform: ixmp4.Platform): + def test_list_equation(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") # Per default, list() lists scalars for `default` version runs: run.set_as_default() @@ -299,7 +299,7 @@ def test_list_equation(self, platform: ixmp4.Platform): ] assert not (set(expected_id) ^ set(list_id)) - def test_tabulate_equation(self, platform: ixmp4.Platform): + def test_tabulate_equation(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = tuple( IndexSet(_backend=platform.backend, _model=model) @@ -348,7 +348,7 @@ def test_tabulate_equation(self, platform: ixmp4.Platform): run.optimization.equations.tabulate(), ) - def test_equation_docs(self, platform: ixmp4.Platform): + def test_equation_docs(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") (indexset,) = tuple( IndexSet(_backend=platform.backend, _model=model) diff --git a/tests/core/test_optimization_indexset.py b/tests/core/test_optimization_indexset.py index a1487e01..a3a80a35 100644 --- a/tests/core/test_optimization_indexset.py +++ b/tests/core/test_optimization_indexset.py @@ -48,7 +48,7 @@ def df_from_list(indexsets: list[IndexSet], include_data: bool = False) -> pd.Da class TestCoreIndexset: - def test_create_indexset(self, platform: ixmp4.Platform): + def test_create_indexset(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset_1 = run.optimization.indexsets.create("Indexset 1") assert indexset_1.id == 1 @@ -60,7 +60,7 @@ def test_create_indexset(self, platform: ixmp4.Platform): with pytest.raises(IndexSet.NotUnique): _ = run.optimization.indexsets.create("Indexset 1") - def test_get_indexset(self, platform: ixmp4.Platform): + def test_get_indexset(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") create_indexsets_for_run(platform=platform, run_id=run.id, amount=1) indexset = run.optimization.indexsets.get("Indexset 1") @@ -70,12 +70,12 @@ def test_get_indexset(self, platform: ixmp4.Platform): with pytest.raises(IndexSet.NotFound): _ = run.optimization.indexsets.get("Foo") - def test_add_data(self, platform: ixmp4.Platform): + def test_add_elements(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") test_data = ["foo", "bar"] indexset_1 = run.optimization.indexsets.create("Indexset 1") - indexset_1.add(test_data) # type: ignore - run.optimization.indexsets.create("Indexset 2").add(test_data) # type: ignore + indexset_1.add(test_data) + run.optimization.indexsets.create("Indexset 2").add(test_data) indexset_2 = run.optimization.indexsets.get("Indexset 2") assert indexset_1.data == indexset_2.data @@ -88,20 +88,20 @@ def test_add_data(self, platform: ixmp4.Platform): # Test data types are conserved indexset_3 = run.optimization.indexsets.create("Indexset 3") - test_data_2: list[float | int | str] = [1.2, 3.4, 5.6] + test_data_2 = [1.2, 3.4, 5.6] indexset_3.add(data=test_data_2) assert indexset_3.data == test_data_2 assert type(indexset_3.data[0]).__name__ == "float" indexset_4 = run.optimization.indexsets.create("Indexset 4") - test_data_3: list[float | int | str] = [0, 1, 2] + test_data_3 = [0, 1, 2] indexset_4.add(data=test_data_3) assert indexset_4.data == test_data_3 assert type(indexset_4.data[0]).__name__ == "int" - def test_list_indexsets(self, platform: ixmp4.Platform): + def test_list_indexsets(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset_1, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id @@ -122,7 +122,7 @@ def test_list_indexsets(self, platform: ixmp4.Platform): ] assert not (set(expected_id) ^ set(list_id)) - def test_tabulate_indexsets(self, platform: ixmp4.Platform): + def test_tabulate_indexsets(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset_1, indexset_2 = tuple( IndexSet(_backend=platform.backend, _model=model) @@ -152,7 +152,7 @@ def test_tabulate_indexsets(self, platform: ixmp4.Platform): ), ) - def test_indexset_docs(self, platform: ixmp4.Platform): + def test_indexset_docs(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") (indexset_1,) = tuple( IndexSet(_backend=platform.backend, _model=model) diff --git a/tests/core/test_optimization_parameter.py b/tests/core/test_optimization_parameter.py index f6baed9d..4daf49b0 100644 --- a/tests/core/test_optimization_parameter.py +++ b/tests/core/test_optimization_parameter.py @@ -11,7 +11,7 @@ from ..utils import assert_unordered_equality, create_indexsets_for_run -def df_from_list(parameters: list): +def df_from_list(parameters: list[Parameter]) -> pd.DataFrame: return pd.DataFrame( [ [ @@ -36,7 +36,7 @@ def df_from_list(parameters: list): class TestCoreParameter: - def test_create_parameter(self, platform: ixmp4.Platform): + def test_create_parameter(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") # Test normal creation @@ -99,7 +99,7 @@ def test_create_parameter(self, platform: ixmp4.Platform): assert parameter_3.columns[0].dtype == "object" assert parameter_3.columns[1].dtype == "int64" - def test_get_parameter(self, platform: ixmp4.Platform): + def test_get_parameter(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 @@ -120,7 +120,7 @@ def test_get_parameter(self, platform: ixmp4.Platform): with pytest.raises(Parameter.NotFound): _ = run.optimization.parameters.get("Parameter 2") - def test_parameter_add_data(self, platform: ixmp4.Platform): + def test_parameter_add_data(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") unit = platform.units.create("Unit") indexset, indexset_2 = tuple( @@ -250,7 +250,7 @@ def test_parameter_add_data(self, platform: ixmp4.Platform): ) assert_unordered_equality(expected, pd.DataFrame(parameter_4.data)) - def test_list_parameter(self, platform: ixmp4.Platform): + def test_list_parameter(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") create_indexsets_for_run(platform=platform, run_id=run.id) parameter = run.optimization.parameters.create( @@ -279,7 +279,7 @@ def test_list_parameter(self, platform: ixmp4.Platform): ] assert not (set(expected_id) ^ set(list_id)) - def test_tabulate_parameter(self, platform: ixmp4.Platform): + def test_tabulate_parameter(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = tuple( IndexSet(_backend=platform.backend, _model=model) @@ -330,7 +330,7 @@ def test_tabulate_parameter(self, platform: ixmp4.Platform): run.optimization.parameters.tabulate(), ) - def test_parameter_docs(self, platform: ixmp4.Platform): + def test_parameter_docs(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 diff --git a/tests/core/test_optimization_scalar.py b/tests/core/test_optimization_scalar.py index 470aa078..0a04e51e 100644 --- a/tests/core/test_optimization_scalar.py +++ b/tests/core/test_optimization_scalar.py @@ -7,7 +7,7 @@ from ..utils import assert_unordered_equality -def df_from_list(scalars: list[Scalar]): +def df_from_list(scalars: list[Scalar]) -> pd.DataFrame: return pd.DataFrame( [ [ @@ -34,7 +34,7 @@ def df_from_list(scalars: list[Scalar]): class TestCoreScalar: - def test_create_scalar(self, platform: ixmp4.Platform): + def test_create_scalar(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") unit = platform.units.create("Test Unit") scalar_1 = run.optimization.scalars.create( @@ -53,7 +53,8 @@ def test_create_scalar(self, platform: ixmp4.Platform): ) with pytest.raises(TypeError): - _ = run.optimization.scalars.create("Scalar 2") # type: ignore + # Testing a missing parameter on purpose + _ = run.optimization.scalars.create("Scalar 2") # type: ignore[call-arg] scalar_2 = run.optimization.scalars.create("Scalar 2", value=20, unit=unit) assert scalar_1.id != scalar_2.id @@ -61,7 +62,7 @@ def test_create_scalar(self, platform: ixmp4.Platform): scalar_3 = run.optimization.scalars.create("Scalar 3", value=1) assert scalar_3.unit.name == "" - def test_get_scalar(self, platform: ixmp4.Platform): + def test_get_scalar(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") unit = platform.units.create("Test Unit") scalar = run.optimization.scalars.create("Scalar", value=10, unit=unit.name) @@ -74,7 +75,7 @@ def test_get_scalar(self, platform: ixmp4.Platform): with pytest.raises(Scalar.NotFound): _ = run.optimization.scalars.get("Foo") - def test_update_scalar(self, platform: ixmp4.Platform): + def test_update_scalar(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") unit = platform.units.create("Test Unit") unit2 = platform.units.create("Test Unit 2") @@ -86,7 +87,9 @@ def test_update_scalar(self, platform: ixmp4.Platform): _ = run.optimization.scalars.create("Scalar", value=20, unit=unit2.name) scalar.value = 30 - scalar.unit = "Test Unit" + # NOTE mypy doesn't support setters taking a different type than their property + # https://github.com/python/mypy/issues/3004 + scalar.unit = "Test Unit" # type: ignore[assignment] # NOTE: doesn't work for some reason (but doesn't either for e.g. model.get()) # assert scalar == run.optimization.scalars.get("Scalar") result = run.optimization.scalars.get("Scalar") @@ -94,9 +97,9 @@ def test_update_scalar(self, platform: ixmp4.Platform): assert scalar.id == result.id assert scalar.name == result.name assert scalar.value == result.value == 30 - assert scalar.unit.id == result.unit.id == 1 # type: ignore + assert scalar.unit.id == result.unit.id == 1 - def test_list_scalars(self, platform: ixmp4.Platform): + def test_list_scalars(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") unit = platform.units.create("Test Unit") scalar_1 = run.optimization.scalars.create( @@ -119,7 +122,7 @@ def test_list_scalars(self, platform: ixmp4.Platform): ] assert not (set(expected_id) ^ set(list_id)) - def test_tabulate_scalars(self, platform: ixmp4.Platform): + def test_tabulate_scalars(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") unit = platform.units.create("Test Unit") scalar_1 = run.optimization.scalars.create("Scalar 1", value=1, unit=unit.name) @@ -137,7 +140,7 @@ def test_tabulate_scalars(self, platform: ixmp4.Platform): result = run.optimization.scalars.tabulate(name="Scalar 2") assert_unordered_equality(expected, result, check_dtype=False) - def test_scalar_docs(self, platform: ixmp4.Platform): + def test_scalar_docs(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") unit = platform.units.create("Test Unit") scalar = run.optimization.scalars.create("Scalar 1", value=4, unit=unit.name) diff --git a/tests/core/test_optimization_table.py b/tests/core/test_optimization_table.py index 93f0c03e..abe14b6c 100644 --- a/tests/core/test_optimization_table.py +++ b/tests/core/test_optimization_table.py @@ -11,7 +11,7 @@ from ..utils import create_indexsets_for_run -def df_from_list(tables: list[Table]): +def df_from_list(tables: list[Table]) -> pd.DataFrame: return pd.DataFrame( # Order is important here to avoid utils.assert_unordered_equality, # which doesn't like lists @@ -38,12 +38,12 @@ def df_from_list(tables: list[Table]): class TestCoreTable: - def test_create_table(self, platform: ixmp4.Platform): + def test_create_table(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") # Test normal creation indexset, indexset_2 = tuple( - IndexSet(_backend=platform.backend, _model=model) # type: ignore + IndexSet(_backend=platform.backend, _model=model) for model in create_indexsets_for_run(platform=platform, run_id=run.id) ) table = run.optimization.tables.create( @@ -99,7 +99,7 @@ def test_create_table(self, platform: ixmp4.Platform): assert table_3.columns[0].dtype == "object" assert table_3.columns[1].dtype == "int64" - def test_get_table(self, platform: ixmp4.Platform): + def test_get_table(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 @@ -118,10 +118,10 @@ def test_get_table(self, platform: ixmp4.Platform): with pytest.raises(Table.NotFound): _ = run.optimization.tables.get(name="Table 2") - def test_table_add_data(self, platform: ixmp4.Platform): + def test_table_add_data(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = tuple( - IndexSet(_backend=platform.backend, _model=model) # type: ignore + IndexSet(_backend=platform.backend, _model=model) for model in create_indexsets_for_run(platform=platform, run_id=run.id) ) indexset.add(data=["foo", "bar", ""]) @@ -238,7 +238,7 @@ def test_table_add_data(self, platform: ixmp4.Platform): table_5.add(data={}) assert table_5.data == test_data_5 - def test_list_tables(self, platform: ixmp4.Platform): + def test_list_tables(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") create_indexsets_for_run(platform=platform, run_id=run.id) table = run.optimization.tables.create( @@ -263,10 +263,10 @@ def test_list_tables(self, platform: ixmp4.Platform): list_id = [table.id for table in run.optimization.tables.list(name="Table")] assert not (set(expected_id) ^ set(list_id)) - def test_tabulate_table(self, platform: ixmp4.Platform): + def test_tabulate_table(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = tuple( - IndexSet(_backend=platform.backend, _model=model) # type: ignore + IndexSet(_backend=platform.backend, _model=model) for model in create_indexsets_for_run(platform=platform, run_id=run.id) ) table = run.optimization.tables.create( @@ -300,7 +300,7 @@ def test_tabulate_table(self, platform: ixmp4.Platform): run.optimization.tables.tabulate(), ) - def test_table_docs(self, platform: ixmp4.Platform): + def test_table_docs(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 diff --git a/tests/core/test_optimization_variable.py b/tests/core/test_optimization_variable.py index 560b4783..183e65f5 100644 --- a/tests/core/test_optimization_variable.py +++ b/tests/core/test_optimization_variable.py @@ -11,7 +11,7 @@ from ..utils import assert_unordered_equality, create_indexsets_for_run -def df_from_list(variables: list): +def df_from_list(variables: list[OptimizationVariable]) -> pd.DataFrame: return pd.DataFrame( [ [ @@ -36,7 +36,7 @@ def df_from_list(variables: list): class TestCoreVariable: - def test_create_variable(self, platform: ixmp4.Platform): + def test_create_variable(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") # Test creation without indexset @@ -123,7 +123,7 @@ def test_create_variable(self, platform: ixmp4.Platform): assert variable_4.columns[0].dtype == "object" assert variable_4.columns[1].dtype == "int64" - def test_get_variable(self, platform: ixmp4.Platform): + def test_get_variable(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 @@ -145,7 +145,7 @@ def test_get_variable(self, platform: ixmp4.Platform): with pytest.raises(OptimizationVariable.NotFound): _ = run.optimization.variables.get("Variable 2") - def test_variable_add_data(self, platform: ixmp4.Platform): + def test_variable_add_data(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = tuple( IndexSet(_backend=platform.backend, _model=model) @@ -273,7 +273,7 @@ def test_variable_add_data(self, platform: ixmp4.Platform): ) assert_unordered_equality(expected, pd.DataFrame(variable_4.data)) - def test_variable_remove_data(self, platform: ixmp4.Platform): + def test_variable_remove_data(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset = run.optimization.indexsets.create("Indexset") indexset.add(data=["foo", "bar"]) @@ -292,7 +292,7 @@ def test_variable_remove_data(self, platform: ixmp4.Platform): variable.remove_data() assert variable.data == {} - def test_list_variable(self, platform: ixmp4.Platform): + def test_list_variable(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id @@ -322,7 +322,7 @@ def test_list_variable(self, platform: ixmp4.Platform): ] assert not (set(expected_id) ^ set(list_id)) - def test_tabulate_variable(self, platform: ixmp4.Platform): + def test_tabulate_variable(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset, indexset_2 = tuple( IndexSet(_backend=platform.backend, _model=model) @@ -371,7 +371,7 @@ def test_tabulate_variable(self, platform: ixmp4.Platform): run.optimization.variables.tabulate(), ) - def test_variable_docs(self, platform: ixmp4.Platform): + def test_variable_docs(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 diff --git a/tests/core/test_region.py b/tests/core/test_region.py index beb20b27..148c3888 100644 --- a/tests/core/test_region.py +++ b/tests/core/test_region.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable + import pandas as pd import pytest @@ -9,13 +11,13 @@ from ..utils import assert_unordered_equality -def create_testcase_regions(platform): +def create_testcase_regions(platform: ixmp4.Platform) -> tuple[Region, Region]: reg = platform.regions.create("Test", hierarchy="default") other = platform.regions.create("Test Other", hierarchy="other") return reg, other -def df_from_list(regions): +def df_from_list(regions: Iterable[Region]) -> pd.DataFrame: return pd.DataFrame( [[r.id, r.name, r.hierarchy, r.created_at, r.created_by] for r in regions], columns=["id", "name", "hierarchy", "created_at", "created_by"], @@ -25,7 +27,7 @@ def df_from_list(regions): class TestCoreRegion: small = SmallIamcDataset - def test_delete_region(self, platform: ixmp4.Platform): + def test_delete_region(self, platform: ixmp4.Platform) -> None: reg1 = platform.regions.create("Test 1", hierarchy="default") reg2 = platform.regions.create("Test 2", hierarchy="default") reg3 = platform.regions.create("Test 3", hierarchy="default") @@ -48,16 +50,17 @@ def test_delete_region(self, platform: ixmp4.Platform): with pytest.raises(Region.DeletionPrevented): platform.regions.delete("Region 1") - def test_region_has_hierarchy(self, platform: ixmp4.Platform): + def test_region_has_hierarchy(self, platform: ixmp4.Platform) -> None: with pytest.raises(TypeError): - platform.regions.create("Test Region") # type:ignore + # We are testing exactly this: raising with missing argument. + platform.regions.create("Test Region") # type: ignore[call-arg] reg1 = platform.regions.create("Test", hierarchy="default") reg2 = platform.regions.get("Test") assert reg1.id == reg2.id - def test_get_region(self, platform: ixmp4.Platform): + def test_get_region(self, platform: ixmp4.Platform) -> None: reg1 = platform.regions.create("Test", hierarchy="default") reg2 = platform.regions.get("Test") @@ -66,13 +69,13 @@ def test_get_region(self, platform: ixmp4.Platform): with pytest.raises(Region.NotFound): platform.regions.get("Does not exist") - def test_region_unique(self, platform: ixmp4.Platform): + def test_region_unique(self, platform: ixmp4.Platform) -> None: platform.regions.create("Test", hierarchy="default") with pytest.raises(Region.NotUnique): platform.regions.create("Test", hierarchy="other") - def test_region_unknown(self, platform): + def test_region_unknown(self, platform: ixmp4.Platform) -> None: self.small.load_regions(platform) self.small.load_units(platform) @@ -83,7 +86,7 @@ def test_region_unknown(self, platform): with pytest.raises(Region.NotFound): run.iamc.add(invalid_data, type=DataPoint.Type.ANNUAL) - def test_list_region(self, platform: ixmp4.Platform): + def test_list_region(self, platform: ixmp4.Platform) -> None: regions = create_testcase_regions(platform) reg, other = regions @@ -95,7 +98,7 @@ def test_list_region(self, platform: ixmp4.Platform): b = [r.id for r in platform.regions.list(hierarchy="other")] assert not (set(a) ^ set(b)) - def test_tabulate_region(self, platform: ixmp4.Platform): + def test_tabulate_region(self, platform: ixmp4.Platform) -> None: regions = create_testcase_regions(platform) _, other = regions @@ -107,7 +110,7 @@ def test_tabulate_region(self, platform: ixmp4.Platform): b = platform.regions.tabulate(hierarchy="other") assert_unordered_equality(a, b, check_dtype=False) - def test_retrieve_docs(self, platform: ixmp4.Platform): + def test_retrieve_docs(self, platform: ixmp4.Platform) -> None: platform.regions.create("Test Region", "Test Hierarchy") docs_region1 = platform.regions.set_docs( "Test Region", "Description of test Region" @@ -124,7 +127,7 @@ def test_retrieve_docs(self, platform: ixmp4.Platform): assert platform.regions.get_docs("Test Region 2") == region2.docs - def test_delete_docs(self, platform: ixmp4.Platform): + def test_delete_docs(self, platform: ixmp4.Platform) -> None: region = platform.regions.create("Test Region", "Hierarchy") region.docs = "Description of test region" region.docs = None @@ -136,7 +139,8 @@ def test_delete_docs(self, platform: ixmp4.Platform): assert region.docs is None - region.docs = "Third description of test region" + # Mypy doesn't recognize del properly, it seems + region.docs = "Third description of test region" # type: ignore[unreachable] platform.regions.delete_docs("Test Region") assert region.docs is None diff --git a/tests/core/test_run.py b/tests/core/test_run.py index b7cf0f3c..e75aa4f5 100644 --- a/tests/core/test_run.py +++ b/tests/core/test_run.py @@ -2,6 +2,9 @@ import pandas.testing as pdt import pytest +# Import this from typing when dropping 3.11 +from typing_extensions import Unpack + import ixmp4 from ixmp4.core import Run from ixmp4.core.exceptions import IxmpError @@ -9,7 +12,7 @@ from ..fixtures import FilterIamcDataset -def _expected_runs_table(*row_default): +def _expected_runs_table(*row_default: Unpack[tuple[bool | None, ...]]) -> pd.DataFrame: rows = [] for i, default in enumerate(row_default, start=1): if default is not None: @@ -21,12 +24,12 @@ def _expected_runs_table(*row_default): class TestCoreRun: filter = FilterIamcDataset() - def test_run_notfound(self, platform: ixmp4.Platform): + def test_run_notfound(self, platform: ixmp4.Platform) -> None: # no Run with that model and scenario name exists with pytest.raises(Run.NotFound): _ = platform.runs.get("Unknown Model", "Unknown Scenario", version=1) - def test_run_versions(self, platform: ixmp4.Platform): + def test_run_versions(self, platform: ixmp4.Platform) -> None: run1 = platform.runs.create("Model", "Scenario") run2 = platform.runs.create("Model", "Scenario") @@ -159,7 +162,7 @@ def test_run_versions(self, platform: ixmp4.Platform): assert sorted(res["model"].tolist()) == [] assert sorted(res["scenario"].tolist()) == [] - def delete_all_datapoints(self, run: ixmp4.Run): + def delete_all_datapoints(self, run: ixmp4.Run) -> None: remove_data = run.iamc.tabulate(raw=True) annual = remove_data[remove_data["type"] == "ANNUAL"].dropna( how="all", axis="columns" diff --git a/tests/core/test_scenario.py b/tests/core/test_scenario.py index f526ed36..609c4d8c 100644 --- a/tests/core/test_scenario.py +++ b/tests/core/test_scenario.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable + import pandas as pd import pytest @@ -7,13 +9,13 @@ from ..utils import assert_unordered_equality -def create_testcase_scenarios(platform): +def create_testcase_scenarios(platform: ixmp4.Platform) -> tuple[Scenario, Scenario]: scenario = platform.scenarios.create("Scenario") scenario2 = platform.scenarios.create("Scenario 2") return scenario, scenario2 -def df_from_list(scenarios): +def df_from_list(scenarios: Iterable[Scenario]) -> pd.DataFrame: return pd.DataFrame( [[s.id, s.name, s.created_at, s.created_by] for s in scenarios], columns=["id", "name", "created_at", "created_by"], @@ -21,19 +23,19 @@ def df_from_list(scenarios): class TestCoreScenario: - def test_retrieve_scenario(self, platform: ixmp4.Platform): + def test_retrieve_scenario(self, platform: ixmp4.Platform) -> None: scenario1 = platform.scenarios.create("Scenario") scenario2 = platform.scenarios.get("Scenario") assert scenario1.id == scenario2.id - def test_scenario_unqiue(self, platform: ixmp4.Platform): + def test_scenario_unqiue(self, platform: ixmp4.Platform) -> None: platform.scenarios.create("Scenario") with pytest.raises(Scenario.NotUnique): platform.scenarios.create("Scenario") - def test_list_scenario(self, platform: ixmp4.Platform): + def test_list_scenario(self, platform: ixmp4.Platform) -> None: scenarios = create_testcase_scenarios(platform) scenario, _ = scenarios @@ -45,7 +47,7 @@ def test_list_scenario(self, platform: ixmp4.Platform): b = [s.id for s in platform.scenarios.list(name="Scenario")] assert not (set(a) ^ set(b)) - def test_tabulate_scenario(self, platform: ixmp4.Platform): + def test_tabulate_scenario(self, platform: ixmp4.Platform) -> None: scenarios = create_testcase_scenarios(platform) scenario, _ = scenarios @@ -57,7 +59,7 @@ def test_tabulate_scenario(self, platform: ixmp4.Platform): b = platform.scenarios.tabulate(name="Scenario") assert_unordered_equality(a, b, check_dtype=False) - def test_retrieve_docs(self, platform: ixmp4.Platform): + def test_retrieve_docs(self, platform: ixmp4.Platform) -> None: platform.scenarios.create("Scenario") docs_scenario1 = platform.scenarios.set_docs( "Scenario", "Description of test Scenario" @@ -74,7 +76,7 @@ def test_retrieve_docs(self, platform: ixmp4.Platform): assert platform.scenarios.get_docs("Scenario2") == scenario2.docs - def test_delete_docs(self, platform: ixmp4.Platform): + def test_delete_docs(self, platform: ixmp4.Platform) -> None: scenario = platform.scenarios.create("Scenario") scenario.docs = "Description of test Scenario" scenario.docs = None @@ -86,7 +88,8 @@ def test_delete_docs(self, platform: ixmp4.Platform): assert scenario.docs is None - scenario.docs = "Third description of test Scenario" + # Mypy doesn't recognize del properly, it seems + scenario.docs = "Third description of test Scenario" # type: ignore[unreachable] platform.scenarios.delete_docs("Scenario") assert scenario.docs is None diff --git a/tests/core/test_unit.py b/tests/core/test_unit.py index dac74dbc..5204eb2a 100644 --- a/tests/core/test_unit.py +++ b/tests/core/test_unit.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable + import pandas as pd import pytest @@ -8,13 +10,13 @@ from ..utils import assert_unordered_equality -def create_testcase_units(platform: ixmp4.Platform): +def create_testcase_units(platform: ixmp4.Platform) -> tuple[Unit, Unit]: unit = platform.units.create("Test") unit2 = platform.units.create("Test 2") return unit, unit2 -def df_from_list(units): +def df_from_list(units: Iterable[Unit]) -> pd.DataFrame: return pd.DataFrame( [[u.id, u.name, u.created_at, u.created_by] for u in units], columns=["id", "name", "created_at", "created_by"], @@ -24,7 +26,7 @@ def df_from_list(units): class TestCoreUnit: small = SmallIamcDataset() - def test_delete_unit(self, platform: ixmp4.Platform): + def test_delete_unit(self, platform: ixmp4.Platform) -> None: unit1 = platform.units.create("Test 1") unit2 = platform.units.create("Test 2") unit3 = platform.units.create("Test 3") @@ -47,19 +49,19 @@ def test_delete_unit(self, platform: ixmp4.Platform): with pytest.raises(Unit.DeletionPrevented): platform.units.delete("Unit 1") - def test_retrieve_unit(self, platform: ixmp4.Platform): + def test_retrieve_unit(self, platform: ixmp4.Platform) -> None: unit1 = platform.units.create("Test") unit2 = platform.units.get("Test") assert unit1.id == unit2.id - def test_unit_unqiue(self, platform: ixmp4.Platform): + def test_unit_unqiue(self, platform: ixmp4.Platform) -> None: platform.units.create("Test") with pytest.raises(Unit.NotUnique): platform.units.create("Test") - def test_unit_dimensionless(self, platform: ixmp4.Platform): + def test_unit_dimensionless(self, platform: ixmp4.Platform) -> None: unit1 = platform.units.create("") unit2 = platform.units.get("") @@ -68,7 +70,7 @@ def test_unit_dimensionless(self, platform: ixmp4.Platform): assert "" in platform.units.tabulate().values assert "" in [unit.name for unit in platform.units.list()] - def test_unit_illegal_names(self, platform: ixmp4.Platform): + def test_unit_illegal_names(self, platform: ixmp4.Platform) -> None: with pytest.raises(ValueError, match="Unit name 'dimensionless' is reserved,"): platform.units.create("dimensionless") @@ -77,7 +79,7 @@ def test_unit_illegal_names(self, platform: ixmp4.Platform): ): platform.units.create(" ") - def test_unit_unknown(self, platform: ixmp4.Platform): + def test_unit_unknown(self, platform: ixmp4.Platform) -> None: self.small.load_regions(platform) self.small.load_units(platform) @@ -88,7 +90,7 @@ def test_unit_unknown(self, platform: ixmp4.Platform): with pytest.raises(Unit.NotFound): run.iamc.add(invalid_data, type=ixmp4.DataPoint.Type.ANNUAL) - def test_list_unit(self, platform: ixmp4.Platform): + def test_list_unit(self, platform: ixmp4.Platform) -> None: units = create_testcase_units(platform) unit, _ = units @@ -100,7 +102,7 @@ def test_list_unit(self, platform: ixmp4.Platform): b = [u.id for u in platform.units.list(name="Test")] assert not (set(a) ^ set(b)) - def test_tabulate_unit(self, platform: ixmp4.Platform): + def test_tabulate_unit(self, platform: ixmp4.Platform) -> None: units = create_testcase_units(platform) unit, _ = units @@ -112,7 +114,7 @@ def test_tabulate_unit(self, platform: ixmp4.Platform): b = platform.units.tabulate(name="Test") assert_unordered_equality(a, b, check_dtype=False) - def test_retrieve_docs(self, platform: ixmp4.Platform): + def test_retrieve_docs(self, platform: ixmp4.Platform) -> None: platform.units.create("Unit") docs_unit1 = platform.units.set_docs("Unit", "Description of test Unit") docs_unit2 = platform.units.get_docs("Unit") @@ -127,7 +129,7 @@ def test_retrieve_docs(self, platform: ixmp4.Platform): assert platform.units.get_docs("Unit2") == unit2.docs - def test_delete_docs(self, platform: ixmp4.Platform): + def test_delete_docs(self, platform: ixmp4.Platform) -> None: unit = platform.units.create("Unit") unit.docs = "Description of test Unit" unit.docs = None @@ -139,7 +141,8 @@ def test_delete_docs(self, platform: ixmp4.Platform): assert unit.docs is None - unit.docs = "Third description of test Unit" + # Mypy doesn't recognize del properly, it seems + unit.docs = "Third description of test Unit" # type: ignore[unreachable] platform.units.delete_docs("Unit") assert unit.docs is None diff --git a/tests/core/test_variable.py b/tests/core/test_variable.py index a6bf9dd6..c9624e78 100644 --- a/tests/core/test_variable.py +++ b/tests/core/test_variable.py @@ -1,4 +1,4 @@ -# Variable tests are currently disabled +from collections.abc import Iterable import pandas as pd import pytest @@ -9,7 +9,9 @@ from ..utils import assert_unordered_equality -def create_testcase_iamc_variables(platform): +def create_testcase_iamc_variables( + platform: ixmp4.Platform, +) -> tuple[Variable, Variable]: platform.regions.create("Region", "default") platform.units.create("Unit") @@ -30,7 +32,7 @@ def create_testcase_iamc_variables(platform): return iamc_variable, iamc_variable2 -def df_from_list(iamc_variables): +def df_from_list(iamc_variables: Iterable[Variable]) -> pd.DataFrame: return pd.DataFrame( [[v.id, v.name, v.created_at, v.created_by] for v in iamc_variables], columns=["id", "name", "created_at", "created_by"], @@ -38,7 +40,7 @@ def df_from_list(iamc_variables): class TestCoreVariable: - def test_retrieve_iamc_variable(self, platform: ixmp4.Platform): + def test_retrieve_iamc_variable(self, platform: ixmp4.Platform) -> None: iamc_variable1 = platform.iamc.variables.create("IAMC Variable") platform.regions.create("Region", "default") platform.units.create("Unit") @@ -57,13 +59,13 @@ def test_retrieve_iamc_variable(self, platform: ixmp4.Platform): assert iamc_variable1.id == iamc_variable2.id - def test_iamc_variable_unqiue(self, platform: ixmp4.Platform): + def test_iamc_variable_unqiue(self, platform: ixmp4.Platform) -> None: platform.iamc.variables.create("IAMC Variable") with pytest.raises(Variable.NotUnique): platform.iamc.variables.create("IAMC Variable") - def test_list_iamc_variable(self, platform: ixmp4.Platform): + def test_list_iamc_variable(self, platform: ixmp4.Platform) -> None: iamc_variables = create_testcase_iamc_variables(platform) iamc_variable, _ = iamc_variables @@ -75,7 +77,7 @@ def test_list_iamc_variable(self, platform: ixmp4.Platform): b = [v.id for v in platform.iamc.variables.list(name="IAMC Variable")] assert not (set(a) ^ set(b)) - def test_tabulate_iamc_variable(self, platform: ixmp4.Platform): + def test_tabulate_iamc_variable(self, platform: ixmp4.Platform) -> None: iamc_variables = create_testcase_iamc_variables(platform) iamc_variable, _ = iamc_variables @@ -87,7 +89,7 @@ def test_tabulate_iamc_variable(self, platform: ixmp4.Platform): b = platform.iamc.variables.tabulate(name="IAMC Variable") assert_unordered_equality(a, b, check_dtype=False) - def test_retrieve_docs(self, platform: ixmp4.Platform): + def test_retrieve_docs(self, platform: ixmp4.Platform) -> None: _, iamc_variable2 = create_testcase_iamc_variables(platform) docs_iamc_variable1 = platform.iamc.variables.set_docs( "IAMC Variable", "Description of test IAMC Variable" @@ -104,7 +106,7 @@ def test_retrieve_docs(self, platform: ixmp4.Platform): platform.iamc.variables.get_docs("IAMC Variable 2") == iamc_variable2.docs ) - def test_delete_docs(self, platform: ixmp4.Platform): + def test_delete_docs(self, platform: ixmp4.Platform) -> None: iamc_variable, _ = create_testcase_iamc_variables(platform) iamc_variable.docs = "Description of test IAMC Variable" iamc_variable.docs = None @@ -116,7 +118,8 @@ def test_delete_docs(self, platform: ixmp4.Platform): assert iamc_variable.docs is None - iamc_variable.docs = "Third description of test IAMC Variable" + # Mypy doesn't recognize del properly, it seems + iamc_variable.docs = "Third description of test IAMC Variable" # type: ignore[unreachable] platform.iamc.variables.delete_docs("IAMC Variable") assert iamc_variable.docs is None diff --git a/tests/data/test_count.py b/tests/data/test_count.py index e23b20b5..0637bfd4 100644 --- a/tests/data/test_count.py +++ b/tests/data/test_count.py @@ -1,11 +1,14 @@ from functools import reduce +from typing import Any import pytest import ixmp4 +from ixmp4.data.backend import Backend +from ixmp4.data.db.base import CountKwargs, Enumerator -def deepgetattr(obj, attr): +def deepgetattr(obj: Backend, attr: str) -> Any: return reduce(getattr, attr.split("."), obj) @@ -75,6 +78,11 @@ class TestDataCount: ], ], ) - def test_count(self, db_platform_big: ixmp4.Platform, repo_name, filters): + def test_count( + self, db_platform_big: ixmp4.Platform, repo_name: str, filters: CountKwargs + ) -> None: repo = deepgetattr(db_platform_big.backend, repo_name) + # NOTE this check would not be necessary if db.platform_big.backend was typed as + # a DB backend and deepgetattr() to return DB-layer Enumerator + assert isinstance(repo, Enumerator) assert len(repo.list(**filters)) == repo.count(**filters) diff --git a/tests/data/test_docs.py b/tests/data/test_docs.py index ae529f9b..80cb2e35 100644 --- a/tests/data/test_docs.py +++ b/tests/data/test_docs.py @@ -5,14 +5,14 @@ class TestDataDocs: - def test_get_and_set_modeldocs(self, platform: ixmp4.Platform): + def test_get_and_set_modeldocs(self, platform: ixmp4.Platform) -> None: model = platform.backend.models.create("Model") docs_model = platform.backend.models.docs.set(model.id, "Description of Model") docs_model1 = platform.backend.models.docs.get(model.id) assert docs_model == docs_model1 - def test_change_empty_modeldocs(self, platform: ixmp4.Platform): + def test_change_empty_modeldocs(self, platform: ixmp4.Platform) -> None: model = platform.backend.models.create("Model") with pytest.raises(Docs.NotFound): @@ -30,7 +30,7 @@ def test_change_empty_modeldocs(self, platform: ixmp4.Platform): assert platform.backend.models.docs.get(model.id) == docs_model2 - def test_delete_modeldocs(self, platform: ixmp4.Platform): + def test_delete_modeldocs(self, platform: ixmp4.Platform) -> None: model = platform.backend.models.create("Model") docs_model = platform.backend.models.docs.set( model.id, "Description of test Model" @@ -43,7 +43,7 @@ def test_delete_modeldocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.models.docs.get(model.id) - def test_get_and_set_regiondocs(self, platform: ixmp4.Platform): + def test_get_and_set_regiondocs(self, platform: ixmp4.Platform) -> None: region = platform.backend.regions.create("Region", "Hierarchy") docs_region = platform.backend.regions.docs.set( region.id, "Description of test Region" @@ -52,7 +52,7 @@ def test_get_and_set_regiondocs(self, platform: ixmp4.Platform): assert docs_region == docs_region1 - def test_change_empty_regiondocs(self, platform: ixmp4.Platform): + def test_change_empty_regiondocs(self, platform: ixmp4.Platform) -> None: region = platform.backend.regions.create("Region", "Hierarchy") with pytest.raises(Docs.NotFound): @@ -70,7 +70,7 @@ def test_change_empty_regiondocs(self, platform: ixmp4.Platform): assert platform.backend.regions.docs.get(region.id) == docs_region2 - def test_delete_regiondocs(self, platform: ixmp4.Platform): + def test_delete_regiondocs(self, platform: ixmp4.Platform) -> None: region = platform.backend.regions.create("Region", "Hierarchy") docs_region = platform.backend.regions.docs.set( region.id, "Description of test region" @@ -83,7 +83,7 @@ def test_delete_regiondocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.regions.docs.get(region.id) - def test_get_and_set_scenariodocs(self, platform: ixmp4.Platform): + def test_get_and_set_scenariodocs(self, platform: ixmp4.Platform) -> None: scenario = platform.backend.scenarios.create("Scenario") docs_scenario = platform.backend.scenarios.docs.set( scenario.id, "Description of Scenario" @@ -91,7 +91,7 @@ def test_get_and_set_scenariodocs(self, platform: ixmp4.Platform): docs_scenario1 = platform.backend.scenarios.docs.get(scenario.id) assert docs_scenario == docs_scenario1 - def test_change_empty_scenariodocs(self, platform: ixmp4.Platform): + def test_change_empty_scenariodocs(self, platform: ixmp4.Platform) -> None: scenario = platform.backend.scenarios.create("Scenario") with pytest.raises(Docs.NotFound): @@ -109,7 +109,7 @@ def test_change_empty_scenariodocs(self, platform: ixmp4.Platform): assert platform.backend.scenarios.docs.get(scenario.id) == docs_scenario2 - def test_delete_scenariodocs(self, platform: ixmp4.Platform): + def test_delete_scenariodocs(self, platform: ixmp4.Platform) -> None: scenario = platform.backend.scenarios.create("Scenario") docs_scenario = platform.backend.scenarios.docs.set( scenario.id, "Description of test Scenario" @@ -122,14 +122,14 @@ def test_delete_scenariodocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.scenarios.docs.get(scenario.id) - def test_get_and_set_unitdocs(self, platform: ixmp4.Platform): + def test_get_and_set_unitdocs(self, platform: ixmp4.Platform) -> None: unit = platform.backend.units.create("Unit") docs_unit = platform.backend.units.docs.set(unit.id, "Description of test Unit") docs_unit1 = platform.backend.units.docs.get(unit.id) assert docs_unit == docs_unit1 - def test_change_empty_unitdocs(self, platform: ixmp4.Platform): + def test_change_empty_unitdocs(self, platform: ixmp4.Platform) -> None: unit = platform.backend.units.create("Unit") with pytest.raises(Docs.NotFound): @@ -147,7 +147,7 @@ def test_change_empty_unitdocs(self, platform: ixmp4.Platform): assert platform.backend.units.docs.get(unit.id) == docs_unit2 - def test_delete_unitdocs(self, platform: ixmp4.Platform): + def test_delete_unitdocs(self, platform: ixmp4.Platform) -> None: unit = platform.backend.units.create("Unit") docs_unit = platform.backend.units.docs.set(unit.id, "Description of test Unit") @@ -158,7 +158,7 @@ def test_delete_unitdocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.units.docs.get(unit.id) - def test_get_and_set_variabledocs(self, platform: ixmp4.Platform): + def test_get_and_set_variabledocs(self, platform: ixmp4.Platform) -> None: variable = platform.backend.iamc.variables.create("Variable") docs_variable = platform.backend.iamc.variables.docs.set( variable.id, "Description of test Variable" @@ -167,7 +167,7 @@ def test_get_and_set_variabledocs(self, platform: ixmp4.Platform): assert docs_variable == docs_variables1 - def test_change_empty_variabledocs(self, platform: ixmp4.Platform): + def test_change_empty_variabledocs(self, platform: ixmp4.Platform) -> None: variable = platform.backend.iamc.variables.create("Variable") with pytest.raises(Docs.NotFound): @@ -185,7 +185,7 @@ def test_change_empty_variabledocs(self, platform: ixmp4.Platform): assert platform.backend.iamc.variables.docs.get(variable.id) == docs_variable2 - def test_delete_variabledocs(self, platform: ixmp4.Platform): + def test_delete_variabledocs(self, platform: ixmp4.Platform) -> None: variable = platform.backend.iamc.variables.create("Variable") docs_variable = platform.backend.iamc.variables.docs.set( variable.id, "Description of test Variable" @@ -198,7 +198,7 @@ def test_delete_variabledocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.iamc.variables.docs.get(variable.id) - def test_get_and_set_indexsetdocs(self, platform: ixmp4.Platform): + def test_get_and_set_indexsetdocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset = platform.backend.optimization.indexsets.create( run_id=run.id, name="IndexSet" @@ -210,7 +210,7 @@ def test_get_and_set_indexsetdocs(self, platform: ixmp4.Platform): assert docs_indexset == docs_indexset1 - def test_change_empty_indexsetdocs(self, platform: ixmp4.Platform): + def test_change_empty_indexsetdocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset = platform.backend.optimization.indexsets.create( run_id=run.id, name="IndexSet" @@ -237,7 +237,7 @@ def test_change_empty_indexsetdocs(self, platform: ixmp4.Platform): == docs_indexset2 ) - def test_delete_indexsetdocs(self, platform: ixmp4.Platform): + def test_delete_indexsetdocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset = platform.backend.optimization.indexsets.create( run_id=run.id, name="IndexSet" @@ -256,7 +256,7 @@ def test_delete_indexsetdocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.optimization.indexsets.docs.get(indexset.id) - def test_get_and_set_scalardocs(self, platform: ixmp4.Platform): + def test_get_and_set_scalardocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") unit = platform.backend.units.create("Unit") scalar = platform.backend.optimization.scalars.create( @@ -269,7 +269,7 @@ def test_get_and_set_scalardocs(self, platform: ixmp4.Platform): assert docs_scalar == docs_scalar1 - def test_change_empty_scalardocs(self, platform: ixmp4.Platform): + def test_change_empty_scalardocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") unit = platform.backend.units.create("Unit") scalar = platform.backend.optimization.scalars.create( @@ -291,7 +291,7 @@ def test_change_empty_scalardocs(self, platform: ixmp4.Platform): assert platform.backend.optimization.scalars.docs.get(scalar.id) == docs_scalar2 - def test_delete_scalardocs(self, platform: ixmp4.Platform): + def test_delete_scalardocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") unit = platform.backend.units.create("Unit") scalar = platform.backend.optimization.scalars.create( @@ -308,7 +308,7 @@ def test_delete_scalardocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.optimization.scalars.docs.get(scalar.id) - def test_get_and_set_tabledocs(self, platform: ixmp4.Platform): + def test_get_and_set_tabledocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -323,7 +323,7 @@ def test_get_and_set_tabledocs(self, platform: ixmp4.Platform): assert docs_table == docs_table1 - def test_change_empty_tabledocs(self, platform: ixmp4.Platform): + def test_change_empty_tabledocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -347,7 +347,7 @@ def test_change_empty_tabledocs(self, platform: ixmp4.Platform): assert platform.backend.optimization.tables.docs.get(table.id) == docs_table2 - def test_delete_tabledocs(self, platform: ixmp4.Platform): + def test_delete_tabledocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -366,7 +366,7 @@ def test_delete_tabledocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.optimization.tables.docs.get(table.id) - def test_get_and_set_parameterdocs(self, platform: ixmp4.Platform): + def test_get_and_set_parameterdocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -383,7 +383,7 @@ def test_get_and_set_parameterdocs(self, platform: ixmp4.Platform): assert docs_parameter == docs_parameter1 - def test_change_empty_parameterdocs(self, platform: ixmp4.Platform): + def test_change_empty_parameterdocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -413,7 +413,7 @@ def test_change_empty_parameterdocs(self, platform: ixmp4.Platform): == docs_parameter2 ) - def test_delete_parameterdocs(self, platform: ixmp4.Platform): + def test_delete_parameterdocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -435,7 +435,9 @@ def test_delete_parameterdocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.optimization.parameters.docs.get(parameter.id) - def test_get_and_set_optimizationvariabledocs(self, platform: ixmp4.Platform): + def test_get_and_set_optimizationvariabledocs( + self, platform: ixmp4.Platform + ) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -450,7 +452,9 @@ def test_get_and_set_optimizationvariabledocs(self, platform: ixmp4.Platform): assert docs_variable == docs_variable1 - def test_change_empty_optimizationvariabledocs(self, platform: ixmp4.Platform): + def test_change_empty_optimizationvariabledocs( + self, platform: ixmp4.Platform + ) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -480,7 +484,7 @@ def test_change_empty_optimizationvariabledocs(self, platform: ixmp4.Platform): == docs_variable2 ) - def test_delete_optimizationvariabledocs(self, platform: ixmp4.Platform): + def test_delete_optimizationvariabledocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -502,7 +506,7 @@ def test_delete_optimizationvariabledocs(self, platform: ixmp4.Platform): with pytest.raises(Docs.NotFound): platform.backend.optimization.variables.docs.get(variable.id) - def test_get_and_set_equationdocs(self, platform: ixmp4.Platform): + def test_get_and_set_equationdocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -517,7 +521,7 @@ def test_get_and_set_equationdocs(self, platform: ixmp4.Platform): assert docs_equation == docs_equation1 - def test_change_empty_equationdocs(self, platform: ixmp4.Platform): + def test_change_empty_equationdocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -547,7 +551,7 @@ def test_change_empty_equationdocs(self, platform: ixmp4.Platform): == docs_equation2 ) - def test_delete_optimizationequationdocs(self, platform: ixmp4.Platform): + def test_delete_optimizationequationdocs(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _ = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" diff --git a/tests/data/test_iamc_datapoint.py b/tests/data/test_iamc_datapoint.py index a74db889..b51a6a53 100644 --- a/tests/data/test_iamc_datapoint.py +++ b/tests/data/test_iamc_datapoint.py @@ -34,10 +34,15 @@ ({"scenario": {"name": "Scenario 2"}}, ("scenario", "__eq__", "Scenario 2")), ], ) -def test_filtering(platform: ixmp4.Platform, filter, exp_filter): +def test_filtering( + platform: ixmp4.Platform, + filter: dict[str, int | list[int] | dict[str, str | list[str]]], + exp_filter: tuple[str, str, int | str | list[int] | list[str]], +) -> None: run1, run2 = filter_dataset.load_dataset(platform) run2.set_as_default() - obs = platform.backend.iamc.datapoints.tabulate(join_parameters=True, **filter) + # Not sure why mypy complains here, maybe about covariance? + obs = platform.backend.iamc.datapoints.tabulate(join_parameters=True, **filter) # type: ignore[arg-type] exp = filter_dataset.datapoints.copy() if exp_filter is not None: @@ -61,8 +66,12 @@ def test_filtering(platform: ixmp4.Platform, filter, exp_filter): {"run": {"default_only": "test"}}, ], ) -def test_invalid_filters(platform: ixmp4.Platform, filter, request): +def test_invalid_filters( + platform: ixmp4.Platform, + filter: dict[str, dict[str, str | bool]], + request: pytest.FixtureRequest, +) -> None: with pytest.raises(BadFilterArguments): - platform.backend.iamc.datapoints.tabulate(**filter) + platform.backend.iamc.datapoints.tabulate(**filter) # type: ignore[arg-type] with pytest.raises(BadFilterArguments): - platform.backend.iamc.datapoints.list(**filter) + platform.backend.iamc.datapoints.list(**filter) # type: ignore[arg-type] diff --git a/tests/data/test_iamc_variable.py b/tests/data/test_iamc_variable.py index 603f31e2..3e0aa759 100644 --- a/tests/data/test_iamc_variable.py +++ b/tests/data/test_iamc_variable.py @@ -9,23 +9,23 @@ class TestDataIamcVariable: filter = FilterIamcDataset() - def test_create_iamc_variable(self, platform: ixmp4.Platform): + def test_create_iamc_variable(self, platform: ixmp4.Platform) -> None: variable = platform.backend.iamc.variables.create("Variable") assert variable.name == "Variable" assert variable.created_at is not None assert variable.created_by == "@unknown" - def test_iamc_variable_unique(self, platform: ixmp4.Platform): + def test_iamc_variable_unique(self, platform: ixmp4.Platform) -> None: platform.backend.iamc.variables.create("Variable") with pytest.raises(Variable.NotUnique): platform.iamc.variables.create("Variable") - def test_iamc_variable_not_found(self, platform: ixmp4.Platform): + def test_iamc_variable_not_found(self, platform: ixmp4.Platform) -> None: with pytest.raises(Variable.NotFound): platform.iamc.variables.get("Variable") - def test_filter_iamc_variable(self, platform: ixmp4.Platform): + def test_filter_iamc_variable(self, platform: ixmp4.Platform) -> None: run1, run2 = self.filter.load_dataset(platform) res = platform.backend.iamc.variables.tabulate(unit={"name": "Unit 1"}) assert sorted(res["name"].tolist()) == ["Variable 1", "Variable 3"] diff --git a/tests/data/test_meta.py b/tests/data/test_meta.py index 099f499a..924c8a82 100644 --- a/tests/data/test_meta.py +++ b/tests/data/test_meta.py @@ -7,7 +7,7 @@ from ..utils import assert_unordered_equality -TEST_ENTRIES = [ +TEST_ENTRIES: list[tuple[str, bool | float | int | str, str]] = [ ("Boolean", True, RunMetaEntry.Type.BOOL), ("Float", 0.2, RunMetaEntry.Type.FLOAT), ("Integer", 1, RunMetaEntry.Type.INT), @@ -16,30 +16,30 @@ TEST_ENTRIES_DF = pd.DataFrame( [[id, key, type, value] for id, (key, value, type) in enumerate(TEST_ENTRIES, 1)], - columns=["id", "key", "type", "value"], + columns=["id", "key", "dtype", "value"], ) TEST_ILLEGAL_META_KEYS = {"model", "scenario", "id", "version", "is_default"} class TestDataMeta: - def test_create_get_entry(self, platform: ixmp4.Platform): + def test_create_get_entry(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") run.set_as_default() for key, value, type in TEST_ENTRIES: - entry = platform.backend.meta.create(run.id, key, value) # type:ignore + entry = platform.backend.meta.create(run.id, key, value) assert entry.key == key assert entry.value == value - assert entry.type == type + assert entry.dtype == type for key, value, type in TEST_ENTRIES: entry = platform.backend.meta.get(run.id, key) assert entry.key == key assert entry.value == value - assert entry.type == type + assert entry.dtype == type - def test_illegal_key(self, platform: ixmp4.Platform): + def test_illegal_key(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") for key in TEST_ILLEGAL_META_KEYS: with pytest.raises(InvalidRunMeta, match="Illegal meta key: " + key): @@ -51,7 +51,7 @@ def test_illegal_key(self, platform: ixmp4.Platform): with pytest.raises(InvalidRunMeta, match=r"Illegal meta key\(s\): " + key): platform.backend.meta.bulk_upsert(df) - def test_entry_unique(self, platform: ixmp4.Platform): + def test_entry_unique(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") platform.backend.meta.create(run.id, "Key", "Value") @@ -61,7 +61,7 @@ def test_entry_unique(self, platform: ixmp4.Platform): with pytest.raises(RunMetaEntry.NotUnique): platform.backend.meta.create(run.id, "Key", 1) - def test_entry_not_found(self, platform: ixmp4.Platform): + def test_entry_not_found(self, platform: ixmp4.Platform) -> None: with pytest.raises(RunMetaEntry.NotFound): platform.backend.meta.get(-1, "Key") @@ -70,7 +70,7 @@ def test_entry_not_found(self, platform: ixmp4.Platform): with pytest.raises(RunMetaEntry.NotFound): platform.backend.meta.get(run.id, "Key") - def test_delete_entry(self, platform: ixmp4.Platform): + def test_delete_entry(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") entry = platform.backend.meta.create(run.id, "Key", "Value") platform.backend.meta.delete(entry.id) @@ -78,26 +78,26 @@ def test_delete_entry(self, platform: ixmp4.Platform): with pytest.raises(RunMetaEntry.NotFound): platform.backend.meta.get(run.id, "Key") - def test_list_entry(self, platform: ixmp4.Platform): + def test_list_entry(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") run.set_as_default() for key, value, _ in TEST_ENTRIES: - entry = platform.backend.meta.create(run.id, key, value) # type:ignore + entry = platform.backend.meta.create(run.id, key, value) entries = platform.backend.meta.list() for (key, value, type), entry in zip(TEST_ENTRIES, entries): assert entry.key == key assert entry.value == value - assert entry.type == type + assert entry.dtype == type - def test_tabulate_entry(self, platform: ixmp4.Platform): + def test_tabulate_entry(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") run.set_as_default() for key, value, _ in TEST_ENTRIES: - platform.backend.meta.create(run.id, key, value) # type:ignore + platform.backend.meta.create(run.id, key, value) true_entries = TEST_ENTRIES_DF.copy() true_entries["run__id"] = run.id @@ -105,16 +105,16 @@ def test_tabulate_entry(self, platform: ixmp4.Platform): entries = platform.backend.meta.tabulate() assert_unordered_equality(entries, true_entries) - def test_tabulate_entries_with_run_filters(self, platform: ixmp4.Platform): + def test_tabulate_entries_with_run_filters(self, platform: ixmp4.Platform) -> None: run1 = platform.runs.create("Model", "Scenario") run1.set_as_default() run2 = platform.runs.create("Model 2", "Scenario 2") # Splitting the loop to more easily correct the id column below for key, value, _ in TEST_ENTRIES: - platform.backend.meta.create(run1.id, key, value) # type:ignore + platform.backend.meta.create(run1.id, key, value) for key, value, _ in TEST_ENTRIES: - platform.backend.meta.create(run2.id, key, value) # type:ignore + platform.backend.meta.create(run2.id, key, value) true_entries1 = TEST_ENTRIES_DF.copy() true_entries1["run__id"] = run1.id @@ -140,12 +140,12 @@ def test_tabulate_entries_with_run_filters(self, platform: ixmp4.Platform): true_entries2, ) - def test_tabulate_entries_with_key_filters(self, platform: ixmp4.Platform): + def test_tabulate_entries_with_key_filters(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") run.set_as_default() for key, value, _ in TEST_ENTRIES: - platform.backend.meta.create(run.id, key, value) # type:ignore + platform.backend.meta.create(run.id, key, value) # Select just some key from TEST_ENTRIES key = TEST_ENTRIES[1][0] @@ -157,7 +157,7 @@ def test_tabulate_entries_with_key_filters(self, platform: ixmp4.Platform): assert_unordered_equality(entry, true_entry, check_dtype=False) - def test_entry_bulk_operations(self, platform: ixmp4.Platform): + def test_entry_bulk_operations(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") run.set_as_default() @@ -165,14 +165,14 @@ def test_entry_bulk_operations(self, platform: ixmp4.Platform): entries["run__id"] = run.id # == Full Addition == - platform.backend.meta.bulk_upsert(entries.drop(columns=["id", "type"])) + platform.backend.meta.bulk_upsert(entries.drop(columns=["id", "dtype"])) ret = platform.backend.meta.tabulate() assert_unordered_equality(entries, ret) # == Partial Removal == # Remove half the data remove_data = entries.head(len(entries) // 2).drop( - columns=["value", "id", "type"] + columns=["value", "id", "dtype"] ) remaining_data = entries.tail(len(entries) // 2).reset_index(drop=True) platform.backend.meta.bulk_delete(remove_data) @@ -191,24 +191,24 @@ def test_entry_bulk_operations(self, platform: ixmp4.Platform): [3, "Integer", -9.9, RunMetaEntry.Type.FLOAT], [4, "String", -9.9, RunMetaEntry.Type.FLOAT], ], - columns=["id", "key", "value", "type"], + columns=["id", "key", "value", "dtype"], ) updated_entries["run__id"] = run.id - updated_entries["value"] = updated_entries["value"].astype("object") # type: ignore + updated_entries["value"] = updated_entries["value"].astype("object") - platform.backend.meta.bulk_upsert(updated_entries.drop(columns=["id", "type"])) + platform.backend.meta.bulk_upsert(updated_entries.drop(columns=["id", "dtype"])) ret = platform.backend.meta.tabulate() assert_unordered_equality(updated_entries, ret, check_like=True) # == Full Removal == - remove_data = entries.drop(columns=["value", "id", "type"]) + remove_data = entries.drop(columns=["value", "id", "dtype"]) platform.backend.meta.bulk_delete(remove_data) ret = platform.backend.meta.tabulate() assert ret.empty - def test_meta_bulk_exceptions(self, platform: ixmp4.Platform): + def test_meta_bulk_exceptions(self, platform: ixmp4.Platform) -> None: entries = pd.DataFrame( [ ["Boolean", -9.9], diff --git a/tests/data/test_model.py b/tests/data/test_model.py index f06b2c21..ad5d4e9e 100644 --- a/tests/data/test_model.py +++ b/tests/data/test_model.py @@ -11,28 +11,28 @@ class TestDataModel: filter = FilterIamcDataset() - def test_create_model(self, platform: ixmp4.Platform): + def test_create_model(self, platform: ixmp4.Platform) -> None: model = platform.backend.models.create("Model") assert model.name == "Model" assert model.created_at is not None assert model.created_by == "@unknown" - def test_model_unique(self, platform: ixmp4.Platform): + def test_model_unique(self, platform: ixmp4.Platform) -> None: platform.backend.models.create("Model") with pytest.raises(Model.NotUnique): platform.models.create("Model") - def test_get_model(self, platform: ixmp4.Platform): + def test_get_model(self, platform: ixmp4.Platform) -> None: model1 = platform.backend.models.create("Model") model2 = platform.backend.models.get("Model") assert model1 == model2 - def test_model_not_found(self, platform: ixmp4.Platform): + def test_model_not_found(self, platform: ixmp4.Platform) -> None: with pytest.raises(Model.NotFound): platform.models.get("Model") - def test_list_model(self, platform: ixmp4.Platform): + def test_list_model(self, platform: ixmp4.Platform) -> None: platform.runs.create("Model 1", "Scenario") platform.runs.create("Model 2", "Scenario") @@ -42,7 +42,7 @@ def test_list_model(self, platform: ixmp4.Platform): assert models[1].id == 2 assert models[1].name == "Model 2" - def test_tabulate_model(self, platform: ixmp4.Platform): + def test_tabulate_model(self, platform: ixmp4.Platform) -> None: platform.runs.create("Model 1", "Scenario") platform.runs.create("Model 2", "Scenario") @@ -59,12 +59,12 @@ def test_tabulate_model(self, platform: ixmp4.Platform): models.drop(columns=["created_at", "created_by"]), true_models ) - def test_map_model(self, platform: ixmp4.Platform): + def test_map_model(self, platform: ixmp4.Platform) -> None: platform.runs.create("Model 1", "Scenario") platform.runs.create("Model 2", "Scenario") assert platform.backend.models.map() == {1: "Model 1", 2: "Model 2"} - def test_filter_model(self, platform: ixmp4.Platform): + def test_filter_model(self, platform: ixmp4.Platform) -> None: run1, run2 = self.filter.load_dataset(platform) run2.set_as_default() diff --git a/tests/data/test_optimization_equation.py b/tests/data/test_optimization_equation.py index 48aa29ae..3cd05969 100644 --- a/tests/data/test_optimization_equation.py +++ b/tests/data/test_optimization_equation.py @@ -11,7 +11,7 @@ from ..utils import assert_unordered_equality, create_indexsets_for_run -def df_from_list(equations: list): +def df_from_list(equations: list[Equation]) -> pd.DataFrame: return pd.DataFrame( [ [ @@ -36,7 +36,7 @@ def df_from_list(equations: list): class TestDataOptimizationEquation: - def test_create_equation(self, platform: ixmp4.Platform): + def test_create_equation(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") # Test normal creation @@ -104,7 +104,7 @@ def test_create_equation(self, platform: ixmp4.Platform): assert equation_3.columns[0].dtype == "object" assert equation_3.columns[1].dtype == "int64" - def test_get_equation(self, platform: ixmp4.Platform): + def test_get_equation(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 @@ -121,7 +121,7 @@ def test_get_equation(self, platform: ixmp4.Platform): run_id=run.id, name="Equation 2" ) - def test_equation_add_data(self, platform: ixmp4.Platform): + def test_equation_add_data(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id @@ -273,7 +273,7 @@ def test_equation_add_data(self, platform: ixmp4.Platform): ) assert_unordered_equality(expected, pd.DataFrame(equation_4.data)) - def test_equation_remove_data(self, platform: ixmp4.Platform): + def test_equation_remove_data(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 @@ -303,7 +303,7 @@ def test_equation_remove_data(self, platform: ixmp4.Platform): ) assert equation.data == {} - def test_list_equation(self, platform: ixmp4.Platform): + def test_list_equation(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") # Per default, list() lists scalars for `default` version runs: platform.backend.runs.set_as_default_version(run.id) @@ -341,7 +341,7 @@ def test_list_equation(self, platform: ixmp4.Platform): equation_4, ] == platform.backend.optimization.equations.list(run_id=run_2.id) - def test_tabulate_equation(self, platform: ixmp4.Platform): + def test_tabulate_equation(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id diff --git a/tests/data/test_optimization_indexset.py b/tests/data/test_optimization_indexset.py index d4cd5bbc..89e6f4d1 100644 --- a/tests/data/test_optimization_indexset.py +++ b/tests/data/test_optimization_indexset.py @@ -48,7 +48,7 @@ def df_from_list(indexsets: list[IndexSet], include_data: bool = False) -> pd.Da class TestDataOptimizationIndexSet: - def test_create_indexset(self, platform: ixmp4.Platform): + def test_create_indexset(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset_1 = platform.backend.optimization.indexsets.create( run_id=run.id, name="Indexset" @@ -62,7 +62,7 @@ def test_create_indexset(self, platform: ixmp4.Platform): run_id=run.id, name="Indexset" ) - def test_get_indexset(self, platform: ixmp4.Platform): + def test_get_indexset(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") create_indexsets_for_run(platform=platform, run_id=run.id, amount=1) indexset = platform.backend.optimization.indexsets.get( @@ -77,7 +77,7 @@ def test_get_indexset(self, platform: ixmp4.Platform): run_id=run.id, name="Indexset 2" ) - def test_list_indexsets(self, platform: ixmp4.Platform): + def test_list_indexsets(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset_1, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id @@ -99,7 +99,7 @@ def test_list_indexsets(self, platform: ixmp4.Platform): run_id=run_2.id ) - def test_tabulate_indexsets(self, platform: ixmp4.Platform): + def test_tabulate_indexsets(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset_1, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id @@ -147,23 +147,21 @@ def test_tabulate_indexsets(self, platform: ixmp4.Platform): ), ) - def test_add_data(self, platform: ixmp4.Platform): + def test_add_data(self, platform: ixmp4.Platform) -> None: test_data = ["foo", "bar"] run = platform.backend.runs.create("Model", "Scenario") indexset_1, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id ) platform.backend.optimization.indexsets.add_data( - indexset_id=indexset_1.id, - data=test_data, # type: ignore + indexset_id=indexset_1.id, data=test_data ) indexset_1 = platform.backend.optimization.indexsets.get( run_id=run.id, name=indexset_1.name ) platform.backend.optimization.indexsets.add_data( - indexset_id=indexset_2.id, - data=test_data, # type: ignore + indexset_id=indexset_2.id, data=test_data ) assert ( @@ -188,7 +186,7 @@ def test_add_data(self, platform: ixmp4.Platform): platform=platform, run_id=run.id, offset=3 ) - test_data_2: list[float | int | str] = [1.2, 3.4, 5.6] + test_data_2 = [1.2, 3.4, 5.6] platform.backend.optimization.indexsets.add_data( indexset_id=indexset_3.id, data=test_data_2 ) @@ -199,7 +197,7 @@ def test_add_data(self, platform: ixmp4.Platform): assert indexset_3.data == test_data_2 assert type(indexset_3.data[0]).__name__ == "float" - test_data_3: list[float | int | str] = [0, 1, 2] + test_data_3 = [0, 1, 2] platform.backend.optimization.indexsets.add_data( indexset_id=indexset_4.id, data=test_data_3 ) diff --git a/tests/data/test_optimization_parameter.py b/tests/data/test_optimization_parameter.py index 7993596f..32c595db 100644 --- a/tests/data/test_optimization_parameter.py +++ b/tests/data/test_optimization_parameter.py @@ -11,7 +11,7 @@ from ..utils import assert_unordered_equality, create_indexsets_for_run -def df_from_list(parameters: list): +def df_from_list(parameters: list[Parameter]) -> pd.DataFrame: return pd.DataFrame( [ [ @@ -36,7 +36,7 @@ def df_from_list(parameters: list): class TestDataOptimizationParameter: - def test_create_parameter(self, platform: ixmp4.Platform): + def test_create_parameter(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") # Test normal creation @@ -106,7 +106,7 @@ def test_create_parameter(self, platform: ixmp4.Platform): assert parameter_3.columns[0].dtype == "object" assert parameter_3.columns[1].dtype == "int64" - def test_get_parameter(self, platform: ixmp4.Platform): + def test_get_parameter(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") create_indexsets_for_run(platform=platform, run_id=run.id, amount=1) parameter = platform.backend.optimization.parameters.create( @@ -121,7 +121,7 @@ def test_get_parameter(self, platform: ixmp4.Platform): run_id=run.id, name="Parameter 2" ) - def test_parameter_add_data(self, platform: ixmp4.Platform): + def test_parameter_add_data(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") unit = platform.backend.units.create("Unit") indexset, indexset_2 = create_indexsets_for_run( @@ -283,7 +283,7 @@ def test_parameter_add_data(self, platform: ixmp4.Platform): ) assert_unordered_equality(expected, pd.DataFrame(parameter_4.data)) - def test_list_parameter(self, platform: ixmp4.Platform): + def test_list_parameter(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id @@ -324,7 +324,7 @@ def test_list_parameter(self, platform: ixmp4.Platform): parameter_4, ] == platform.backend.optimization.parameters.list(run_id=run_2.id) - def test_tabulate_parameter(self, platform: ixmp4.Platform): + def test_tabulate_parameter(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id diff --git a/tests/data/test_optimization_scalar.py b/tests/data/test_optimization_scalar.py index 138800c7..113f3ca3 100644 --- a/tests/data/test_optimization_scalar.py +++ b/tests/data/test_optimization_scalar.py @@ -3,10 +3,10 @@ import pytest import ixmp4 -from ixmp4 import Scalar +from ixmp4.data.abstract import Scalar -def df_from_list(scalars: list): +def df_from_list(scalars: list[Scalar]) -> pd.DataFrame: return pd.DataFrame( [ [ @@ -33,7 +33,7 @@ def df_from_list(scalars: list): class TestDataOptimizationScalar: - def test_create_scalar(self, platform: ixmp4.Platform): + def test_create_scalar(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") unit = platform.backend.units.create("Unit") unit2 = platform.backend.units.create("Unit 2") @@ -50,7 +50,7 @@ def test_create_scalar(self, platform: ixmp4.Platform): run_id=run.id, name="Scalar", value=2, unit_name=unit2.name ) - def test_get_scalar(self, platform: ixmp4.Platform): + def test_get_scalar(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") unit = platform.backend.units.create("Unit") scalar = platform.backend.optimization.scalars.create( @@ -65,7 +65,7 @@ def test_get_scalar(self, platform: ixmp4.Platform): run_id=run.id, name="Scalar 2" ) - def test_update_scalar(self, platform: ixmp4.Platform): + def test_update_scalar(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") unit = platform.backend.units.create("Unit") unit2 = platform.backend.units.create("Unit 2") @@ -83,7 +83,7 @@ def test_update_scalar(self, platform: ixmp4.Platform): assert ret.unit__id == unit2.id assert ret.value == 20 - def test_list_scalars(self, platform: ixmp4.Platform): + def test_list_scalars(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") unit = platform.backend.units.create("Unit") unit2 = platform.backend.units.create("Unit 2") @@ -96,7 +96,7 @@ def test_list_scalars(self, platform: ixmp4.Platform): assert [scalar_1] == platform.backend.optimization.scalars.list(name="Scalar") assert [scalar_1, scalar_2] == platform.backend.optimization.scalars.list() - def test_tabulate_scalars(self, platform: ixmp4.Platform): + def test_tabulate_scalars(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") unit = platform.backend.units.create("Unit") unit2 = platform.backend.units.create("Unit 2") diff --git a/tests/data/test_optimization_table.py b/tests/data/test_optimization_table.py index 5db61cc6..9f7aab06 100644 --- a/tests/data/test_optimization_table.py +++ b/tests/data/test_optimization_table.py @@ -2,16 +2,16 @@ import pytest import ixmp4 -from ixmp4 import Table from ixmp4.core.exceptions import ( OptimizationDataValidationError, OptimizationItemUsageError, ) +from ixmp4.data.abstract import Table from ..utils import create_indexsets_for_run -def df_from_list(tables: list): +def df_from_list(tables: list[Table]) -> pd.DataFrame: return pd.DataFrame( [ [ @@ -36,7 +36,7 @@ def df_from_list(tables: list): class TestDataOptimizationTable: - def test_create_table(self, platform: ixmp4.Platform): + def test_create_table(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") # Test normal creation @@ -102,7 +102,7 @@ def test_create_table(self, platform: ixmp4.Platform): assert table_3.columns[0].dtype == "object" assert table_3.columns[1].dtype == "int64" - def test_get_table(self, platform: ixmp4.Platform): + def test_get_table(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") _, _ = create_indexsets_for_run(platform=platform, run_id=run.id) table = platform.backend.optimization.tables.create( @@ -115,7 +115,7 @@ def test_get_table(self, platform: ixmp4.Platform): with pytest.raises(Table.NotFound): _ = platform.backend.optimization.tables.get(run_id=run.id, name="Table 2") - def test_table_add_data(self, platform: ixmp4.Platform): + def test_table_add_data(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset_1, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id @@ -292,7 +292,7 @@ def test_table_add_data(self, platform: ixmp4.Platform): ) assert table_5.data == test_data_5 - def test_list_table(self, platform: ixmp4.Platform): + def test_list_table(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") create_indexsets_for_run(platform=platform, run_id=run.id) table = platform.backend.optimization.tables.create( @@ -319,7 +319,7 @@ def test_list_table(self, platform: ixmp4.Platform): run_id=run_2.id ) - def test_tabulate_table(self, platform: ixmp4.Platform): + def test_tabulate_table(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset_1, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id, offset=2 diff --git a/tests/data/test_optimization_variable.py b/tests/data/test_optimization_variable.py index 0c61ea72..f7969134 100644 --- a/tests/data/test_optimization_variable.py +++ b/tests/data/test_optimization_variable.py @@ -11,7 +11,7 @@ from ..utils import assert_unordered_equality, create_indexsets_for_run -def df_from_list(variables: list): +def df_from_list(variables: list[OptimizationVariable]) -> pd.DataFrame: return pd.DataFrame( [ [ @@ -36,7 +36,7 @@ def df_from_list(variables: list): class TestDataOptimizationVariable: - def test_create_variable(self, platform: ixmp4.Platform): + def test_create_variable(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") # Test creation without indexset @@ -128,7 +128,7 @@ def test_create_variable(self, platform: ixmp4.Platform): assert variable_4.columns[0].dtype == "object" assert variable_4.columns[1].dtype == "int64" - def test_get_variable(self, platform: ixmp4.Platform): + def test_get_variable(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") (indexset,) = create_indexsets_for_run( platform=platform, run_id=run.id, amount=1 @@ -145,7 +145,7 @@ def test_get_variable(self, platform: ixmp4.Platform): run_id=run.id, name="Variable 2" ) - def test_variable_add_data(self, platform: ixmp4.Platform): + def test_variable_add_data(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id @@ -297,7 +297,7 @@ def test_variable_add_data(self, platform: ixmp4.Platform): ) assert_unordered_equality(expected, pd.DataFrame(variable_4.data)) - def test_variable_remove_data(self, platform: ixmp4.Platform): + def test_variable_remove_data(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset = platform.backend.optimization.indexsets.create(run.id, "Indexset") platform.backend.optimization.indexsets.add_data( @@ -325,7 +325,7 @@ def test_variable_remove_data(self, platform: ixmp4.Platform): ) assert variable.data == {} - def test_list_variable(self, platform: ixmp4.Platform): + def test_list_variable(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id @@ -361,7 +361,7 @@ def test_list_variable(self, platform: ixmp4.Platform): variable_4, ] == platform.backend.optimization.variables.list(run_id=run_2.id) - def test_tabulate_variable(self, platform: ixmp4.Platform): + def test_tabulate_variable(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") indexset, indexset_2 = create_indexsets_for_run( platform=platform, run_id=run.id diff --git a/tests/data/test_region.py b/tests/data/test_region.py index 5823ef40..872c6ef3 100644 --- a/tests/data/test_region.py +++ b/tests/data/test_region.py @@ -11,19 +11,19 @@ class TestDataRegion: filter = FilterIamcDataset() - def test_create_region(self, platform: ixmp4.Platform): + def test_create_region(self, platform: ixmp4.Platform) -> None: region1 = platform.backend.regions.create("Region", "Hierarchy") assert region1.name == "Region" assert region1.hierarchy == "Hierarchy" assert region1.created_at is not None assert region1.created_by == "@unknown" - def test_delete_region(self, platform: ixmp4.Platform): + def test_delete_region(self, platform: ixmp4.Platform) -> None: region1 = platform.backend.regions.create("Region", "Hierarchy") platform.backend.regions.delete(region1.id) assert platform.backend.regions.tabulate().empty - def test_region_unique(self, platform: ixmp4.Platform): + def test_region_unique(self, platform: ixmp4.Platform) -> None: platform.backend.regions.create("Region", "Hierarchy") with pytest.raises(Region.NotUnique): @@ -32,16 +32,16 @@ def test_region_unique(self, platform: ixmp4.Platform): with pytest.raises(Region.NotUnique): platform.regions.create("Region", "Another Hierarchy") - def test_get_region(self, platform: ixmp4.Platform): + def test_get_region(self, platform: ixmp4.Platform) -> None: region1 = platform.backend.regions.create("Region", "Hierarchy") region2 = platform.backend.regions.get("Region") assert region1 == region2 - def test_region_not_found(self, platform: ixmp4.Platform): + def test_region_not_found(self, platform: ixmp4.Platform) -> None: with pytest.raises(Region.NotFound): platform.regions.get("Region") - def test_get_or_create_region(self, platform: ixmp4.Platform): + def test_get_or_create_region(self, platform: ixmp4.Platform) -> None: region1 = platform.backend.regions.create("Region", "Hierarchy") region2 = platform.backend.regions.get_or_create("Region") assert region1.id == region2.id @@ -51,7 +51,7 @@ def test_get_or_create_region(self, platform: ixmp4.Platform): with pytest.raises(Region.NotUnique): platform.backend.regions.get_or_create("Other", hierarchy="Other Hierarchy") - def test_list_region(self, platform: ixmp4.Platform): + def test_list_region(self, platform: ixmp4.Platform) -> None: platform.backend.regions.create("Region 1", "Hierarchy") platform.backend.regions.create("Region 2", "Hierarchy") @@ -63,7 +63,7 @@ def test_list_region(self, platform: ixmp4.Platform): assert regions[1].id == 2 assert regions[1].name == "Region 2" - def test_tabulate_region(self, platform: ixmp4.Platform): + def test_tabulate_region(self, platform: ixmp4.Platform) -> None: platform.backend.regions.create("Region 1", "Hierarchy") platform.backend.regions.create("Region 2", "Hierarchy") @@ -80,7 +80,7 @@ def test_tabulate_region(self, platform: ixmp4.Platform): regions.drop(columns=["created_at", "created_by"]), true_regions ) - def test_filter_region(self, platform: ixmp4.Platform): + def test_filter_region(self, platform: ixmp4.Platform) -> None: run1, run2 = self.filter.load_dataset(platform) res = platform.backend.regions.tabulate( diff --git a/tests/data/test_run.py b/tests/data/test_run.py index 8ae75ccd..e9ba4ebf 100644 --- a/tests/data/test_run.py +++ b/tests/data/test_run.py @@ -8,14 +8,14 @@ class TestDataRun: - def test_create_run(self, platform: ixmp4.Platform): + def test_create_run(self, platform: ixmp4.Platform) -> None: run1 = platform.backend.runs.create("Model", "Scenario") assert run1.model.name == "Model" assert run1.scenario.name == "Scenario" assert run1.version == 1 assert not run1.is_default - def test_create_run_increment_version(self, platform: ixmp4.Platform): + def test_create_run_increment_version(self, platform: ixmp4.Platform) -> None: platform.backend.runs.create("Model", "Scenario") run2 = platform.backend.runs.create("Model", "Scenario") assert run2.model.name == "Model" @@ -23,7 +23,7 @@ def test_create_run_increment_version(self, platform: ixmp4.Platform): assert run2.version == 2 assert not run2.is_default - def test_get_run_versions(self, platform: ixmp4.Platform): + def test_get_run_versions(self, platform: ixmp4.Platform) -> None: run1a = platform.backend.runs.create("Model", "Scenario") run2a = platform.backend.runs.create("Model", "Scenario") platform.backend.runs.set_as_default_version(run2a.id) @@ -40,11 +40,11 @@ def test_get_run_versions(self, platform: ixmp4.Platform): run3b = platform.backend.runs.get("Model", "Scenario", 3) assert run3a.id == run3b.id - def test_get_run_no_default_version(self, platform: ixmp4.Platform): + def test_get_run_no_default_version(self, platform: ixmp4.Platform) -> None: with pytest.raises(NoDefaultRunVersion): platform.backend.runs.get_default_version("Model", "Scenario") - def test_get_or_create_run(self, platform: ixmp4.Platform): + def test_get_or_create_run(self, platform: ixmp4.Platform) -> None: run1 = platform.backend.runs.create("Model", "Scenario") run2 = platform.backend.runs.get_or_create("Model", "Scenario") assert run1.id != run2.id @@ -55,7 +55,7 @@ def test_get_or_create_run(self, platform: ixmp4.Platform): run3 = platform.backend.runs.get_or_create("Model", "Scenario") assert run1.id == run3.id - def test_list_run(self, platform: ixmp4.Platform): + def test_list_run(self, platform: ixmp4.Platform) -> None: run1 = platform.backend.runs.create("Model", "Scenario") platform.backend.runs.create("Model", "Scenario") @@ -72,7 +72,7 @@ def test_list_run(self, platform: ixmp4.Platform): assert run1.id == run.id - def test_tabulate_run(self, platform: ixmp4.Platform): + def test_tabulate_run(self, platform: ixmp4.Platform) -> None: run = platform.backend.runs.create("Model", "Scenario") platform.backend.runs.set_as_default_version(run.id) platform.backend.runs.create("Model", "Scenario") @@ -105,7 +105,7 @@ def test_tabulate_run(self, platform: ixmp4.Platform): runs = platform.backend.runs.tabulate(default_only=False, iamc=True) assert runs.empty - def drop_audit_info(self, df): + def drop_audit_info(self, df: pd.DataFrame) -> None: df.drop( inplace=True, columns=["created_by", "created_at", "updated_by", "updated_at"], diff --git a/tests/data/test_scenario.py b/tests/data/test_scenario.py index ea31547b..225ddd44 100644 --- a/tests/data/test_scenario.py +++ b/tests/data/test_scenario.py @@ -11,28 +11,28 @@ class TestDataScenario: filter = FilterIamcDataset() - def test_create_scenario(self, platform: ixmp4.Platform): + def test_create_scenario(self, platform: ixmp4.Platform) -> None: scenario = platform.backend.scenarios.create("Scenario") assert scenario.name == "Scenario" assert scenario.created_at is not None assert scenario.created_by == "@unknown" - def test_scenario_unique(self, platform: ixmp4.Platform): + def test_scenario_unique(self, platform: ixmp4.Platform) -> None: platform.backend.scenarios.create("Scenario") with pytest.raises(Scenario.NotUnique): platform.scenarios.create("Scenario") - def test_get_scenario(self, platform: ixmp4.Platform): + def test_get_scenario(self, platform: ixmp4.Platform) -> None: scenario1 = platform.backend.scenarios.create("Scenario") scenario2 = platform.backend.scenarios.get("Scenario") assert scenario1 == scenario2 - def test_scenario_not_found(self, platform: ixmp4.Platform): + def test_scenario_not_found(self, platform: ixmp4.Platform) -> None: with pytest.raises(Scenario.NotFound): platform.scenarios.get("Scenario") - def test_list_scenario(self, platform: ixmp4.Platform): + def test_list_scenario(self, platform: ixmp4.Platform) -> None: platform.runs.create("Model", "Scenario 1") platform.runs.create("Model", "Scenario 2") @@ -43,7 +43,7 @@ def test_list_scenario(self, platform: ixmp4.Platform): assert scenarios[1].id == 2 assert scenarios[1].name == "Scenario 2" - def test_tabulate_scenario(self, platform: ixmp4.Platform): + def test_tabulate_scenario(self, platform: ixmp4.Platform) -> None: platform.runs.create("Model", "Scenario 1") platform.runs.create("Model", "Scenario 2") @@ -60,13 +60,13 @@ def test_tabulate_scenario(self, platform: ixmp4.Platform): scenarios.drop(columns=["created_at", "created_by"]), true_scenarios ) - def test_map_scenario(self, platform: ixmp4.Platform): + def test_map_scenario(self, platform: ixmp4.Platform) -> None: platform.runs.create("Model", "Scenario 1") platform.runs.create("Model", "Scenario 2") assert platform.backend.scenarios.map() == {1: "Scenario 1", 2: "Scenario 2"} - def test_filter_scenario(self, platform: ixmp4.Platform): + def test_filter_scenario(self, platform: ixmp4.Platform) -> None: run1, run2 = self.filter.load_dataset(platform) res = platform.backend.scenarios.tabulate( diff --git a/tests/data/test_unit.py b/tests/data/test_unit.py index 9e94e0af..0ef76faa 100644 --- a/tests/data/test_unit.py +++ b/tests/data/test_unit.py @@ -11,19 +11,19 @@ class TestDataUnit: filter = FilterIamcDataset() - def test_create_get_unit(self, platform: ixmp4.Platform): + def test_create_get_unit(self, platform: ixmp4.Platform) -> None: unit1 = platform.backend.units.create("Unit") assert unit1.name == "Unit" unit2 = platform.backend.units.get("Unit") assert unit1.id == unit2.id - def test_delete_unit(self, platform: ixmp4.Platform): + def test_delete_unit(self, platform: ixmp4.Platform) -> None: unit1 = platform.backend.units.create("Unit") platform.backend.units.delete(unit1.id) assert platform.backend.units.tabulate().empty - def test_get_or_create_unit(self, platform: ixmp4.Platform): + def test_get_or_create_unit(self, platform: ixmp4.Platform) -> None: unit1 = platform.backend.units.create("Unit") unit2 = platform.backend.units.get_or_create("Unit") assert unit1.id == unit2.id @@ -32,17 +32,17 @@ def test_get_or_create_unit(self, platform: ixmp4.Platform): assert unit3.name == "Another Unit" assert unit1.id != unit3.id - def test_unit_unique(self, platform: ixmp4.Platform): + def test_unit_unique(self, platform: ixmp4.Platform) -> None: platform.backend.units.create("Unit") with pytest.raises(Unit.NotUnique): platform.backend.units.create("Unit") - def test_unit_not_found(self, platform: ixmp4.Platform): + def test_unit_not_found(self, platform: ixmp4.Platform) -> None: with pytest.raises(Unit.NotFound): platform.backend.units.get("Unit") - def test_list_unit(self, platform: ixmp4.Platform): + def test_list_unit(self, platform: ixmp4.Platform) -> None: platform.backend.units.create("Unit 1") platform.backend.units.create("Unit 2") @@ -54,7 +54,7 @@ def test_list_unit(self, platform: ixmp4.Platform): assert units[1].id == 2 assert units[1].name == "Unit 2" - def test_tabulate_unit(self, platform: ixmp4.Platform): + def test_tabulate_unit(self, platform: ixmp4.Platform) -> None: platform.backend.units.create("Unit 1") platform.backend.units.create("Unit 2") @@ -71,7 +71,7 @@ def test_tabulate_unit(self, platform: ixmp4.Platform): units.drop(columns=["created_at", "created_by"]), true_units ) - def test_filter_unit(self, platform: ixmp4.Platform): + def test_filter_unit(self, platform: ixmp4.Platform) -> None: run1, run2 = self.filter.load_dataset(platform) res = platform.backend.units.tabulate( iamc={ diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 6b2bc89c..c6c0ecd4 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -17,17 +17,17 @@ class SmallIamcDataset: datetime["step_datetime"] = pd.to_datetime(datetime["step_datetime"]) @classmethod - def load_regions(cls, platform: ixmp4.Platform): + def load_regions(cls, platform: ixmp4.Platform) -> None: for _, name, hierarchy in cls.regions.itertuples(): platform.regions.create(name, hierarchy) @classmethod - def load_units(cls, platform: ixmp4.Platform): + def load_units(cls, platform: ixmp4.Platform) -> None: for _, name in cls.units.itertuples(): platform.units.create(name) @classmethod - def load_dataset(cls, platform: ixmp4.Platform): + def load_dataset(cls, platform: ixmp4.Platform) -> tuple[ixmp4.Run, ixmp4.Run]: cls.load_regions(platform) cls.load_units(platform) @@ -39,11 +39,13 @@ def load_dataset(cls, platform: ixmp4.Platform): datapoints = cls.annual.copy() run1.iamc.add(datapoints, type=ixmp4.DataPoint.Type.ANNUAL) - run1.meta = {"run": 1, "test": 0.1293, "bool": True} + # NOTE mypy doesn't support setters taking a different type than their property + # https://github.com/python/mypy/issues/3004 + run1.meta = {"run": 1, "test": 0.1293, "bool": True} # type: ignore[assignment] datapoints["variable"] = "Variable 4" run2.iamc.add(datapoints, type=ixmp4.DataPoint.Type.ANNUAL) - run2.meta = {"run": 2, "test": "string", "bool": False} + run2.meta = {"run": 2, "test": "string", "bool": False} # type: ignore[assignment] return run1, run2 @@ -53,17 +55,17 @@ class FilterIamcDataset: datapoints = pd.read_csv(here / "filters/datapoints.csv") @classmethod - def load_regions(cls, platform: ixmp4.Platform): + def load_regions(cls, platform: ixmp4.Platform) -> None: for _, name, hierarchy in cls.regions.itertuples(): platform.regions.create(name, hierarchy) @classmethod - def load_units(cls, platform: ixmp4.Platform): + def load_units(cls, platform: ixmp4.Platform) -> None: for _, name in cls.units.itertuples(): platform.units.create(name) @classmethod - def load_dataset(cls, platform: ixmp4.Platform): + def load_dataset(cls, platform: ixmp4.Platform) -> tuple[ixmp4.Run, ixmp4.Run]: cls.load_regions(platform) cls.load_units(platform) @@ -86,17 +88,17 @@ class MediumIamcDataset: run_cols = ["model", "scenario", "version"] @classmethod - def load_regions(cls, platform: ixmp4.Platform): + def load_regions(cls, platform: ixmp4.Platform) -> None: for _, name, hierarchy in cls.regions.itertuples(): platform.regions.create(name, hierarchy) @classmethod - def load_units(cls, platform: ixmp4.Platform): + def load_units(cls, platform: ixmp4.Platform) -> None: for _, name in cls.units.itertuples(): platform.units.create(name) @classmethod - def load_runs(cls, platform: ixmp4.Platform): + def load_runs(cls, platform: ixmp4.Platform) -> None: for _, model, scenario, version, is_default in cls.runs.itertuples(): run = platform.runs.create(model, scenario) if run.version != version: @@ -108,7 +110,7 @@ def load_runs(cls, platform: ixmp4.Platform): @classmethod def load_run_datapoints( cls, platform: ixmp4.Platform, run_tup: tuple[str, str, int], dps: pd.DataFrame - ): + ) -> None: run = platform.runs.get(*run_tup) annual = dps[dps["type"] == "ANNUAL"].dropna(how="all", axis="columns") @@ -124,7 +126,9 @@ def load_run_datapoints( run.iamc.add(datetime, type=ixmp4.DataPoint.Type.DATETIME) @classmethod - def get_run_dps(cls, df: pd.DataFrame, model, scenario, version): + def get_run_dps( + cls, df: pd.DataFrame, model: str, scenario: str, version: int + ) -> pd.DataFrame: dps = df.copy() dps = dps[dps["model"] == model] dps = dps[dps["scenario"] == scenario] @@ -133,7 +137,7 @@ def get_run_dps(cls, df: pd.DataFrame, model, scenario, version): return dps @classmethod - def load_dp_df(cls, platform: ixmp4.Platform, df: pd.DataFrame): + def load_dp_df(cls, platform: ixmp4.Platform, df: pd.DataFrame) -> None: runs = df[cls.run_cols].copy() runs.drop_duplicates(inplace=True) for _, model, scenario, version in runs.itertuples(): @@ -141,11 +145,11 @@ def load_dp_df(cls, platform: ixmp4.Platform, df: pd.DataFrame): cls.load_run_datapoints(platform, (model, scenario, version), dps) @classmethod - def load_datapoints(cls, platform: ixmp4.Platform): + def load_datapoints(cls, platform: ixmp4.Platform) -> None: cls.load_dp_df(platform, cls.datapoints) @classmethod - def load_dataset(cls, platform: ixmp4.Platform): + def load_dataset(cls, platform: ixmp4.Platform) -> None: cls.load_regions(platform) cls.load_units(platform) cls.load_runs(platform) @@ -160,17 +164,17 @@ class BigIamcDataset: run_cols = ["model", "scenario", "version"] @classmethod - def load_regions(cls, platform: ixmp4.Platform): + def load_regions(cls, platform: ixmp4.Platform) -> None: for _, name, hierarchy in cls.regions.itertuples(): platform.regions.create(name, hierarchy) @classmethod - def load_units(cls, platform: ixmp4.Platform): + def load_units(cls, platform: ixmp4.Platform) -> None: for _, name in cls.units.itertuples(): platform.units.create(name) @classmethod - def load_runs(cls, platform: ixmp4.Platform): + def load_runs(cls, platform: ixmp4.Platform) -> None: for _, model, scenario, version, is_default in cls.runs.itertuples(): run = platform.runs.create(model, scenario) if run.version != version: @@ -180,7 +184,7 @@ def load_runs(cls, platform: ixmp4.Platform): run.set_as_default() @classmethod - def load_dp_df(cls, platform: ixmp4.Platform, df: pd.DataFrame): + def load_dp_df(cls, platform: ixmp4.Platform, df: pd.DataFrame) -> None: runs = df[cls.run_cols].copy() runs.drop_duplicates(inplace=True) for _, model, scenario, version in runs.itertuples(): @@ -188,7 +192,9 @@ def load_dp_df(cls, platform: ixmp4.Platform, df: pd.DataFrame): cls.load_run_datapoints(platform, (model, scenario, version), dps) @classmethod - def get_run_dps(cls, df: pd.DataFrame, model, scenario, version): + def get_run_dps( + cls, df: pd.DataFrame, model: str, scenario: str, version: int + ) -> pd.DataFrame: dps = df.copy() dps = dps[dps["model"] == model] dps = dps[dps["scenario"] == scenario] @@ -197,7 +203,7 @@ def get_run_dps(cls, df: pd.DataFrame, model, scenario, version): return dps @classmethod - def rm_dp_df(cls, platform: ixmp4.Platform, df: pd.DataFrame): + def rm_dp_df(cls, platform: ixmp4.Platform, df: pd.DataFrame) -> None: runs = df[cls.run_cols].copy() runs.drop_duplicates(inplace=True) for _, model, scenario, version in runs.itertuples(): @@ -205,11 +211,11 @@ def rm_dp_df(cls, platform: ixmp4.Platform, df: pd.DataFrame): cls.rm_run_datapoints(platform, (model, scenario, version), dps) @classmethod - def load_datapoints(cls, platform: ixmp4.Platform): + def load_datapoints(cls, platform: ixmp4.Platform) -> None: cls.load_dp_df(platform, cls.datapoints) @classmethod - def load_datapoints_half(cls, platform: ixmp4.Platform): + def load_datapoints_half(cls, platform: ixmp4.Platform) -> None: scrambled_dps = cls.datapoints.sample(frac=1) half_dps = scrambled_dps.head(len(scrambled_dps) // 2) half_dps = half_dps.sort_values(by=cls.run_cols) @@ -218,7 +224,7 @@ def load_datapoints_half(cls, platform: ixmp4.Platform): @classmethod def load_run_datapoints( cls, platform: ixmp4.Platform, run_tup: tuple[str, str, int], dps: pd.DataFrame - ): + ) -> None: run = platform.runs.get(*run_tup) annual = dps[dps["type"] == "ANNUAL"].dropna(how="all", axis="columns") @@ -233,7 +239,7 @@ def load_run_datapoints( @classmethod def rm_run_datapoints( cls, platform: ixmp4.Platform, run_tup: tuple[str, str, int], dps: pd.DataFrame - ): + ) -> None: run = platform.runs.get(*run_tup) annual = dps[dps["type"] == "ANNUAL"].dropna(how="all", axis="columns") @@ -246,7 +252,7 @@ def rm_run_datapoints( run.iamc.remove(datetime, type=ixmp4.DataPoint.Type.DATETIME) @classmethod - def load_dataset(cls, platform: ixmp4.Platform): + def load_dataset(cls, platform: ixmp4.Platform) -> None: cls.load_regions(platform) cls.load_units(platform) cls.load_runs(platform) diff --git a/tests/test_api.py b/tests/test_api.py index 0489e7c7..e0e76087 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -12,7 +12,7 @@ class TestApi: small = SmallIamcDataset() medium = MediumIamcDataset() - def assert_res(self, res: httpx.Response, is_success=True): + def assert_res(self, res: httpx.Response, is_success: bool = True) -> None: assert res.is_success == is_success def assert_table_res( @@ -20,8 +20,8 @@ def assert_table_res( res: httpx.Response, no_of_rows: int | None = None, has_columns: list[str] | None = None, - has_data_for_columns: dict[str, list] | None = None, - ): + has_data_for_columns: dict[str, list[str] | list[int]] | None = None, + ) -> None: self.assert_res(res) page = res.json() table = page["results"] @@ -50,36 +50,47 @@ def assert_table_res( def assert_paginated_res( self, - client, - endpoint, - filters: dict | None = None, + client: httpx.Client, + endpoint: str, + filters: dict[str, dict[str, bool]] | None = None, no_of_rows: int | None = None, - ): - total, offset, limit = None, None, None + ) -> None: + total: int | None = None + offset: int | None = None + limit: int | None = None ret_no_of_rows = 0 - while offset is None or offset + limit < total: + while ( + offset is None or limit is None or total is None + ) or offset + limit < total: url = endpoint + "?table=true" - if offset is not None: + if offset is not None and limit is not None: offset += limit url += f"&offset={offset}&limit={limit}" res = client.patch(url, json=filters) self.assert_res(res) page = res.json() pagination = page.pop("pagination") - offset, limit = pagination["offset"], pagination["limit"] - total = page.pop("total") + offset, limit = ( + cast(int, pagination["offset"]), + cast(int, pagination["limit"]), + ) + total = cast(int, page.pop("total")) table = page["results"] data = table["data"] page_no_of_rows = len(data) ret_no_of_rows += page_no_of_rows - num_expected = min(total - offset, limit) + num_expected = ( + 0 + if offset is None or limit is None or total is None + else min(total - offset, limit) + ) assert page_no_of_rows == num_expected if no_of_rows is not None: assert no_of_rows == ret_no_of_rows - def test_index_meta(self, rest_platform: ixmp4.Platform): + def test_index_meta(self, rest_platform: ixmp4.Platform) -> None: self.small.load_dataset(rest_platform) backend = cast(RestBackend, rest_platform.backend) @@ -87,7 +98,7 @@ def test_index_meta(self, rest_platform: ixmp4.Platform): self.assert_table_res( res, no_of_rows=6, - has_columns=["run__id", "key", "type"], + has_columns=["run__id", "key", "dtype"], ) res = backend.client.patch("meta/?table=true&join_run_index=true") @@ -100,7 +111,7 @@ def test_index_meta(self, rest_platform: ixmp4.Platform): res = backend.client.patch("meta/?table=false&join_run_index=true") assert res.status_code == 400 - def test_index_model(self, rest_platform: ixmp4.Platform): + def test_index_model(self, rest_platform: ixmp4.Platform) -> None: self.small.load_dataset(rest_platform) backend = cast(RestBackend, rest_platform.backend) table_endpoint = "iamc/models/?table=True" @@ -134,7 +145,7 @@ def test_index_model(self, rest_platform: ixmp4.Platform): has_data_for_columns={"name": ["Model 1"]}, ) - def test_index_scenario(self, rest_platform: ixmp4.Platform): + def test_index_scenario(self, rest_platform: ixmp4.Platform) -> None: self.small.load_dataset(rest_platform) backend = cast(RestBackend, rest_platform.backend) table_endpoint = "iamc/scenarios/?table=True" @@ -170,7 +181,7 @@ def test_index_scenario(self, rest_platform: ixmp4.Platform): has_data_for_columns={"name": ["Scenario 1", "Scenario 2"]}, ) - def test_index_region(self, rest_platform: ixmp4.Platform): + def test_index_region(self, rest_platform: ixmp4.Platform) -> None: self.small.load_dataset(rest_platform) backend = cast(RestBackend, rest_platform.backend) table_endpoint = "iamc/regions/?table=True" @@ -209,7 +220,7 @@ def test_index_region(self, rest_platform: ixmp4.Platform): has_data_for_columns={"id": [1]}, ) - def test_index_unit(self, rest_platform: ixmp4.Platform): + def test_index_unit(self, rest_platform: ixmp4.Platform) -> None: self.small.load_dataset(rest_platform) backend = cast(RestBackend, rest_platform.backend) table_endpoint = "iamc/units/?table=True" @@ -250,7 +261,7 @@ def test_index_unit(self, rest_platform: ixmp4.Platform): has_data_for_columns={"id": [1, 2]}, ) - def test_paginate_datapoints(self, rest_platform_med: ixmp4.Platform): + def test_paginate_datapoints(self, rest_platform_med: ixmp4.Platform) -> None: client = cast(RestBackend, rest_platform_med.backend).client endpoint = "iamc/datapoints/" filters = {"run": {"default_only": False}} diff --git a/tests/test_auth.py b/tests/test_auth.py index e6f6fb6c..b0b17dc2 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -142,7 +142,12 @@ class TestAuthContext: ), ], ) - def test_guards(self, sqlite_platform: ixmp4.Platform, user, truths): + def test_guards( + self, + sqlite_platform: ixmp4.Platform, + user: User, + truths: dict[str, dict[str, bool]], + ) -> None: mp = sqlite_platform backend = cast(SqlAlchemyBackend, mp.backend) self.small.load_dataset(mp) @@ -226,9 +231,11 @@ def test_guards(self, sqlite_platform: ixmp4.Platform, user, truths): ) with pytest.raises(Forbidden): - run.meta = {"meta": "test"} + # NOTE mypy doesn't support setters taking a different type than + # their property https://github.com/python/mypy/issues/3004 + run.meta = {"meta": "test"} # type: ignore[assignment] - def test_run_audit_info(self, db_platform: ixmp4.Platform): + def test_run_audit_info(self, db_platform: ixmp4.Platform) -> None: backend = cast(SqlAlchemyBackend, db_platform.backend) test_user = User(username="test_audit", is_verified=True, is_superuser=True) @@ -273,10 +280,10 @@ def test_run_audit_info(self, db_platform: ixmp4.Platform): def test_filters( self, db_platform: ixmp4.Platform, - model, - platform_info, - access, - ): + model: str, + platform_info: ManagerPlatformInfo, + access: str | None, + ) -> None: mp = db_platform backend = cast(SqlAlchemyBackend, mp.backend) user = User(username="User Carina", is_verified=True, groups=[6, 7]) @@ -287,7 +294,7 @@ def test_filters( run = mp.runs.create(model, "Scenario") annual_dps = self.small.annual.copy() run.iamc.add(annual_dps, type=ixmp4.DataPoint.Type.ANNUAL) - run.meta = {"meta": "test"} + run.meta = {"meta": "test"} # type: ignore[assignment] run.set_as_default() with backend.auth(user, self.mock_manager, platform_info): @@ -303,7 +310,7 @@ def test_filters( annual_dps.drop(columns=["value"]), type=ixmp4.DataPoint.Type.ANNUAL, ) - run.meta = {"meta": "test"} + run.meta = {"meta": "test"} # type: ignore[assignment] else: with pytest.raises(Forbidden): @@ -319,7 +326,7 @@ def test_filters( ) with pytest.raises(Forbidden): - run.meta = {"meta": "test"} + run.meta = {"meta": "test"} # type: ignore[assignment] else: with pytest.raises((ixmp4.Run.NotFound, ixmp4.Run.NoDefaultVersion)): mp.runs.get(model, "Scenario") @@ -330,7 +337,7 @@ def test_filters( assert mp.scenarios.tabulate().empty -def test_invalid_credentials(): +def test_invalid_credentials() -> None: # TODO: Use testing instance once available. # Using dev for now to reduce load on production environment. # @wronguser cannot exist ("@" is not allowed) and will therefore always be invalid. diff --git a/tests/test_benchmark_filters.py b/tests/test_benchmark_filters.py index 4e1d428e..e928b567 100644 --- a/tests/test_benchmark_filters.py +++ b/tests/test_benchmark_filters.py @@ -1,7 +1,11 @@ +from typing import Any + +import pandas as pd import pytest import ixmp4 +from .conftest import Profiled from .fixtures import BigIamcDataset big = BigIamcDataset() @@ -34,18 +38,24 @@ ], ) def test_filter_datapoints_benchmark( - platform: ixmp4.Platform, profiled, benchmark, filters -): - """Benchmarks a the filtration of `test_data_big`.""" + platform: ixmp4.Platform, + profiled: Profiled, + # NOTE can be specified once https://github.com/ionelmc/pytest-benchmark/issues/212 + # is closed + benchmark: Any, + filters: dict[str, dict[str, bool | str | list[str]]], +) -> None: + """Benchmarks the filtration of `test_data_big`.""" big.load_regions(platform) big.load_units(platform) big.load_runs(platform) big.load_datapoints(platform) - def run(): + def run() -> pd.DataFrame: with profiled(): - return platform.iamc.tabulate(**filters) + # Not sure why mypy complains here, maybe about covariance? + return platform.iamc.tabulate(**filters) # type: ignore[arg-type] df = benchmark.pedantic(run, warmup_rounds=5, rounds=10) assert not df.empty diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 3b23ee47..0bda6327 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -1,8 +1,13 @@ """This module only contains benchmarks, no assertions are made to validate the results.""" +from typing import Any + +import pandas as pd + import ixmp4 +from .conftest import Profiled from .fixtures import BigIamcDataset @@ -10,47 +15,47 @@ class TestBenchmarks: big = BigIamcDataset() def test_add_datapoints_full_benchmark( - self, platform: ixmp4.Platform, profiled, benchmark - ): + self, platform: ixmp4.Platform, profiled: Profiled, benchmark: Any + ) -> None: """Benchmarks a full insert of `test_data_big`.""" - def setup(): + def setup() -> tuple[tuple[ixmp4.Platform], dict[str, object]]: self.big.load_regions(platform) self.big.load_units(platform) self.big.load_runs(platform) return (platform,), {} - def run(mp): + def run(mp: ixmp4.Platform) -> None: with profiled(): self.big.load_datapoints(mp) benchmark.pedantic(run, setup=setup) def test_add_datapoints_half_unchanged_benchmark( - self, platform: ixmp4.Platform, profiled, benchmark - ): + self, platform: ixmp4.Platform, profiled: Profiled, benchmark: Any + ) -> None: """Benchmarks a full insert of `test_data_big` on a half-filled database.""" - def setup(): + def setup() -> tuple[tuple[ixmp4.Platform], dict[str, object]]: self.big.load_regions(platform) self.big.load_units(platform) self.big.load_runs(platform) self.big.load_datapoints_half(platform) return (platform,), {} - def run(mp): + def run(mp: ixmp4.Platform) -> None: with profiled(): self.big.load_datapoints(mp) benchmark.pedantic(run, setup=setup) def test_add_datapoints_half_insert_half_update_benchmark( - self, platform: ixmp4.Platform, profiled, benchmark - ): + self, platform: ixmp4.Platform, profiled: Profiled, benchmark: Any + ) -> None: """Benchmarks a full insert of `test_data_big` with changed values on a half-filled database.""" - def setup(): + def setup() -> tuple[tuple[ixmp4.Platform, pd.DataFrame], dict[str, object]]: self.big.load_regions(platform) self.big.load_units(platform) self.big.load_runs(platform) @@ -59,7 +64,7 @@ def setup(): datapoints["value"] = -9999 return (platform, datapoints), {} - def run(mp, data): + def run(mp: ixmp4.Platform, data: pd.DataFrame) -> None: with profiled(): self.big.load_dp_df(mp, data) @@ -69,11 +74,11 @@ def run(mp, data): assert ret["value"].unique() == [-9999] def test_remove_datapoints_benchmark( - self, platform: ixmp4.Platform, profiled, benchmark - ): + self, platform: ixmp4.Platform, profiled: Profiled, benchmark: Any + ) -> None: """Benchmarks a full removal of `test_data_big` from a filled database.""" - def setup(): + def setup() -> tuple[tuple[ixmp4.Platform, pd.DataFrame], dict[str, object]]: self.big.load_regions(platform) self.big.load_units(platform) self.big.load_runs(platform) @@ -81,7 +86,7 @@ def setup(): data = self.big.datapoints.copy().drop(columns=["value"]) return (platform, data), {} - def run(mp, data): + def run(mp: ixmp4.Platform, data: pd.DataFrame) -> None: with profiled(): self.big.rm_dp_df(mp, data) @@ -90,18 +95,18 @@ def run(mp, data): assert ret.empty def test_tabulate_datapoints_benchmark( - self, platform: ixmp4.Platform, profiled, benchmark - ): + self, platform: ixmp4.Platform, profiled: Profiled, benchmark: Any + ) -> None: """Benchmarks a full retrieval of `test_data_big` from a filled database.""" - def setup(): + def setup() -> tuple[tuple[ixmp4.Platform], dict[str, object]]: self.big.load_regions(platform) self.big.load_units(platform) self.big.load_runs(platform) self.big.load_datapoints(platform) return (platform,), {} - def run(mp): + def run(mp: ixmp4.Platform) -> None: with profiled(): mp.iamc.tabulate(run={"default_only": False}) diff --git a/tests/utils.py b/tests/utils.py index a35e7804..4e47c8cc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,22 @@ +import pandas as pd import pandas.testing as pdt +# Import this from typing when dropping 3.11 +from typing_extensions import TypedDict, Unpack + from ixmp4 import Platform from ixmp4.data.abstract.optimization import IndexSet -def assert_unordered_equality(df1, df2, **kwargs): +# Based on current usage +class AssertKwargs(TypedDict, total=False): + check_like: bool + check_dtype: bool + + +def assert_unordered_equality( + df1: pd.DataFrame, df2: pd.DataFrame, **kwargs: Unpack[AssertKwargs] +) -> None: df1 = df1.sort_index(axis=1) df1 = df1.sort_values(by=list(df1.columns)).reset_index(drop=True) df2 = df2.sort_index(axis=1) @@ -19,6 +31,6 @@ def create_indexsets_for_run( return tuple( platform.backend.optimization.indexsets.create( run_id=run_id, name=f"Indexset {i}" - ) # type: ignore + ) for i in range(offset, offset + amount) )