Skip to content

Commit

Permalink
feat: add task to db and tests for migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
toadharvard committed Apr 2, 2024
1 parent 27bb0bd commit c96c914
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ format:

## Run all tests in project
test:
poetry run pytest -o log_cli=true --verbosity=2 --showlocals --log-cli-level=INFO --cov=app --cov-report term
poetry run pytest -o log_cli=true --verbosity=2 --showlocals --log-cli-level=INFO --test-alembic --cov=app --cov-report term

.DEFAULT_GOAL := help
# See <https://gist.github.com/klmr/575726c7e05d8780505a> for explanation.
Expand Down
30 changes: 23 additions & 7 deletions app/api/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import UUID
from fastapi import APIRouter, HTTPException
from pydantic import UUID4
from app.domain.file.dataset import DatasetORM
from app.domain.task.task import TaskModel, TaskORM, TaskStatus
from app.domain.worker.task.data_profiling_task import data_profiling_task
from app.domain.task import OneOfTaskConfig
from sqlalchemy_mixins.activerecord import ModelNotFoundError
Expand All @@ -11,18 +11,34 @@

@router.post("")
def set_task(
dataset_id: UUID4,
dataset_id: UUID,
config: OneOfTaskConfig,
) -> UUID4:
) -> UUID:
try:
DatasetORM.find_or_fail(dataset_id)
except ModelNotFoundError:
raise HTTPException(404, "Dataset not found")

async_result = data_profiling_task.delay(dataset_id, config)
return UUID(async_result.id, version=4)
task_orm = TaskORM.create(
status=TaskStatus.CREATED,
config=config.model_dump(),
dataset_id=dataset_id,
)
task_id = task_orm.id # type: ignore

data_profiling_task.delay(
task_id=task_id,
dataset_id=dataset_id,
config=config,
)

return task_id


@router.get("/{task_id}")
def retrieve_task(task_id: UUID4) -> None:
raise HTTPException(418, "Not implemented yet")
def retrieve_task(task_id: UUID) -> TaskModel:
try:
task_orm = TaskORM.find_or_fail(task_id)
return TaskModel.model_validate(task_orm)
except ModelNotFoundError:
raise HTTPException(404, "Task not found")
3 changes: 2 additions & 1 deletion app/db/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from app.settings import settings
from app.db import ORMBase
from app.domain.file.file import FileORM # noqa: F401
from app.domain.file.dataset import DatasetModel # noqa: F401
from app.domain.file.dataset import DatasetORM # noqa: F401
from app.domain.task.task import TaskORM # noqa: F401

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down
67 changes: 67 additions & 0 deletions app/db/migrations/versions/7dc9a3441d07_add_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""add task
Revision ID: 7dc9a3441d07
Revises: 6a59d47fe978
Create Date: 2024-04-02 04:04:09.759025
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = "7dc9a3441d07"
down_revision: Union[str, None] = "6a59d47fe978"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"task",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column(
"status",
sa.Enum("FAILED", "CREATED", "RUNNING", "COMPLETED", name="taskstatus"),
nullable=False,
),
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("result", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("dataset_id", sa.Uuid(), nullable=False),
sa.Column("raised_exception_name", sa.String(), nullable=True),
sa.Column(
"failure_reason",
sa.Enum(
"MEMORY_LIMIT_EXCEEDED",
"TIME_LIMIT_EXCEEDED",
"WORKER_KILLED_BY_SIGNAL",
"OTHER",
name="taskfailurereason",
),
nullable=True,
),
sa.Column("traceback", sa.String(), nullable=True),
sa.Column("created_at", sa.TIMESTAMP(), nullable=False),
sa.Column("updated_at", sa.TIMESTAMP(), nullable=False),
sa.ForeignKeyConstraint(
["dataset_id"],
["dataset.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###


# ADJUSTED! See: https://github.com/sqlalchemy/alembic/issues/278#issuecomment-907283386
def downgrade() -> None:
op.drop_table("task")

taskfailurereason = sa.Enum(name="taskfailurereason")
taskfailurereason.drop(op.get_bind(), checkfirst=True)

taskstatus = sa.Enum(name="taskstatus")
taskstatus.drop(op.get_bind(), checkfirst=True)
11 changes: 9 additions & 2 deletions app/domain/file/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from app.db import ORMBase
from app.db.session import ORMBaseModel
from app.domain.file.file import FileModel, FileORM
import typing

if typing.TYPE_CHECKING:
from app.domain.task.task import TaskORM


class DatasetORM(ORMBase):
Expand All @@ -17,8 +21,11 @@ class DatasetORM(ORMBase):
file_id: Mapped[UUID] = mapped_column(ForeignKey("file.id"), nullable=False)
file: Mapped[FileORM] = relationship("FileORM")

# user = relationship("UserORM")
# task = relationship("TaskORM")
related_tasks: Mapped[list["TaskORM"]] = relationship(
"TaskORM", back_populates="dataset"
)

# owner = relationship("UserORM")


class DatasetModel(ORMBaseModel):
Expand Down
60 changes: 60 additions & 0 deletions app/domain/task/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from enum import StrEnum, auto
import typing
from uuid import UUID, uuid4
from sqlalchemy.orm import Mapped, mapped_column
from app.db import ORMBase
from app.db.session import ORMBaseModel
from sqlalchemy import ForeignKey
from sqlalchemy.orm import relationship
from app.domain.file.dataset import DatasetModel
from app.domain.task import OneOfTaskConfig, OneOfTaskResult

from sqlalchemy.dialects.postgresql import JSONB

if typing.TYPE_CHECKING:
from app.domain.file.dataset import DatasetORM


class TaskStatus(StrEnum):
FAILED = auto()
CREATED = auto()
RUNNING = auto()
COMPLETED = auto()


class TaskFailureReason(StrEnum):
MEMORY_LIMIT_EXCEEDED = auto()
TIME_LIMIT_EXCEEDED = auto()
WORKER_KILLED_BY_SIGNAL = auto()
OTHER = auto()


class TaskORM(ORMBase):
__tablename__ = "task"
id: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)

status: Mapped[TaskStatus]
config: Mapped[OneOfTaskConfig] = mapped_column(JSONB)
result: Mapped[OneOfTaskResult | None] = mapped_column(JSONB, default=None)

dataset_id: Mapped[UUID] = mapped_column(ForeignKey("dataset.id"), nullable=False)
dataset: Mapped["DatasetORM"] = relationship(
"DatasetORM", back_populates="related_tasks"
)

# Only if task failed
raised_exception_name: Mapped[str | None] = mapped_column(default=None)
failure_reason: Mapped[TaskFailureReason | None] = mapped_column(default=None)
traceback: Mapped[str | None] = mapped_column(default=None)


class TaskModel(ORMBaseModel):
id: UUID
status: TaskStatus
config: OneOfTaskConfig
result: OneOfTaskResult | None
dataset: DatasetModel

raised_exception_name: str | None
failure_reason: TaskFailureReason | None
traceback: str | None
67 changes: 34 additions & 33 deletions app/domain/worker/task/data_profiling_task.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import logging
from typing import Any
from uuid import UUID

from app.db.session import no_pooling
from app.domain.file.dataset import DatasetORM
from app.domain.task import OneOfTaskConfig
from app.domain.task import OneOfTaskConfig, OneOfTaskResult
from app.domain.task import match_task_by_primitive_name
from app.domain.task.task import TaskFailureReason, TaskORM, TaskStatus
from app.worker import worker
from app.domain.worker.task.resource_intensive_task import ResourceIntensiveTask
from pydantic import UUID4
import pandas as pd
from celery.signals import task_failure, task_prerun, task_postrun
from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded, WorkerLostError


@worker.task(base=ResourceIntensiveTask, ignore_result=True, max_retries=0)
def data_profiling_task(
dataset_id: UUID4,
task_id: UUID,
dataset_id: UUID,
config: OneOfTaskConfig,
) -> Any:
with no_pooling():
Expand All @@ -37,53 +39,52 @@ def data_profiling_task(

@task_prerun.connect(sender=data_profiling_task)
def task_prerun_notifier(
sender,
task_id,
task,
args,
kwargs,
**_,
):
# TODO: Create Task in database and set status to "running" or similar
db_task_id: UUID = kwargs["task_id"]
with no_pooling():
...
logging.critical(
f"From task_prerun_notifier ==> Running just before add() executes, {sender}"
)
task_orm = TaskORM.find_or_fail(db_task_id)
task_orm.update(status=TaskStatus.RUNNING) # type: ignore


@task_postrun.connect(sender=data_profiling_task)
def task_postrun_notifier(
sender,
task_id,
task,
args,
kwargs,
retval,
retval: OneOfTaskResult,
**_,
):
db_task_id: UUID = kwargs["task_id"]
with no_pooling():
...

# TODO: Update Task in database and set status to "completed" or similar
logging.critical(f"From task_postrun_notifier ==> Ok, done!, {sender}")
task_orm = TaskORM.find_or_fail(db_task_id) # type: ignore
task_orm.update(
status=TaskStatus.COMPLETED, # type: ignore
result=retval.model_dump(),
)


@task_failure.connect(sender=data_profiling_task)
def task_failure_notifier(
sender,
task_id,
exception,
args,
kwargs,
exception,
traceback,
einfo,
**_,
):
with no_pooling():
...
# TODO: Update Task in database and set status to "failed" or similar
# TODO: test all possible exceptions
task_failure_reason = TaskFailureReason.OTHER
if exception in (TimeLimitExceeded, SoftTimeLimitExceeded):
task_failure_reason = TaskFailureReason.TIME_LIMIT_EXCEEDED
if exception is MemoryError:
task_failure_reason = TaskFailureReason.MEMORY_LIMIT_EXCEEDED
if exception is WorkerLostError:
task_failure_reason = TaskFailureReason.WORKER_KILLED_BY_SIGNAL

logging.critical(
f"From task_failure_notifier ==> Task failed successfully! 😅, {sender}"
)
db_task_id: UUID = kwargs["task_id"]
with no_pooling():
task_orm = TaskORM.find_or_fail(db_task_id) # type: ignore
task_orm.update(
status=TaskStatus.FAILED, # type: ignore
raised_exception_name=exception.__class__.__name__,
failure_reason=task_failure_reason, # type: ignore
traceback=traceback,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pytest-cov = "^4.1.0"
ipykernel = "^6.29.3"
polyfactory = "^2.15.0"
pyright = "^1.1.355"
pytest-alembic = "^0.11.0"

[build-system]
requires = ["poetry-core"]
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from pytest_alembic import Config
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import database_exists, create_database
Expand Down Expand Up @@ -27,3 +28,9 @@ def prepare_db():
def session():
session = sessionmaker(test_engine, expire_on_commit=False)
yield session


@pytest.fixture
def alembic_config():
options = {"file": "app/settings/alembic.ini"}
return Config(config_options=options)

0 comments on commit c96c914

Please sign in to comment.