Skip to content

Commit

Permalink
feat: add celery task for data profiling and update API
Browse files Browse the repository at this point in the history
  • Loading branch information
toadharvard committed Mar 19, 2024
1 parent 0d276d7 commit d238f24
Show file tree
Hide file tree
Showing 15 changed files with 174 additions and 77 deletions.
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
POSTGRES_DIALECT_DRIVER=postgresql+psycopg

POSTGRES_USER=admin
POSTGRES_PASSWORD=admin
POSTGRES_HOST=localhost
Expand Down
5 changes: 2 additions & 3 deletions app/api/ping.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Literal

from fastapi import APIRouter

router = APIRouter()


@router.get("/ping")
def ping() -> Literal["Pong"]:
return "Pong"
def ping() -> Literal["Pong!"]:
return "Pong!"
31 changes: 10 additions & 21 deletions app/api/task.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
from fastapi import APIRouter, HTTPException, Depends
from fastapi import APIRouter, HTTPException
from pydantic import UUID4
from typing import Type, Annotated
import pandas as pd
from uuid import uuid4
from typing import Type
from app.domain.task.primitive_factory import PrimitiveName, PrimitiveFactory
from app.domain.task.task_factory import AnyAlgoName
from app.domain.task.abstract_task import AnyTask, AnyRes
from app.domain.common.optional_fields import OptionalFields
from app.domain.worker.task.data_profiling_task import data_profiling_task

router = APIRouter(prefix="/task")

repo = {}


def get_df_by_file_id(file_id: UUID4) -> pd.DataFrame:
return pd.read_csv("tests/datasets/university_fd.csv", sep=",", header=0)


def generate_set_task_endpoint(
primitive_name: PrimitiveName,
Expand All @@ -30,13 +22,13 @@ def generate_set_task_endpoint(
tags=["set task"],
)
def _(
df: Annotated[pd.DataFrame, Depends(get_df_by_file_id)],
config: OptionalFields[task_cls.config_model_cls],
file_id: UUID4,
config: task_cls.config_model_cls,
) -> UUID4:
task = task_cls(df)
task_id = uuid4()
repo[task_id] = (task, config)
return task_id
async_result = data_profiling_task.delay(
primitive_name, algo_name, file_id, config
)
return async_result.id

router.include_router(primitive_router)

Expand All @@ -48,10 +40,7 @@ def generate_get_task_result_endpoint(

@primitive_router.get("", name=f"Get {primitive_name} result", tags=["get result"])
def _(task_id: UUID4) -> result_cls:
task, config = repo.get(task_id, (None, None))
if not task:
raise HTTPException(404, "Task not found")
return task.execute(config)
raise HTTPException(418, "Not implemented yet")

router.include_router(primitive_router)

Expand Down
26 changes: 21 additions & 5 deletions app/db/session.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
from contextlib import contextmanager
from typing import Generator
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy import create_engine

from app.settings import settings
from sqlalchemy.pool import NullPool

default_engine = create_engine(url=settings.postgres_dsn.unicode_string())
engine_without_pool = create_engine(
url=settings.postgres_dsn.unicode_string(),
poolclass=NullPool,
)

SessionLocal = sessionmaker(bind=default_engine)
SessionLocalWithoutPool = sessionmaker(bind=engine_without_pool)

engine = create_engine(url=settings.postgres_dsn.unicode_string())
SessionLocal = sessionmaker(bind=engine, autoflush=False)

@contextmanager
def get_session(with_pool=True) -> Generator[Session, None, None]:
"""
Returns a generator that yields a session object for database operations.
def get_session() -> Generator[Session, None, None]:
with SessionLocal() as session:
Parameters:
with_pool (bool): A flag to determine if the session uses a connection pool.
Set to False when used in a Celery task. Defaults to True.
"""
maker = SessionLocal if with_pool else SessionLocalWithoutPool
with maker() as session:
yield session
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,3 @@ def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
field.default = None

cls.model_rebuild(force=True)


class OptionalFields:
def __class_getitem__(cls, item):
return type(f"{item.__name__}WithOptionalFields", (item, OptionalModel), {})
24 changes: 13 additions & 11 deletions app/domain/task/fd/config.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,53 @@
from pydantic import BaseModel, Field
from pydantic import Field
from typing import Annotated

from app.domain.common.optional_model import OptionalModel

class AidConfig(BaseModel):

class AidConfig(OptionalModel):
is_null_equal_null: bool


class DFDConfig(BaseModel):
class DFDConfig(OptionalModel):
is_null_equal_null: bool
threads: Annotated[int, Field(ge=1, le=8)]


class DepminerConfig(BaseModel):
class DepminerConfig(OptionalModel):
is_null_equal_null: bool


class FDepConfig(BaseModel):
class FDepConfig(OptionalModel):
is_null_equal_null: bool


class FUNConfig(BaseModel):
class FUNConfig(OptionalModel):
is_null_equal_null: bool


class FastFDsConfig(BaseModel):
class FastFDsConfig(OptionalModel):
is_null_equal_null: bool
max_lhs: Annotated[int, Field(ge=1, le=10)]
threads: Annotated[int, Field(ge=1, le=8)]


class FdMineConfig(BaseModel):
class FdMineConfig(OptionalModel):
is_null_equal_null: bool


class HyFDConfig(BaseModel):
class HyFDConfig(OptionalModel):
is_null_equal_null: bool


class PyroConfig(BaseModel):
class PyroConfig(OptionalModel):
is_null_equal_null: bool
error: Annotated[float, Field(ge=0, le=1)]
max_lhs: Annotated[int, Field(ge=1, le=10)]
threads: Annotated[int, Field(ge=1, le=8)]
seed: int


class TaneConfig(BaseModel):
class TaneConfig(OptionalModel):
is_null_equal_null: bool
error: Annotated[float, Field(ge=0, le=1)]
max_lhs: Annotated[int, Field(ge=1, le=10)]
6 changes: 3 additions & 3 deletions app/domain/task/task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
type AnyAlgoName = StrEnum


class TaskFactory[E: AnyAlgoName, T: AnyTask]:
def __init__(self, enum_used_as_keys: Type[E], general_task_cls: Type[T]) -> None:
class TaskFactory[E: AnyAlgoName, T: Type[AnyTask]]:
def __init__(self, enum_used_as_keys: Type[E], general_task_cls: T) -> None:
self.tasks: dict[E, T] = {}
self.enum_used_as_keys = enum_used_as_keys

Expand Down Expand Up @@ -41,4 +41,4 @@ def get_names(self) -> Iterable[E]:
return self.tasks.keys()


type AnyTaskFactory = TaskFactory[AnyAlgoName, AnyTask]
type AnyTaskFactory = TaskFactory[AnyAlgoName, Type[AnyTask]]
1 change: 1 addition & 0 deletions app/domain/worker/task/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .data_profiling_task import data_profiling_task as data_profiling_task
10 changes: 0 additions & 10 deletions app/domain/worker/task/data_profiling.py

This file was deleted.

85 changes: 85 additions & 0 deletions app/domain/worker/task/data_profiling_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging

from app.db.session import get_session
from app.worker import worker
from app.domain.task.abstract_task import AnyConf, AnyRes
from app.domain.task.primitive_factory import PrimitiveName, PrimitiveFactory
from app.domain.task.task_factory import AnyAlgoName
from app.domain.worker.task.resource_intensive_task import ResourceIntensiveTask
from pydantic import UUID4
import pandas as pd
from celery.signals import task_failure, task_prerun, task_postrun


@worker.task(base=ResourceIntensiveTask, ignore_result=True, max_retries=0)
def data_profiling_task(
primitive_name: PrimitiveName,
algo_name: AnyAlgoName,
file_id: UUID4,
config: AnyConf,
) -> AnyRes:
task_factory = PrimitiveFactory.get_by_name(primitive_name)
task_cls = task_factory.get_by_name(algo_name)

df = pd.read_csv(
"tests/datasets/university_fd.csv", sep=",", header=0
) # TODO: Replace with actual file (by file_id) in future

task = task_cls(df)
result = task.execute(config)
return result


@task_prerun.connect(sender=data_profiling_task)
def task_prerun_notifier(
sender,
task_id,
task,
args,
kwargs,
**_,
):
# TODO: Create Task in database and set status to "running" or similar
with get_session(with_pool=False) as session:
session

logging.critical(
f"From task_prerun_notifier ==> Running just before add() executes, {sender}"
)


@task_postrun.connect(sender=data_profiling_task)
def task_postrun_notifier(
sender,
task_id,
task,
args,
kwargs,
retval,
**_,
):
with get_session(with_pool=False) as session:
session

# TODO: Update Task in database and set status to "completed" or similar
logging.critical(f"From task_postrun_notifier ==> Ok, done!, {sender}")


@task_failure.connect(sender=data_profiling_task)
def task_failure_notifier(
sender,
task_id,
exception,
args,
kwargs,
traceback,
einfo,
**_,
):
with get_session(with_pool=False) as session:
session
# TODO: Update Task in database and set status to "failed" or similar

logging.critical(
f"From task_failure_notifier ==> Task failed successfully! 😅, {sender}"
)
13 changes: 0 additions & 13 deletions app/domain/worker/task/dummy.py

This file was deleted.

19 changes: 19 additions & 0 deletions app/domain/worker/task/resource_intensive_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from celery import Task
import resource
from app.settings import settings


class ResourceIntensiveTask(Task):
# There are default Celery time limits, see: https://docs.celeryq.dev/en/stable/userguide/workers.html#time-limits
time_limit = settings.worker_hard_time_limit_in_seconds
soft_time_limit = settings.worker_soft_time_limit_in_seconds

# There are custom memory limits using `resource` module
hard_memory_limit = settings.worker_hard_memory_limit
soft_memory_limit = settings.worker_soft_memory_limit

def before_start(self, task_id, args, kwargs) -> None:
resource.setrlimit(
resource.RLIMIT_AS, (self.soft_memory_limit, self.hard_memory_limit)
)
super().before_start(task_id, args, kwargs)
4 changes: 2 additions & 2 deletions app/settings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .settings import Settings
from .settings import get_settings

settings = Settings()
settings = get_settings()
5 changes: 5 additions & 0 deletions app/settings/celery_config.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
broker_connection_retry_on_startup = True
task_serializer = "pickle"
result_serializer = "pickle"
event_serializer = "json"
accept_content = ["application/json", "application/x-python-serialize"]
result_accept_content = ["application/json", "application/x-python-serialize"]
16 changes: 13 additions & 3 deletions app/settings/settings.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from functools import cached_property

from dotenv import load_dotenv, find_dotenv
from pydantic import AmqpDsn, PostgresDsn
from pydantic import AmqpDsn, PostgresDsn, Field, ByteSize
from pydantic_settings import BaseSettings

load_dotenv(find_dotenv(".env"))


class Settings(BaseSettings):
# Postgres settings
postgres_dialect_driver: str = "postgresql"

postgres_user: str
postgres_password: str
postgres_host: str
postgres_db: str
postgres_port: int = 5432

# RabbitMQ settings
rabbitmq_default_user: str
rabbitmq_default_password: str
rabbitmq_host: str
rabbitmq_port: int = 5672
# Worker limits
worker_soft_time_limit_in_seconds: int = Field(default=60, gt=0)
worker_hard_time_limit_in_seconds: int = Field(default=120, gt=0)
worker_soft_memory_limit: ByteSize = "2GB"
worker_hard_memory_limit: ByteSize = "4GB"

@cached_property
def rabbitmq_dsn(self) -> AmqpDsn:
Expand All @@ -41,3 +46,8 @@ def postgres_dsn(self) -> PostgresDsn:
port=self.postgres_port,
path=self.postgres_db,
)


def get_settings():
# TODO: create different settings based on environment (production, testing, etc.)
return Settings()

0 comments on commit d238f24

Please sign in to comment.