Skip to content

Commit

Permalink
refactor: Use dataclasses for all SQLAlchemy database classes
Browse files Browse the repository at this point in the history
With dataclasses, a `__init__` function is created for all database
classes. This adds autocomplete and code recommendations and type
checkers like mypy can check the passed types.

Since the `__hash__` method was removed during this change, the intersection
method to get common projects doesn't work anymore. I changed it to a proper
database join, which is faster.

In addition, many type annotations were added. Some mypy and pylint errors
and warnings were fixed in the tests.
  • Loading branch information
MoritzWeber0 committed Feb 9, 2024
1 parent c808389 commit 8d48fbf
Show file tree
Hide file tree
Showing 65 changed files with 636 additions and 423 deletions.
2 changes: 1 addition & 1 deletion backend/capellacollab/core/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SessionLocal = orm.sessionmaker(autocommit=False, autoflush=False, bind=engine)


class Base(orm.DeclarativeBase):
class Base(orm.MappedAsDataclass, orm.DeclarativeBase):
type_annotation_map = {
dict[str, str]: postgresql.JSONB,
dict[str, t.Any]: postgresql.JSONB,
Expand Down
34 changes: 18 additions & 16 deletions backend/capellacollab/core/database/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def migrate_db(engine, database_url: str):
create_coffee_machine_model(session)


def initialize_admin_user(db):
def initialize_admin_user(db: orm.Session):
LOGGER.info("Initialized adminuser %s", config["initial"]["admin"])
admin_user = users_crud.create_user(
db=db,
Expand All @@ -94,7 +94,7 @@ def initialize_admin_user(db):
events_crud.create_user_creation_event(db, admin_user)


def initialize_default_project(db):
def initialize_default_project(db: orm.Session):
LOGGER.info("Initialized project 'default'")
projects_crud.create_project(
db=db,
Expand All @@ -104,7 +104,7 @@ def initialize_default_project(db):
)


def initialize_coffee_machine_project(db):
def initialize_coffee_machine_project(db: orm.Session):
LOGGER.info("Initialize project 'Coffee Machine'")
projects_crud.create_project(
db=db,
Expand All @@ -114,7 +114,7 @@ def initialize_coffee_machine_project(db):
)


def create_tools(db):
def create_tools(db: orm.Session):
LOGGER.info("Initialized tools")
registry = config["docker"]["registry"]
if os.getenv("DEVELOPMENT_MODE", "").lower() in ("1", "true", "t"):
Expand All @@ -131,12 +131,12 @@ def create_tools(db):
)
tools_crud.create_tool(db, papyrus)

tools_crud.create_version(db, papyrus.id, "6.1")
tools_crud.create_version(db, papyrus.id, "6.0")
tools_crud.create_version(db, papyrus, "6.1")
tools_crud.create_version(db, papyrus, "6.0")

tools_crud.create_nature(db, papyrus.id, "UML 2.5")
tools_crud.create_nature(db, papyrus.id, "SysML 1.4")
tools_crud.create_nature(db, papyrus.id, "SysML 1.1")
tools_crud.create_nature(db, papyrus, "UML 2.5")
tools_crud.create_nature(db, papyrus, "SysML 1.4")
tools_crud.create_nature(db, papyrus, "SysML 1.1")

else:
# Use public Github images per default
Expand All @@ -154,21 +154,22 @@ def create_tools(db):
docker_image_template=f"{registry}/jupyter-notebook:$version",
)
tools_crud.create_tool(db, jupyter)
assert jupyter.integrations
integrations_crud.update_integrations(
db,
jupyter.integrations,
integrations_models.PatchToolIntegrations(jupyter=True),
)

default_version = tools_crud.create_version(db, capella.id, "6.0.0", True)
tools_crud.create_version(db, capella.id, "5.2.0")
tools_crud.create_version(db, capella.id, "5.0.0")
default_version = tools_crud.create_version(db, capella, "6.0.0", True)
tools_crud.create_version(db, capella, "5.2.0")
tools_crud.create_version(db, capella, "5.0.0")

tools_crud.create_version(db, jupyter.id, "python-3.11")
tools_crud.create_nature(db, jupyter.id, "notebooks")
tools_crud.create_version(db, jupyter, "python-3.11")
tools_crud.create_nature(db, jupyter, "notebooks")

default_nature = tools_crud.create_nature(db, capella.id, "model")
tools_crud.create_nature(db, capella.id, "library")
default_nature = tools_crud.create_nature(db, capella, "model")
tools_crud.create_nature(db, capella, "library")

for model in toolmodels_crud.get_models(db):
toolmodels_crud.set_tool_for_model(db, model, capella)
Expand All @@ -184,6 +185,7 @@ def create_t4c_instance_and_repositories(db):
version = tools_crud.get_version_by_tool_id_version_name(
db, tool.id, "5.2.0"
)
assert version
default_instance = settings_t4c_models.DatabaseT4CInstance(
name="default",
license="placeholder",
Expand Down
8 changes: 5 additions & 3 deletions backend/capellacollab/events/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ def create_event(
raise ValueError(
f"Event type must of one of the following: {allowed_types}"
)

event = models.DatabaseUserHistoryEvent(
user_id=user.id,
user=user,
event_type=event_type,
execution_time=datetime.datetime.now(datetime.UTC),
executor_id=executor.id if executor else None,
project_id=project.id if project else None,
executor=executor,
project=project,
reason=reason,
)

db.add(event)
db.commit()

Expand Down
41 changes: 23 additions & 18 deletions backend/capellacollab/events/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import datetime
import enum
import typing as t

import pydantic
import sqlalchemy as sa
Expand All @@ -14,10 +13,6 @@
from capellacollab.projects import models as projects_models
from capellacollab.users import models as users_models

if t.TYPE_CHECKING:
from capellacollab.projects.models import DatabaseProject
from capellacollab.users.models import DatabaseUser


class EventType(enum.Enum):
CREATED_USER = "CreatedUser"
Expand Down Expand Up @@ -55,27 +50,37 @@ class HistoryEvent(BaseHistoryEvent):
class DatabaseUserHistoryEvent(database.Base):
__tablename__ = "user_history_events"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True, index=True)
id: orm.Mapped[int] = orm.mapped_column(
init=False, primary_key=True, index=True
)

user_id: orm.Mapped[int] = orm.mapped_column(sa.ForeignKey("users.id"))
user: orm.Mapped["DatabaseUser"] = orm.relationship(
user_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("users.id"),
init=False,
)
user: orm.Mapped[users_models.DatabaseUser] = orm.relationship(
back_populates="events", foreign_keys=[user_id]
)

event_type: orm.Mapped[EventType]
reason: orm.Mapped[str | None] = orm.mapped_column(default=None)

executor_id: orm.Mapped[int | None] = orm.mapped_column(
sa.ForeignKey("users.id")
sa.ForeignKey("users.id"),
init=False,
)
executor: orm.Mapped["DatabaseUser"] = orm.relationship(
foreign_keys=[executor_id]
executor: orm.Mapped[users_models.DatabaseUser | None] = orm.relationship(
default=None, foreign_keys=[executor_id]
)

project_id: orm.Mapped[int | None] = orm.mapped_column(
sa.ForeignKey("projects.id")
)
project: orm.Mapped["DatabaseProject"] = orm.relationship(
foreign_keys=[project_id]
sa.ForeignKey("projects.id"),
init=False,
)
project: orm.Mapped[
projects_models.DatabaseProject | None
] = orm.relationship(default=None, foreign_keys=[project_id])

execution_time: orm.Mapped[datetime.datetime]
event_type: orm.Mapped[EventType]
reason: orm.Mapped[str | None]
execution_time: orm.Mapped[datetime.datetime] = orm.mapped_column(
default=datetime.datetime.now(datetime.UTC)
)
2 changes: 1 addition & 1 deletion backend/capellacollab/health/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def project_status(db: orm.Session = fastapi.Depends(database.get_db)):
def _create_tool_model_status_tasks(
db: orm.Session,
logger: logging.LoggerAdapter,
model: toolmodels_models.DatabaseCapellaModel,
model: toolmodels_models.DatabaseToolModel,
) -> models.ToolModelStatusTasks:
return models.ToolModelStatusTasks(
primary_git_repository_status=asyncio.create_task(
Expand Down
4 changes: 3 additions & 1 deletion backend/capellacollab/notices/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ class NoticeResponse(CreateNoticeRequest):
class DatabaseNotice(database.Base):
__tablename__ = "notices"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True, index=True)
id: orm.Mapped[int] = orm.mapped_column(
init=False, primary_key=True, index=True
)
title: orm.Mapped[str]
message: orm.Mapped[str]
level: orm.Mapped[NoticeLevel]
30 changes: 30 additions & 0 deletions backend/capellacollab/projects/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from sqlalchemy import orm

from capellacollab.core import database
from capellacollab.projects.users import models as project_users_models
from capellacollab.users import models as users_models

from . import models

Expand Down Expand Up @@ -40,6 +42,34 @@ def get_project_by_slug(
).scalar_one_or_none()


def get_common_projects_for_users(
db: orm.Session,
user1: users_models.DatabaseUser,
user2: users_models.DatabaseUser,
) -> abc.Sequence[models.DatabaseProject]:
user1_table = orm.aliased(project_users_models.ProjectUserAssociation)
user2_table = orm.aliased(project_users_models.ProjectUserAssociation)

return (
db.execute(
sa.select(models.DatabaseProject)
.join(
user1_table,
models.DatabaseProject.id == user1_table.project_id,
)
.join(
user2_table,
models.DatabaseProject.id == user2_table.project_id,
)
.where(user1_table.user_id == user1.id)
.where(user2_table.user_id == user2.id)
.distinct()
)
.scalars()
.all()
)


def update_project(
db: orm.Session,
project: models.DatabaseProject,
Expand Down
21 changes: 13 additions & 8 deletions backend/capellacollab/projects/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from capellacollab.projects.users import models as project_users_models

if t.TYPE_CHECKING:
from capellacollab.projects.toolmodels.models import DatabaseCapellaModel
from capellacollab.projects.toolmodels.models import DatabaseToolModel
from capellacollab.projects.users.models import ProjectUserAssociation


Expand Down Expand Up @@ -105,20 +105,25 @@ class DatabaseProject(database.Base):
__tablename__ = "projects"

id: orm.Mapped[int] = orm.mapped_column(
unique=True, primary_key=True, index=True
init=False, unique=True, primary_key=True, index=True
)

name: orm.Mapped[str] = orm.mapped_column(unique=True, index=True)
slug: orm.Mapped[str] = orm.mapped_column(unique=True, index=True)
description: orm.Mapped[str | None]
visibility: orm.Mapped[Visibility]
type: orm.Mapped[ProjectType]

description: orm.Mapped[str | None] = orm.mapped_column(default=None)
visibility: orm.Mapped[Visibility] = orm.mapped_column(
default=Visibility.PRIVATE
)
type: orm.Mapped[ProjectType] = orm.mapped_column(
default=ProjectType.GENERAL
)

users: orm.Mapped[list[ProjectUserAssociation]] = orm.relationship(
back_populates="project"
default_factory=list, back_populates="project"
)
models: orm.Mapped[list[DatabaseCapellaModel]] = orm.relationship(
back_populates="project"
models: orm.Mapped[list[DatabaseToolModel]] = orm.relationship(
default_factory=list, back_populates="project"
)

is_archived: orm.Mapped[bool] = orm.mapped_column(default=False)
4 changes: 2 additions & 2 deletions backend/capellacollab/projects/toolmodels/backups/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_pipeline_by_id(


def get_pipelines_for_tool_model(
db: orm.Session, model: toolmodels_models.DatabaseCapellaModel
db: orm.Session, model: toolmodels_models.DatabaseToolModel
) -> abc.Sequence[models.DatabaseBackup]:
return (
db.execute(
Expand All @@ -42,7 +42,7 @@ def get_pipelines_for_tool_model(


def get_first_pipeline_for_tool_model(
db: orm.Session, model: toolmodels_models.DatabaseCapellaModel
db: orm.Session, model: toolmodels_models.DatabaseToolModel
) -> models.DatabaseBackup | None:
return (
db.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

def get_existing_pipeline(
pipeline_id: int,
model: toolmodels_models.DatabaseCapellaModel = fastapi.Depends(
model: toolmodels_models.DatabaseToolModel = fastapi.Depends(
toolmodels_injectables.get_existing_capella_model
),
db: orm.Session = fastapi.Depends(database.get_db),
Expand Down
15 changes: 9 additions & 6 deletions backend/capellacollab/projects/toolmodels/backups/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

if t.TYPE_CHECKING:
from capellacollab.projects.toolmodels.models import DatabaseCapellaModel
from capellacollab.projects.toolmodels.models import DatabaseToolModel
from capellacollab.projects.toolmodels.modelsources.git.models import (
DatabaseGitModel,
)
Expand Down Expand Up @@ -51,7 +51,7 @@ class Backup(pydantic.BaseModel):
class DatabaseBackup(database.Base):
__tablename__ = "backups"
id: orm.Mapped[int] = orm.mapped_column(
primary_key=True, index=True, autoincrement=True
init=False, primary_key=True, index=True, autoincrement=True
)

created_by: orm.Mapped[str]
Expand All @@ -64,22 +64,25 @@ class DatabaseBackup(database.Base):
run_nightly: orm.Mapped[bool]

git_model_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("git_models.id")
sa.ForeignKey("git_models.id"), init=False
)
git_model: orm.Mapped["DatabaseGitModel"] = orm.relationship()

t4c_model_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("t4c_models.id")
sa.ForeignKey("t4c_models.id"), init=False
)
t4c_model: orm.Mapped["DatabaseT4CModel"] = orm.relationship()

model_id: orm.Mapped[int] = orm.mapped_column(sa.ForeignKey("models.id"))
model: orm.Mapped["DatabaseCapellaModel"] = orm.relationship()
model_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("models.id"), init=False
)
model: orm.Mapped["DatabaseToolModel"] = orm.relationship()

runs: orm.Mapped[
list["runs_models.DatabasePipelineRun"]
] = orm.relationship(
"DatabasePipelineRun",
back_populates="pipeline",
cascade="all, delete-orphan",
default_factory=list,
)
4 changes: 2 additions & 2 deletions backend/capellacollab/projects/toolmodels/backups/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

@router.get("", response_model=list[models.Backup])
def get_pipelines(
model: toolmodels_models.DatabaseCapellaModel = fastapi.Depends(
model: toolmodels_models.DatabaseToolModel = fastapi.Depends(
toolmodels_injectables.get_existing_capella_model
),
db: orm.Session = fastapi.Depends(database.get_db),
Expand All @@ -67,7 +67,7 @@ def get_pipeline(
@router.post("", response_model=models.Backup)
def create_backup(
body: models.CreateBackup,
capella_model: toolmodels_models.DatabaseCapellaModel = fastapi.Depends(
capella_model: toolmodels_models.DatabaseToolModel = fastapi.Depends(
toolmodels_injectables.get_existing_capella_model
),
db: orm.Session = fastapi.Depends(database.get_db),
Expand Down
Loading

0 comments on commit 8d48fbf

Please sign in to comment.