Skip to content

Commit

Permalink
Enable professional-grade mypy (#129)
Browse files Browse the repository at this point in the history
* Complete type hints
* Bump mypy to latest release version
* Enable mypy for docs and tests
* Note reasons for type: ignore for future reference
* DRY enumerate kwargs
* DRY more type hints using TypeAlias
* Exclude if TYPE_CHECKING from coverage
* Add mypy badge to README
* Make scalar.unit non-optional
* Bump pandas version in poetry.lock
  • Loading branch information
glatterf42 authored Dec 19, 2024
1 parent 362d503 commit e252b7c
Show file tree
Hide file tree
Showing 238 changed files with 3,915 additions and 2,289 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.12.0
rev: v1.13.0
hooks:
- id: mypy
entry: bash -c "poetry run mypy ."
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Copyright (c) 2023-2024 IIASA - Energy, Climate, and Environment Program (ECE)
[![license: MIT](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://github.com/iiasa/ixmp4/blob/main/LICENSE)
[![python](https://img.shields.io/badge/python-3.10_|_3.11_|_3.12_|_3.13-blue?logo=python&logoColor=white)](https://github.com/iiasa/ixmp4)
[![Code style: ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
[![Checked with mypy](https://www.mypy-lang.org/static/mypy_badge.svg)](https://mypy-lang.org/)

## Overview

Expand Down
2 changes: 1 addition & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns: list[str] = []


# -- Options for HTML output -------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions ixmp4/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def login(
prompt=True,
hide_input=True,
),
):
) -> None:
try:
auth = ManagerAuth(username, password, str(settings.manager_url))
user = auth.get_user()
Expand All @@ -62,7 +62,7 @@ def login(


@app.command()
def logout():
def logout() -> None:
if typer.confirm(
"Are you sure you want to log out and delete locally saved credentials?"
):
Expand All @@ -80,7 +80,7 @@ def test(
with_backend: Optional[bool] = False,
with_benchmarks: Optional[bool] = False,
dry: Optional[bool] = False,
):
) -> None:
opts = [
"--cov-report",
"xml:.coverage.xml",
Expand Down
39 changes: 23 additions & 16 deletions ixmp4/cli/platforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
from collections.abc import Generator, Iterator
from itertools import cycle
from pathlib import Path
from typing import Generator, Optional
from typing import Any, Optional, TypeVar

import typer
from rich.progress import Progress, track
Expand All @@ -21,15 +22,15 @@
app = typer.Typer()


def validate_name(name: str):
def validate_name(name: str) -> str:
match = re.match(r"^[\w\-_]*$", name)
if match is None:
raise typer.BadParameter("Platform name must be slug-like.")
else:
return name


def validate_dsn(dsn: str | None):
def validate_dsn(dsn: str | None) -> str | None:
if dsn is None:
return None
match = re.match(r"^(sqlite|postgresql\+psycopg|https|http)(\:\/\/)", dsn)
Expand All @@ -41,7 +42,7 @@ def validate_dsn(dsn: str | None):
return dsn


def prompt_sqlite_dsn(name: str):
def prompt_sqlite_dsn(name: str) -> str:
path = sqlite.get_database_path(name)
dsn = sqlite.get_dsn(path)
if path.exists():
Expand Down Expand Up @@ -75,7 +76,7 @@ def add(
help="Data source name. Can be a http(s) URL or a database connection string.",
callback=validate_dsn,
),
):
) -> None:
try:
settings.toml.get_platform(name)
raise typer.BadParameter(
Expand All @@ -95,11 +96,11 @@ def add(
utils.good("\nPlatform added successfully.")


def prompt_sqlite_removal(dsn: str):
def prompt_sqlite_removal(dsn: str) -> None:
path = Path(dsn.replace("sqlite://", ""))
path_str = typer.style(path, fg=typer.colors.CYAN)
if typer.confirm(
"Do you want to remove the associated database file at " f"{path_str} as well?" # type: ignore
"Do you want to remove the associated database file at " f"{path_str} as well?"
):
path.unlink()
utils.echo("\nDatabase file deleted.")
Expand All @@ -112,7 +113,7 @@ def remove(
name: str = typer.Argument(
..., help="The string identifier of the platform to remove."
),
):
) -> None:
try:
platform = settings.toml.get_platform(name)
except PlatformNotFound:
Expand All @@ -127,7 +128,7 @@ def remove(
settings.toml.remove_platform(name)


def tabulate_toml_platforms(platforms: list[TomlPlatformInfo]):
def tabulate_toml_platforms(platforms: list[TomlPlatformInfo]) -> None:
toml_path_str = typer.style(settings.toml.path, fg=typer.colors.CYAN)
utils.echo(f"\nPlatforms registered in '{toml_path_str}'")
if len(platforms):
Expand All @@ -140,7 +141,7 @@ def tabulate_toml_platforms(platforms: list[TomlPlatformInfo]):

def tabulate_manager_platforms(
platforms: list[ManagerPlatformInfo],
):
) -> None:
manager_url_str = typer.style(settings.manager.url, fg=typer.colors.CYAN)
utils.echo(f"\nPlatforms accessible via '{manager_url_str}'")
utils.echo("\nName".ljust(21) + "Access".ljust(10) + "Notice")
Expand All @@ -154,7 +155,7 @@ def tabulate_manager_platforms(


@app.command("list", help="Lists all registered platforms.")
def list_():
def list_() -> None:
tabulate_toml_platforms(settings.toml.list_platforms())
if settings.manager is not None:
tabulate_manager_platforms(settings.manager.list_platforms())
Expand All @@ -166,7 +167,8 @@ def list_():
"revision."
)
)
def upgrade():
def upgrade() -> None:
platform_list: list[ManagerPlatformInfo] | list[TomlPlatformInfo]
if settings.managed:
utils.echo(
f"Establishing self-signed admin connection to '{settings.manager_url}'."
Expand Down Expand Up @@ -232,7 +234,7 @@ def generate(
num_datapoints: int = typer.Option(
30_000, "--datapoints", help="Number of mock datapoints to generate."
),
):
) -> None:
try:
platform = Platform(platform_name)
except PlatformNotFound:
Expand Down Expand Up @@ -270,7 +272,12 @@ def generate(
utils.good("Done!")


def create_cycle(generator: Generator, name: str, total: int):
T = TypeVar("T")


def create_cycle(
generator: Generator[T, Any, None], name: str, total: int
) -> Iterator[T]:
return cycle(
[
m
Expand All @@ -283,7 +290,7 @@ def create_cycle(generator: Generator, name: str, total: int):
)


def generate_data(generator: MockDataGenerator):
def generate_data(generator: MockDataGenerator) -> None:
model_names = create_cycle(
generator.yield_model_names(), "Model", generator.num_models
)
Expand All @@ -301,7 +308,7 @@ def generate_data(generator: MockDataGenerator):
progress.advance(task, len(df))


def _shorten(value: str, length: int):
def _shorten(value: str, length: int) -> str:
"""Shorten and adjust a string to a given length adding `...` if necessary"""
if len(value) > length - 4:
value = value[: length - 4] + "..."
Expand Down
6 changes: 4 additions & 2 deletions ixmp4/cli/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import typer
import uvicorn # type: ignore[import]
import uvicorn
from fastapi.openapi.utils import get_openapi

from ixmp4.conf import settings
Expand Down Expand Up @@ -33,7 +33,9 @@ def start(


@app.command()
def dump_schema(output_file: Optional[typer.FileTextWrite] = typer.Option(None, "-o")):
def dump_schema(
output_file: Optional[typer.FileTextWrite] = typer.Option(None, "-o"),
) -> None:
schema = get_openapi(
title=v1.title,
version=v1.version,
Expand Down
4 changes: 1 addition & 3 deletions ixmp4/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,4 @@
from ixmp4.conf.settings import Settings

load_dotenv()
# strict typechecking fails due to a bug
# https://docs.pydantic.dev/visual_studio_code/#adding-a-default-with-field
settings = Settings() # type: ignore
settings = Settings()
38 changes: 22 additions & 16 deletions ixmp4/conf/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from collections.abc import Generator
from datetime import datetime, timedelta
from typing import Any, cast
from uuid import uuid4

import httpx
Expand All @@ -13,10 +15,11 @@


class BaseAuth(object):
def __call__(self, *args, **kwargs):
# This should never be called
def __call__(self, *args: Any, **kwargs: Any) -> httpx.Request:
raise NotImplementedError

def auth_flow(self, request):
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, Any, None]:
yield self(request)

def get_user(self) -> User:
Expand All @@ -39,7 +42,7 @@ def __init__(self, secret: str, username: str = "ixmp4"):
)
self.token = self.get_local_jwt()

def __call__(self, r):
def __call__(self, r: httpx.Request) -> httpx.Request:
try:
jwt.decode(self.token, self.secret, algorithms=["HS256"])
except (jwt.InvalidTokenError, jwt.ExpiredSignatureError):
Expand All @@ -48,7 +51,7 @@ def __call__(self, r):
r.headers["Authorization"] = "Bearer " + self.token
return r

def get_local_jwt(self):
def get_local_jwt(self) -> str:
self.jti = uuid4().hex
return jwt.encode(
{
Expand All @@ -63,7 +66,7 @@ def get_local_jwt(self):
algorithm="HS256",
)

def get_expiration_timestamp(self):
def get_expiration_timestamp(self) -> int:
return int((datetime.now() + timedelta(minutes=15)).timestamp())

def get_user(self) -> User:
Expand All @@ -72,11 +75,11 @@ def get_user(self) -> User:


class AnonymousAuth(BaseAuth, httpx.Auth):
def __init__(self):
def __init__(self) -> None:
self.user = anonymous_user
logger.info("Connecting to service anonymously and without credentials.")

def __call__(self, r):
def __call__(self, r: httpx.Request) -> httpx.Request:
return r

def get_user(self) -> User:
Expand All @@ -97,7 +100,7 @@ def __init__(
self.password = password
self.obtain_jwt()

def __call__(self, r):
def __call__(self, r: httpx.Request) -> httpx.Request:
try:
jwt.decode(
self.access_token,
Expand All @@ -109,7 +112,7 @@ def __call__(self, r):
r.headers["Authorization"] = "Bearer " + self.access_token
return r

def obtain_jwt(self):
def obtain_jwt(self) -> None:
res = self.client.post(
"/token/obtain/",
json={
Expand All @@ -133,7 +136,7 @@ def obtain_jwt(self):
self.set_user(self.access_token)
self.refresh_token = json["refresh"]

def refresh_or_reobtain_jwt(self):
def refresh_or_reobtain_jwt(self) -> None:
try:
jwt.decode(
self.refresh_token,
Expand All @@ -143,7 +146,7 @@ def refresh_or_reobtain_jwt(self):
except jwt.ExpiredSignatureError:
self.obtain_jwt()

def refresh_jwt(self):
def refresh_jwt(self) -> None:
res = self.client.post(
"/token/refresh/",
json={
Expand All @@ -157,13 +160,16 @@ def refresh_jwt(self):
self.access_token = res.json()["access"]
self.set_user(self.access_token)

def decode_token(self, token: str):
return jwt.decode(
token,
options={"verify_signature": False, "verify_exp": False},
def decode_token(self, token: str) -> dict[str, Any]:
return cast(
dict[str, Any],
jwt.decode(
token,
options={"verify_signature": False, "verify_exp": False},
),
)

def set_user(self, token: str):
def set_user(self, token: str) -> None:
token_dict = self.decode_token(token)
user_dict = token_dict["user"]
self.user = User(**user_dict, jti=token_dict.get("jti"))
Expand Down
10 changes: 5 additions & 5 deletions ixmp4/conf/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,31 @@


class Credentials(object):
credentials: dict
credentials: dict[str, dict[str, str]]

def __init__(self, toml_file: Path) -> None:
self.path = toml_file
self.load()

def load(self):
def load(self) -> None:
self.credentials = toml.load(self.path)

def dump(self):
def dump(self) -> None:
f = self.path.open("w+")
toml.dump(self.credentials, f)

def get(self, key: str) -> tuple[str, str]:
c = self.credentials[key]
return (c["username"], c["password"])

def set(self, key: str, username: str, password: str):
def set(self, key: str, username: str, password: str) -> None:
self.credentials[key] = {
"username": username,
"password": password,
}
self.dump()

def clear(self, key: str):
def clear(self, key: str) -> None:
with suppress(KeyError):
del self.credentials[key]
self.dump()
Loading

0 comments on commit e252b7c

Please sign in to comment.