Skip to content

Commit

Permalink
refactor: Use dataclasses for all SQLAlchemy database classes
Browse files Browse the repository at this point in the history
With dataclasses, a `__init__` function is created for all database
classes. This adds autocomplete and code recommendations and type
checkers like mypy can check the passed types.
  • Loading branch information
MoritzWeber0 committed Feb 7, 2024
1 parent c808389 commit c74d737
Show file tree
Hide file tree
Showing 24 changed files with 149 additions and 112 deletions.
2 changes: 1 addition & 1 deletion backend/capellacollab/core/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SessionLocal = orm.sessionmaker(autocommit=False, autoflush=False, bind=engine)


class Base(orm.DeclarativeBase):
class Base(orm.MappedAsDataclass, orm.DeclarativeBase):
type_annotation_map = {
dict[str, str]: postgresql.JSONB,
dict[str, t.Any]: postgresql.JSONB,
Expand Down
32 changes: 16 additions & 16 deletions backend/capellacollab/core/database/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def migrate_db(engine, database_url: str):
create_coffee_machine_model(session)


def initialize_admin_user(db):
def initialize_admin_user(db: orm.Session):
LOGGER.info("Initialized adminuser %s", config["initial"]["admin"])
admin_user = users_crud.create_user(
db=db,
Expand All @@ -94,7 +94,7 @@ def initialize_admin_user(db):
events_crud.create_user_creation_event(db, admin_user)


def initialize_default_project(db):
def initialize_default_project(db: orm.Session):
LOGGER.info("Initialized project 'default'")
projects_crud.create_project(
db=db,
Expand All @@ -104,7 +104,7 @@ def initialize_default_project(db):
)


def initialize_coffee_machine_project(db):
def initialize_coffee_machine_project(db: orm.Session):
LOGGER.info("Initialize project 'Coffee Machine'")
projects_crud.create_project(
db=db,
Expand All @@ -114,7 +114,7 @@ def initialize_coffee_machine_project(db):
)


def create_tools(db):
def create_tools(db: orm.Session):
LOGGER.info("Initialized tools")
registry = config["docker"]["registry"]
if os.getenv("DEVELOPMENT_MODE", "").lower() in ("1", "true", "t"):
Expand All @@ -131,12 +131,12 @@ def create_tools(db):
)
tools_crud.create_tool(db, papyrus)

tools_crud.create_version(db, papyrus.id, "6.1")
tools_crud.create_version(db, papyrus.id, "6.0")
tools_crud.create_version(db, papyrus, "6.1")
tools_crud.create_version(db, papyrus, "6.0")

tools_crud.create_nature(db, papyrus.id, "UML 2.5")
tools_crud.create_nature(db, papyrus.id, "SysML 1.4")
tools_crud.create_nature(db, papyrus.id, "SysML 1.1")
tools_crud.create_nature(db, papyrus, "UML 2.5")
tools_crud.create_nature(db, papyrus, "SysML 1.4")
tools_crud.create_nature(db, papyrus, "SysML 1.1")

else:
# Use public Github images per default
Expand All @@ -160,15 +160,15 @@ def create_tools(db):
integrations_models.PatchToolIntegrations(jupyter=True),
)

default_version = tools_crud.create_version(db, capella.id, "6.0.0", True)
tools_crud.create_version(db, capella.id, "5.2.0")
tools_crud.create_version(db, capella.id, "5.0.0")
default_version = tools_crud.create_version(db, capella, "6.0.0", True)
tools_crud.create_version(db, capella, "5.2.0")
tools_crud.create_version(db, capella, "5.0.0")

tools_crud.create_version(db, jupyter.id, "python-3.11")
tools_crud.create_nature(db, jupyter.id, "notebooks")
tools_crud.create_version(db, jupyter, "python-3.11")
tools_crud.create_nature(db, jupyter, "notebooks")

default_nature = tools_crud.create_nature(db, capella.id, "model")
tools_crud.create_nature(db, capella.id, "library")
default_nature = tools_crud.create_nature(db, capella, "model")
tools_crud.create_nature(db, capella, "library")

for model in toolmodels_crud.get_models(db):
toolmodels_crud.set_tool_for_model(db, model, capella)
Expand Down
8 changes: 5 additions & 3 deletions backend/capellacollab/events/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ def create_event(
raise ValueError(
f"Event type must of one of the following: {allowed_types}"
)

event = models.DatabaseUserHistoryEvent(
user_id=user.id,
user=user,
event_type=event_type,
execution_time=datetime.datetime.now(datetime.UTC),
executor_id=executor.id if executor else None,
project_id=project.id if project else None,
executor=executor,
project=project,
reason=reason,
)

db.add(event)
db.commit()

Expand Down
40 changes: 23 additions & 17 deletions backend/capellacollab/events/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from capellacollab.projects import models as projects_models
from capellacollab.users import models as users_models

if t.TYPE_CHECKING:
from capellacollab.projects.models import DatabaseProject
from capellacollab.users.models import DatabaseUser


class EventType(enum.Enum):
CREATED_USER = "CreatedUser"
Expand Down Expand Up @@ -55,27 +51,37 @@ class HistoryEvent(BaseHistoryEvent):
class DatabaseUserHistoryEvent(database.Base):
__tablename__ = "user_history_events"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True, index=True)
id: orm.Mapped[int] = orm.mapped_column(
init=False, primary_key=True, index=True
)

user_id: orm.Mapped[int] = orm.mapped_column(sa.ForeignKey("users.id"))
user: orm.Mapped["DatabaseUser"] = orm.relationship(
user_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("users.id"),
init=False,
)
user: orm.Mapped[users_models.DatabaseUser] = orm.relationship(
back_populates="events", foreign_keys=[user_id]
)

event_type: orm.Mapped[EventType]
reason: orm.Mapped[str | None] = orm.mapped_column(default=None)

executor_id: orm.Mapped[int | None] = orm.mapped_column(
sa.ForeignKey("users.id")
sa.ForeignKey("users.id"),
init=False,
)
executor: orm.Mapped["DatabaseUser"] = orm.relationship(
foreign_keys=[executor_id]
executor: orm.Mapped[users_models.DatabaseUser | None] = orm.relationship(
default=None, foreign_keys=[executor_id]
)

project_id: orm.Mapped[int | None] = orm.mapped_column(
sa.ForeignKey("projects.id")
)
project: orm.Mapped["DatabaseProject"] = orm.relationship(
foreign_keys=[project_id]
sa.ForeignKey("projects.id"),
init=False,
)
project: orm.Mapped[
projects_models.DatabaseProject | None
] = orm.relationship(default=None, foreign_keys=[project_id])

execution_time: orm.Mapped[datetime.datetime]
event_type: orm.Mapped[EventType]
reason: orm.Mapped[str | None]
execution_time: orm.Mapped[datetime.datetime] = orm.mapped_column(
default=datetime.datetime.now(datetime.UTC)
)
4 changes: 3 additions & 1 deletion backend/capellacollab/notices/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ class NoticeResponse(CreateNoticeRequest):
class DatabaseNotice(database.Base):
__tablename__ = "notices"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True, index=True)
id: orm.Mapped[int] = orm.mapped_column(
init=False, primary_key=True, index=True
)
title: orm.Mapped[str]
message: orm.Mapped[str]
level: orm.Mapped[NoticeLevel]
17 changes: 11 additions & 6 deletions backend/capellacollab/projects/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,25 @@ class DatabaseProject(database.Base):
__tablename__ = "projects"

id: orm.Mapped[int] = orm.mapped_column(
unique=True, primary_key=True, index=True
init=False, unique=True, primary_key=True, index=True
)

name: orm.Mapped[str] = orm.mapped_column(unique=True, index=True)
slug: orm.Mapped[str] = orm.mapped_column(unique=True, index=True)
description: orm.Mapped[str | None]
visibility: orm.Mapped[Visibility]
type: orm.Mapped[ProjectType]

description: orm.Mapped[str | None] = orm.mapped_column(default=None)
visibility: orm.Mapped[Visibility] = orm.mapped_column(
default=Visibility.PRIVATE
)
type: orm.Mapped[ProjectType] = orm.mapped_column(
default=ProjectType.GENERAL
)

users: orm.Mapped[list[ProjectUserAssociation]] = orm.relationship(
back_populates="project"
default_factory=list, back_populates="project"
)
models: orm.Mapped[list[DatabaseCapellaModel]] = orm.relationship(
back_populates="project"
default_factory=list, back_populates="project"
)

is_archived: orm.Mapped[bool] = orm.mapped_column(default=False)
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Backup(pydantic.BaseModel):
class DatabaseBackup(database.Base):
__tablename__ = "backups"
id: orm.Mapped[int] = orm.mapped_column(
primary_key=True, index=True, autoincrement=True
init=False, primary_key=True, index=True, autoincrement=True
)

created_by: orm.Mapped[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class PipelineRunStatus(enum.Enum):
class DatabasePipelineRun(Base):
__tablename__ = "pipeline_run"
id: orm.Mapped[int] = orm.mapped_column(
primary_key=True, index=True, autoincrement=True
init=False, primary_key=True, index=True, autoincrement=True
)
reference_id: orm.Mapped[str | None]

Expand Down
2 changes: 1 addition & 1 deletion backend/capellacollab/projects/toolmodels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class DatabaseCapellaModel(database.Base):
__table_args__ = (sa.UniqueConstraint("project_id", "slug"),)

id: orm.Mapped[int] = orm.mapped_column(
primary_key=True, index=True, unique=True
init=False, primary_key=True, index=True, unique=True
)

name: orm.Mapped[str] = orm.mapped_column(index=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DatabaseT4CModel(database.Base):
)

id: orm.Mapped[int] = orm.mapped_column(
unique=True, primary_key=True, index=True
init=False, unique=True, primary_key=True, index=True
)
name: orm.Mapped[str] = orm.mapped_column(index=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DatabaseToolModelRestrictions(database.Base):
__tablename__ = "model_restrictions"

id: orm.Mapped[int] = orm.mapped_column(
primary_key=True, index=True, unique=True
init=False, primary_key=True, index=True, unique=True
)

model_id: orm.Mapped[int] = orm.mapped_column(sa.ForeignKey("models.id"))
Expand Down
2 changes: 1 addition & 1 deletion backend/capellacollab/settings/configuration/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class DatabaseConfiguration(database.Base):
__tablename__ = "configuration"

id: orm.Mapped[int] = orm.mapped_column(
unique=True, primary_key=True, index=True
init=False, unique=True, primary_key=True, index=True
)

name: orm.Mapped[str] = orm.mapped_column(unique=True, index=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def validate_license_url(value: str | None):
class DatabasePureVariantsLicenses(database.Base):
__tablename__ = "pure_variants"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True, index=True)
id: orm.Mapped[int] = orm.mapped_column(
init=False, primary_key=True, index=True
)
license_server_url: orm.Mapped[str | None]
license_key_filename: orm.Mapped[str | None]

Expand Down
2 changes: 1 addition & 1 deletion backend/capellacollab/settings/modelsources/git/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DatabaseGitInstance(database.Base):
__tablename__ = "git_instances"

id: orm.Mapped[int] = orm.mapped_column(
primary_key=True, index=True, autoincrement=True
init=False, primary_key=True, index=True, autoincrement=True
)

name: orm.Mapped[str]
Expand Down
21 changes: 11 additions & 10 deletions backend/capellacollab/settings/modelsources/t4c/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,14 @@ class DatabaseT4CInstance(database.Base):
__tablename__ = "t4c_instances"

id: orm.Mapped[int] = orm.mapped_column(
primary_key=True, index=True, autoincrement=True
init=False, primary_key=True, index=True, autoincrement=True
)

name: orm.Mapped[str] = orm.mapped_column(unique=True)

license: orm.Mapped[str]
host: orm.Mapped[str]
port: orm.Mapped[int] = orm.mapped_column(
sa.CheckConstraint("port >= 0 AND port <= 65535"), default=2036
)
cdo_port: orm.Mapped[int] = orm.mapped_column(
sa.CheckConstraint("cdo_port >= 0 AND cdo_port <= 65535"),
default=12036,
)

http_port: orm.Mapped[int | None] = orm.mapped_column(
sa.CheckConstraint("http_port >= 0 AND http_port <= 65535"),
)
Expand All @@ -68,8 +62,6 @@ class DatabaseT4CInstance(database.Base):
username: orm.Mapped[str]
password: orm.Mapped[str]

protocol: orm.Mapped[Protocol] = orm.mapped_column(default=Protocol.tcp)

version_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("versions.id")
)
Expand All @@ -79,6 +71,15 @@ class DatabaseT4CInstance(database.Base):
back_populates="instance", cascade="all, delete"
)

port: orm.Mapped[int] = orm.mapped_column(
sa.CheckConstraint("port >= 0 AND port <= 65535"), default=2036
)
cdo_port: orm.Mapped[int] = orm.mapped_column(
sa.CheckConstraint("cdo_port >= 0 AND cdo_port <= 65535"),
default=12036,
)
protocol: orm.Mapped[Protocol] = orm.mapped_column(default=Protocol.tcp)

is_archived: orm.Mapped[bool] = orm.mapped_column(default=False)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ class DatabaseT4CRepository(database.Base):
__table_args__ = (sa.UniqueConstraint("instance_id", "name"),)

id: orm.Mapped[int] = orm.mapped_column(
primary_key=True, index=True, autoincrement=True, unique=True
init=False,
primary_key=True,
index=True,
autoincrement=True,
unique=True,
)
name: orm.Mapped[str]

Expand Down
10 changes: 5 additions & 5 deletions backend/capellacollab/tools/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_tool(
db: orm.Session, tool: models.DatabaseTool
) -> models.DatabaseTool:
tool.integrations = integrations_models.DatabaseToolIntegrations(
pure_variants=False, t4c=False, jupyter=False
tool=tool, pure_variants=False, t4c=False, jupyter=False
)
db.add(tool)
db.commit()
Expand Down Expand Up @@ -151,7 +151,7 @@ def update_version(

def create_version(
db: orm.Session,
tool_id: int,
tool: models.DatabaseTool,
name: str,
is_recommended: bool = False,
is_deprecated: bool = False,
Expand All @@ -160,7 +160,7 @@ def create_version(
name=name,
is_recommended=is_recommended,
is_deprecated=is_deprecated,
tool_id=tool_id,
tool=tool,
)
db.add(version)
db.commit()
Expand Down Expand Up @@ -223,9 +223,9 @@ def get_natures_by_tool_id(


def create_nature(
db: orm.Session, tool_id: int, name: str
db: orm.Session, tool: models.DatabaseTool, name: str
) -> models.DatabaseNature:
nature = models.DatabaseNature(name=name, tool_id=tool_id)
nature = models.DatabaseNature(name=name, tool=tool)
db.add(nature)
db.commit()
return nature
Expand Down
6 changes: 4 additions & 2 deletions backend/capellacollab/tools/integrations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ class PatchToolIntegrations(pydantic.BaseModel):
class DatabaseToolIntegrations(database.Base):
__tablename__ = "tool_integrations"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
id: orm.Mapped[int] = orm.mapped_column(init=False, primary_key=True)

tool_id: orm.Mapped[int] = orm.mapped_column(sa.ForeignKey("tools.id"))
tool_id: orm.Mapped[int] = orm.mapped_column(
sa.ForeignKey("tools.id"), init=False
)
tool: orm.Mapped[DatabaseTool] = orm.relationship(
back_populates="integrations"
)
Expand Down
Loading

0 comments on commit c74d737

Please sign in to comment.