Skip to content

Commit

Permalink
Merge pull request #1030 from DSD-DBS/feat-enable-mypy-in-pre-commit
Browse files Browse the repository at this point in the history
feat: Enable mypy in pre-commit
  • Loading branch information
MoritzWeber0 authored Oct 9, 2023
2 parents 2cc8459 + 6583080 commit 89c50bb
Show file tree
Hide file tree
Showing 37 changed files with 302 additions and 114 deletions.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ repos:
- id: isort
entry: bash -c "cd backend && isort ."
types: [python]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
hooks:
- id: mypy
types_or: [python, spec]
files: "^backend/capellacollab"
exclude: "^backend/capellacollab/alembic/"
args: [--config-file=./backend/pyproject.toml]
additional_dependencies:
- fastapi
- pydantic
- sqlalchemy
- repo: local
hooks:
- id: pylint
Expand Down
4 changes: 4 additions & 0 deletions backend/capellacollab/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from capellacollab.core import logging as core_logging
from capellacollab.core.database import engine, migration
from capellacollab.core.logging import exceptions as logging_exceptions
from capellacollab.projects.toolmodels import (
exceptions as toolmodels_exceptions,
)
from capellacollab.projects.toolmodels.backups import (
exceptions as backups_exceptions,
)
Expand Down Expand Up @@ -133,6 +136,7 @@ async def healthcheck():

def register_exceptions():
tools_exceptions.register_exceptions(app)
toolmodels_exceptions.register_exceptions(app)
git_exceptions.register_exceptions(app)
gitlab_exceptions.register_exceptions(app)
git_handler_exceptions.register_exceptions(app)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: Copyright DB Netz AG and the capella-collab-manager contributors
# SPDX-License-Identifier: Apache-2.0

"""Make tool name required
Revision ID: 1a4208c18909
Revises: d8cf851562cd
Create Date: 2023-09-19 11:25:16.343948
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "1a4208c18909"
down_revision = "d8cf851562cd"
branch_labels = None
depends_on = None


def upgrade():
op.alter_column(
"tools",
"name",
existing_type=sa.String(),
nullable=False,
)
2 changes: 1 addition & 1 deletion backend/capellacollab/core/authentication/jwt_bearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class JWTBearer(security.HTTPBearer):
def __init__(self, auto_error: bool = True):
super().__init__(auto_error=auto_error)

async def __call__(
async def __call__( # type: ignore
self, request: fastapi.Request
) -> dict[str, t.Any] | None:
credentials: security.HTTPAuthorizationCredentials | None = (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright DB Netz AG and the capella-collab-manager contributors
# SPDX-License-Identifier: Apache-2.0

import typing as t

from capellacollab.config import config

Expand All @@ -13,8 +14,8 @@
cfg = config["authentication"]["azure"]


def get_jwk_cfg(token: str) -> dict[str, any]:
def get_jwk_cfg(token: str) -> dict[str, t.Any]:
return {
"audience": cfg["client"]["id"],
"key": KeyStore.key_for_token(token).dict(),
"key": KeyStore.key_for_token(token).model_dump(),
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
from jose import jwt

from capellacollab.config import config
from capellacollab.core.authentication.provider.models import (
InvalidTokenError,
JSONWebKeySet,
KeyIDNotFoundError,
)

from .. import models as provider_models

log = logging.getLogger(__name__)
cfg = config["authentication"]["azure"]
Expand All @@ -38,7 +35,7 @@ def __init__(

self.jwks_uri = jwks_uri
self.algorithms = algorithms
self.public_keys: dict[t.Any, t.Any] = {}
self.public_keys: dict[str, provider_models.JSONWebKey] = {}
self.key_refresh_interval = key_refresh_interval
self.public_keys_last_refreshed: float = 0
self.refresh_keys()
Expand All @@ -56,15 +53,15 @@ def refresh_keys(self) -> None:
except Exception:
log.error("Could not retrieve JWKS data from %s", self.jwks_uri)
return
jwks = JSONWebKeySet.parse_raw(resp.text)
jwks = provider_models.JSONWebKeySet.parse_raw(resp.text)
self.public_keys_last_refreshed = time.time()
self.public_keys.clear()
for key in jwks.keys:
self.public_keys[key.kid] = key

def key_for_token(
self, token: str, *, in_retry: int = 0
) -> dict[str, t.Any]:
) -> provider_models.JSONWebKey:
# Before we do anything, the validation keys may need to be refreshed.
# If so, refresh them.
if self.keys_need_refresh():
Expand All @@ -75,7 +72,9 @@ def key_for_token(
try:
unverified_claims = jwt.get_unverified_header(token)
except Exception:
raise InvalidTokenError("Unable to parse key ID from token")
raise provider_models.InvalidTokenError(
"Unable to parse key ID from token"
)

# See if we have the key identified by this key ID.
try:
Expand All @@ -85,7 +84,7 @@ def key_for_token(
# haven't refreshed keys yet), then try to refresh the keys and try
# again.
if in_retry:
raise KeyIDNotFoundError()
raise provider_models.KeyIDNotFoundError()
self.refresh_keys()
return self.key_for_token(token, in_retry=1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ async def api_get_token(
)
access_token = token["id_token"]

username = get_username(JWTBearer().validate_token(access_token))
validated_token = JWTBearer().validate_token(access_token)
assert validated_token

username = get_username(validated_token)

if user := users_crud.get_user_by_name(db, username):
users_crud.update_last_login(db, user)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def get_jwk_cfg(token: str) -> dict[str, t.Any]:
return {
"algorithms": ["RS256"],
"audience": cfg["audience"] or cfg["client"]["id"],
"key": KeyStore.key_for_token(token).dict(),
"key": KeyStore.key_for_token(token).model_dump(),
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from capellacollab.config import config
from capellacollab.core.authentication.provider import models

from .. import models as provider_models

log = logging.getLogger(__name__)
cfg = config["authentication"]["oauth"]

Expand All @@ -35,7 +37,7 @@ def __init__(
self.get_jwks_uri = get_jwks_uri
self.jwks_uri = ""
self.algorithms = algorithms
self.public_keys: dict[t.Any, t.Any] = {}
self.public_keys: dict[str, provider_models.JSONWebKey] = {}
self.key_refresh_interval = key_refresh_interval
self.public_keys_last_refreshed: float = 0

Expand All @@ -62,7 +64,7 @@ def refresh_keys(self) -> None:

def key_for_token(
self, token: str, *, in_retry: int = 0
) -> dict[str, t.Any]:
) -> provider_models.JSONWebKey:
# Before we do anything, the validation keys may need to be refreshed.
# If so, refresh them.
if self.keys_need_refresh():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ async def api_get_token(
body: TokenRequest, db: orm.Session = fastapi.Depends(database.get_db)
):
token = get_token(body.code)
access_token = token["id_token"]

username = get_username(JWTBearer().validate_token(token["access_token"]))
validated_token = JWTBearer().validate_token(access_token)
assert validated_token

username = get_username(validated_token)

if user := users_crud.get_user_by_name(db, username):
users_crud.update_last_login(db, user)
Expand Down
3 changes: 2 additions & 1 deletion backend/capellacollab/core/database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright DB Netz AG and the capella-collab-manager contributors
# SPDX-License-Identifier: Apache-2.0

import typing as t

import pydantic
import sqlalchemy as sa
Expand All @@ -23,7 +24,7 @@ class Base(orm.DeclarativeBase):
from . import models # isort:skip # pylint: disable=unused-import


def get_db() -> orm.Session:
def get_db() -> t.Iterator[orm.Session]:
with SessionLocal() as session:
yield session

Expand Down
10 changes: 8 additions & 2 deletions backend/capellacollab/core/database/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def create_tools(db):

def create_t4c_instance_and_repositories(db):
tool = tools_crud.get_tool_by_name(db, "Capella")
assert tool

version = tools_crud.get_version_by_tool_id_version_name(
db, tool.id, "5.2.0"
)
Expand Down Expand Up @@ -189,13 +191,17 @@ def create_t4c_instance_and_repositories(db):

def create_models(db: orm.Session):
capella_tool = tools_crud.get_tool_by_name(db, "Capella")
assert capella_tool

default_project = projects_crud.get_project_by_slug(db, "default")
assert default_project

for version in ["5.0.0", "5.2.0", "6.0.0"]:
capella_model = toolmodels_crud.create_model(
db=db,
project=projects_crud.get_project_by_slug(db, "default"),
project=default_project,
post_model=toolmodels_models.PostCapellaModel(
name=f"Meldody Model Test {version}",
name=f"Melody Model Test {version}",
description="",
tool_id=capella_tool.id,
),
Expand Down
14 changes: 7 additions & 7 deletions backend/capellacollab/projects/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@


import logging
from collections import abc

import fastapi
import slugify
Expand Down Expand Up @@ -39,7 +38,7 @@
)


@router.get("", response_model=abc.Sequence[models.Project], tags=["Projects"])
@router.get("", response_model=list[models.Project], tags=["Projects"])
def get_projects(
user: users_models.DatabaseUser = fastapi.Depends(
users_injectables.get_own_user
Expand All @@ -49,18 +48,19 @@ def get_projects(
log: logging.LoggerAdapter = fastapi.Depends(
core_logging.get_request_logger
),
) -> abc.Sequence[models.DatabaseProject]:
) -> list[models.DatabaseProject]:
if auth_injectables.RoleVerification(
required_role=users_models.Role.ADMIN, verify=False
)(token, db):
log.debug("Fetching all projects")
return crud.get_projects(db)
return list(crud.get_projects(db))

projects = [
association.project
for association in user.projects
if not association.project.visibility == models.Visibility.INTERNAL
] + crud.get_internal_projects(db)
]
projects.extend(crud.get_internal_projects(db))

log.debug("Fetching the following projects: %s", projects)
return projects
Expand Down Expand Up @@ -89,7 +89,7 @@ def update_project(
if crud.get_project_by_slug(db, new_slug) and project.slug != new_slug:
raise fastapi.HTTPException(
status_code=status.HTTP_409_CONFLICT,
details={
detail={
"reason": "A project with a similar name already exists.",
"technical": "Slug already used",
},
Expand Down Expand Up @@ -137,7 +137,7 @@ def create_project(
project = crud.create_project(
db,
post_project.name,
post_project.description,
post_project.description or "",
post_project.visibility,
)

Expand Down
4 changes: 4 additions & 0 deletions backend/capellacollab/projects/toolmodels/backups/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from capellacollab.tools import crud as tools_crud
from capellacollab.users import models as users_models

from .. import exceptions as toolmodels_exceptions
from . import core, crud, exceptions, injectables, models
from .runs import routes as runs_routes

Expand Down Expand Up @@ -100,6 +101,9 @@ def create_backup(
)

if body.run_nightly:
if not capella_model.version_id:
raise toolmodels_exceptions.VersionIdNotSetError(capella_model.id)

reference = operators.get_operator().create_cronjob(
image=tools_crud.get_backup_image_for_tool_version(
db, capella_model.version_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from capellacollab.config import config
from capellacollab.core import database
from capellacollab.core.logging import loki
from capellacollab.projects.toolmodels import (
exceptions as toolmodels_exceptions,
)
from capellacollab.projects.toolmodels.backups import core as backups_core
from capellacollab.sessions import operators
from capellacollab.tools import crud as tools_crud
Expand Down Expand Up @@ -54,20 +57,21 @@ def _schedule_pending_jobs():
pending_run.pipeline.model.slug,
)
try:
model = pending_run.pipeline.model

if not model.version_id:
raise toolmodels_exceptions.VersionIdNotSetError(model.id)

job_name = operators.get_operator().create_job(
image=tools_crud.get_backup_image_for_tool_version(
db, pending_run.pipeline.model.version_id
db, model.version_id
),
command="backup",
labels={
"app.capellacollab/projectSlug": pending_run.pipeline.model.project.slug,
"app.capellacollab/projectID": str(
pending_run.pipeline.model.project.id
),
"app.capellacollab/modelSlug": pending_run.pipeline.model.slug,
"app.capellacollab/modelID": str(
pending_run.pipeline.model.id
),
"app.capellacollab/projectSlug": model.project.slug,
"app.capellacollab/projectID": str(model.project.id),
"app.capellacollab/modelSlug": model.slug,
"app.capellacollab/modelID": str(model.id),
"app.capellacollab/pipelineID": str(
pending_run.pipeline.id
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def check_last_pipeline_run_status(
db: orm.Session, model: toolmodel_models.DatabaseCapellaModel
) -> runs_models.PipelineRunStatus:
) -> runs_models.PipelineRunStatus | None:
if pipeline := crud.get_first_pipeline_for_tool_model(db, model):
# Only consider first pipeline for monitoring, usually there is only one pipeline.
if pipeline_run := runs_crud.get_last_pipeline_run_of_pipeline(
Expand Down
Loading

0 comments on commit 89c50bb

Please sign in to comment.