diff --git a/Makefile b/Makefile index 9a120b01..9b39e3b2 100644 --- a/Makefile +++ b/Makefile @@ -63,7 +63,7 @@ format: ## Run all tests in project test: - poetry run pytest -o log_cli=true --verbosity=2 --showlocals --log-cli-level=INFO --cov=app --cov-report term + poetry run pytest -o log_cli=true --verbosity=2 --showlocals --log-cli-level=INFO --test-alembic --cov=app --cov-report term .DEFAULT_GOAL := help # See for explanation. diff --git a/app/api/task.py b/app/api/task.py index 097064b6..26ebe814 100644 --- a/app/api/task.py +++ b/app/api/task.py @@ -1,7 +1,7 @@ from uuid import UUID from fastapi import APIRouter, HTTPException -from pydantic import UUID4 from app.domain.file.dataset import DatasetORM +from app.domain.task.task import TaskModel, TaskORM, TaskStatus from app.domain.worker.task.data_profiling_task import data_profiling_task from app.domain.task import OneOfTaskConfig from sqlalchemy_mixins.activerecord import ModelNotFoundError @@ -11,18 +11,34 @@ @router.post("") def set_task( - dataset_id: UUID4, + dataset_id: UUID, config: OneOfTaskConfig, -) -> UUID4: +) -> UUID: try: DatasetORM.find_or_fail(dataset_id) except ModelNotFoundError: raise HTTPException(404, "Dataset not found") - async_result = data_profiling_task.delay(dataset_id, config) - return UUID(async_result.id, version=4) + task_orm = TaskORM.create( + status=TaskStatus.CREATED, + config=config.model_dump(), + dataset_id=dataset_id, + ) + task_id = task_orm.id # type: ignore + + data_profiling_task.delay( + task_id=task_id, + dataset_id=dataset_id, + config=config, + ) + + return task_id @router.get("/{task_id}") -def retrieve_task(task_id: UUID4) -> None: - raise HTTPException(418, "Not implemented yet") +def retrieve_task(task_id: UUID) -> TaskModel: + try: + task_orm = TaskORM.find_or_fail(task_id) + return TaskModel.model_validate(task_orm) + except ModelNotFoundError: + raise HTTPException(404, "Task not found") diff --git a/app/db/migrations/env.py b/app/db/migrations/env.py index d67e97f6..6ad7bb57 100644 --- a/app/db/migrations/env.py +++ b/app/db/migrations/env.py @@ -4,7 +4,8 @@ from app.settings import settings from app.db import ORMBase from app.domain.file.file import FileORM # noqa: F401 -from app.domain.file.dataset import DatasetModel # noqa: F401 +from app.domain.file.dataset import DatasetORM # noqa: F401 +from app.domain.task.task import TaskORM # noqa: F401 # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/app/db/migrations/versions/7dc9a3441d07_add_task.py b/app/db/migrations/versions/7dc9a3441d07_add_task.py new file mode 100644 index 00000000..8ff511e7 --- /dev/null +++ b/app/db/migrations/versions/7dc9a3441d07_add_task.py @@ -0,0 +1,67 @@ +"""add task + +Revision ID: 7dc9a3441d07 +Revises: 6a59d47fe978 +Create Date: 2024-04-02 04:04:09.759025 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "7dc9a3441d07" +down_revision: Union[str, None] = "6a59d47fe978" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "task", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "status", + sa.Enum("FAILED", "CREATED", "RUNNING", "COMPLETED", name="taskstatus"), + nullable=False, + ), + sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("result", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("dataset_id", sa.Uuid(), nullable=False), + sa.Column("raised_exception_name", sa.String(), nullable=True), + sa.Column( + "failure_reason", + sa.Enum( + "MEMORY_LIMIT_EXCEEDED", + "TIME_LIMIT_EXCEEDED", + "WORKER_KILLED_BY_SIGNAL", + "OTHER", + name="taskfailurereason", + ), + nullable=True, + ), + sa.Column("traceback", sa.String(), nullable=True), + sa.Column("created_at", sa.TIMESTAMP(), nullable=False), + sa.Column("updated_at", sa.TIMESTAMP(), nullable=False), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["dataset.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +# ADJUSTED! See: https://github.com/sqlalchemy/alembic/issues/278#issuecomment-907283386 +def downgrade() -> None: + op.drop_table("task") + + taskfailurereason = sa.Enum(name="taskfailurereason") + taskfailurereason.drop(op.get_bind(), checkfirst=True) + + taskstatus = sa.Enum(name="taskstatus") + taskstatus.drop(op.get_bind(), checkfirst=True) diff --git a/app/domain/file/dataset.py b/app/domain/file/dataset.py index 91d119d1..42e7bb27 100644 --- a/app/domain/file/dataset.py +++ b/app/domain/file/dataset.py @@ -5,6 +5,10 @@ from app.db import ORMBase from app.db.session import ORMBaseModel from app.domain.file.file import FileModel, FileORM +import typing + +if typing.TYPE_CHECKING: + from app.domain.task.task import TaskORM class DatasetORM(ORMBase): @@ -17,8 +21,11 @@ class DatasetORM(ORMBase): file_id: Mapped[UUID] = mapped_column(ForeignKey("file.id"), nullable=False) file: Mapped[FileORM] = relationship("FileORM") - # user = relationship("UserORM") - # task = relationship("TaskORM") + related_tasks: Mapped[list["TaskORM"]] = relationship( + "TaskORM", back_populates="dataset" + ) + + # owner = relationship("UserORM") class DatasetModel(ORMBaseModel): diff --git a/app/domain/task/task.py b/app/domain/task/task.py new file mode 100644 index 00000000..09010069 --- /dev/null +++ b/app/domain/task/task.py @@ -0,0 +1,60 @@ +from enum import StrEnum, auto +import typing +from uuid import UUID, uuid4 +from sqlalchemy.orm import Mapped, mapped_column +from app.db import ORMBase +from app.db.session import ORMBaseModel +from sqlalchemy import ForeignKey +from sqlalchemy.orm import relationship +from app.domain.file.dataset import DatasetModel +from app.domain.task import OneOfTaskConfig, OneOfTaskResult + +from sqlalchemy.dialects.postgresql import JSONB + +if typing.TYPE_CHECKING: + from app.domain.file.dataset import DatasetORM + + +class TaskStatus(StrEnum): + FAILED = auto() + CREATED = auto() + RUNNING = auto() + COMPLETED = auto() + + +class TaskFailureReason(StrEnum): + MEMORY_LIMIT_EXCEEDED = auto() + TIME_LIMIT_EXCEEDED = auto() + WORKER_KILLED_BY_SIGNAL = auto() + OTHER = auto() + + +class TaskORM(ORMBase): + __tablename__ = "task" + id: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) + + status: Mapped[TaskStatus] + config: Mapped[OneOfTaskConfig] = mapped_column(JSONB) + result: Mapped[OneOfTaskResult | None] = mapped_column(JSONB, default=None) + + dataset_id: Mapped[UUID] = mapped_column(ForeignKey("dataset.id"), nullable=False) + dataset: Mapped["DatasetORM"] = relationship( + "DatasetORM", back_populates="related_tasks" + ) + + # Only if task failed + raised_exception_name: Mapped[str | None] = mapped_column(default=None) + failure_reason: Mapped[TaskFailureReason | None] = mapped_column(default=None) + traceback: Mapped[str | None] = mapped_column(default=None) + + +class TaskModel(ORMBaseModel): + id: UUID + status: TaskStatus + config: OneOfTaskConfig + result: OneOfTaskResult | None + dataset: DatasetModel + + raised_exception_name: str | None + failure_reason: TaskFailureReason | None + traceback: str | None diff --git a/app/domain/worker/task/data_profiling_task.py b/app/domain/worker/task/data_profiling_task.py index fc2bb5c5..82f928d8 100644 --- a/app/domain/worker/task/data_profiling_task.py +++ b/app/domain/worker/task/data_profiling_task.py @@ -1,20 +1,22 @@ -import logging from typing import Any +from uuid import UUID from app.db.session import no_pooling from app.domain.file.dataset import DatasetORM -from app.domain.task import OneOfTaskConfig +from app.domain.task import OneOfTaskConfig, OneOfTaskResult from app.domain.task import match_task_by_primitive_name +from app.domain.task.task import TaskFailureReason, TaskORM, TaskStatus from app.worker import worker from app.domain.worker.task.resource_intensive_task import ResourceIntensiveTask -from pydantic import UUID4 import pandas as pd from celery.signals import task_failure, task_prerun, task_postrun +from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded, WorkerLostError @worker.task(base=ResourceIntensiveTask, ignore_result=True, max_retries=0) def data_profiling_task( - dataset_id: UUID4, + task_id: UUID, + dataset_id: UUID, config: OneOfTaskConfig, ) -> Any: with no_pooling(): @@ -37,53 +39,52 @@ def data_profiling_task( @task_prerun.connect(sender=data_profiling_task) def task_prerun_notifier( - sender, - task_id, - task, - args, kwargs, **_, ): - # TODO: Create Task in database and set status to "running" or similar + db_task_id: UUID = kwargs["task_id"] with no_pooling(): - ... - logging.critical( - f"From task_prerun_notifier ==> Running just before add() executes, {sender}" - ) + task_orm = TaskORM.find_or_fail(db_task_id) + task_orm.update(status=TaskStatus.RUNNING) # type: ignore @task_postrun.connect(sender=data_profiling_task) def task_postrun_notifier( - sender, - task_id, - task, - args, kwargs, - retval, + retval: OneOfTaskResult, **_, ): + db_task_id: UUID = kwargs["task_id"] with no_pooling(): - ... - - # TODO: Update Task in database and set status to "completed" or similar - logging.critical(f"From task_postrun_notifier ==> Ok, done!, {sender}") + task_orm = TaskORM.find_or_fail(db_task_id) # type: ignore + task_orm.update( + status=TaskStatus.COMPLETED, # type: ignore + result=retval.model_dump(), + ) @task_failure.connect(sender=data_profiling_task) def task_failure_notifier( - sender, - task_id, - exception, - args, kwargs, + exception, traceback, - einfo, **_, ): - with no_pooling(): - ... - # TODO: Update Task in database and set status to "failed" or similar + # TODO: test all possible exceptions + task_failure_reason = TaskFailureReason.OTHER + if exception in (TimeLimitExceeded, SoftTimeLimitExceeded): + task_failure_reason = TaskFailureReason.TIME_LIMIT_EXCEEDED + if exception is MemoryError: + task_failure_reason = TaskFailureReason.MEMORY_LIMIT_EXCEEDED + if exception is WorkerLostError: + task_failure_reason = TaskFailureReason.WORKER_KILLED_BY_SIGNAL - logging.critical( - f"From task_failure_notifier ==> Task failed successfully! 😅, {sender}" - ) + db_task_id: UUID = kwargs["task_id"] + with no_pooling(): + task_orm = TaskORM.find_or_fail(db_task_id) # type: ignore + task_orm.update( + status=TaskStatus.FAILED, # type: ignore + raised_exception_name=exception.__class__.__name__, + failure_reason=task_failure_reason, # type: ignore + traceback=traceback, + ) diff --git a/pyproject.toml b/pyproject.toml index 30b148d3..b964ced1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ pytest-cov = "^4.1.0" ipykernel = "^6.29.3" polyfactory = "^2.15.0" pyright = "^1.1.355" +pytest-alembic = "^0.11.0" [build-system] requires = ["poetry-core"] diff --git a/tests/conftest.py b/tests/conftest.py index 758baf92..26bace98 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import pytest +from pytest_alembic import Config from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy_utils import database_exists, create_database @@ -27,3 +28,9 @@ def prepare_db(): def session(): session = sessionmaker(test_engine, expire_on_commit=False) yield session + + +@pytest.fixture +def alembic_config(): + options = {"file": "app/settings/alembic.ini"} + return Config(config_options=options)