From 371ab9b883fa6d9547679434ec710a32b72d2c6d Mon Sep 17 00:00:00 2001 From: Vadim Yakshigulov Date: Mon, 25 Mar 2024 11:02:43 +0300 Subject: [PATCH] feat: replace factories with tagged unions --- app/api/task.py | 61 ++----- app/domain/task/__init__.py | 32 ++++ app/domain/task/abstract_task.py | 67 ++++---- app/domain/task/afd/__init__.py | 36 +++++ app/domain/task/afd/algo_name.py | 6 + app/domain/task/afd/config.py | 34 ++++ app/domain/task/afd/result.py | 4 + app/domain/task/fd/__init__.py | 153 ++++++------------ app/domain/task/fd/algo_name.py | 14 ++ app/domain/task/fd/config.py | 67 ++++++-- app/domain/task/fd/result.py | 6 +- app/domain/task/primitive_factory.py | 46 ------ app/domain/task/primitive_name.py | 13 ++ app/domain/task/task_factory.py | 44 ----- app/domain/worker/task/data_profiling_task.py | 19 +-- .../worker/task/resource_intensive_task.py | 1 - pyproject.toml | 1 + tests/domain/task/test_fd.py | 28 ++-- tests/domain/task/test_primitive_factory.py | 11 -- 19 files changed, 319 insertions(+), 324 deletions(-) create mode 100644 app/domain/task/__init__.py create mode 100644 app/domain/task/afd/__init__.py create mode 100644 app/domain/task/afd/algo_name.py create mode 100644 app/domain/task/afd/config.py create mode 100644 app/domain/task/afd/result.py create mode 100644 app/domain/task/fd/algo_name.py delete mode 100644 app/domain/task/primitive_factory.py create mode 100644 app/domain/task/primitive_name.py delete mode 100644 app/domain/task/task_factory.py delete mode 100644 tests/domain/task/test_primitive_factory.py diff --git a/app/api/task.py b/app/api/task.py index e3e1000f..530d9961 100644 --- a/app/api/task.py +++ b/app/api/task.py @@ -1,58 +1,21 @@ +from uuid import UUID from fastapi import APIRouter, HTTPException from pydantic import UUID4 -from typing import Type -from app.domain.task.primitive_factory import PrimitiveName, PrimitiveFactory -from app.domain.task.task_factory import AnyAlgoName -from app.domain.task.abstract_task import AnyTask, AnyRes from app.domain.worker.task.data_profiling_task import data_profiling_task +from app.domain.task import OneOfTaskConfig router = APIRouter(prefix="/task") -def generate_set_task_endpoint( - primitive_name: PrimitiveName, - algo_name: AnyAlgoName, - task_cls: Type[AnyTask], -): - primitive_router = APIRouter(prefix=f"/{primitive_name}", tags=[primitive_name]) +@router.post("") +def set_task( + file_id: UUID4, + config: OneOfTaskConfig, +) -> UUID4: + async_result = data_profiling_task.delay(file_id, config) + return UUID(async_result.id, version=4) - @primitive_router.post( - f"/{algo_name}", - name=f"Set {algo_name} task", - tags=["set task"], - ) - def _( - file_id: UUID4, - config: task_cls.config_model_cls, - ) -> UUID4: - async_result = data_profiling_task.delay( - primitive_name, algo_name, file_id, config - ) - return async_result.id - router.include_router(primitive_router) - - -def generate_get_task_result_endpoint( - primitive_name: PrimitiveName, result_cls: Type[AnyRes] -): - primitive_router = APIRouter(prefix=f"/{primitive_name}", tags=[primitive_name]) - - @primitive_router.get("", name=f"Get {primitive_name} result", tags=["get result"]) - def _(task_id: UUID4) -> result_cls: - raise HTTPException(418, "Not implemented yet") - - router.include_router(primitive_router) - - -for primitive_name in PrimitiveFactory.get_names(): - task_factory = PrimitiveFactory.get_by_name(primitive_name) - for algo_name in task_factory.get_names(): - task_cls = task_factory.get_by_name(algo_name) - generate_set_task_endpoint(primitive_name, algo_name, task_cls) - -for primitive_name in PrimitiveFactory.get_names(): - task_factory = PrimitiveFactory.get_by_name(primitive_name) - generate_get_task_result_endpoint( - primitive_name, task_factory.general_task_cls.result_model_cls - ) +@router.get("/{task_id}") +def retrieve_task(task_id: UUID4) -> None: + raise HTTPException(418, "Not implemented yet") diff --git a/app/domain/task/__init__.py b/app/domain/task/__init__.py new file mode 100644 index 00000000..d0f4aa6b --- /dev/null +++ b/app/domain/task/__init__.py @@ -0,0 +1,32 @@ +from app.domain.task.afd import AfdTask, AfdTaskConfig, AfdTaskResult +from app.domain.task.fd import FdTaskConfig, FdTaskResult +from typing import Annotated, Union, assert_never +from pydantic import Field +from app.domain.task.fd import FdTask +from app.domain.task.primitive_name import PrimitiveName + + +OneOfTaskConfig = Annotated[ + Union[ + FdTaskConfig, + AfdTaskConfig, + ], + Field(discriminator="primitive_name"), +] + +OneOfTaskResult = Annotated[ + Union[ + FdTaskResult, + AfdTaskResult, + ], + Field(discriminator="primitive_name"), +] + + +def match_task_by_primitive_name(primitive_name: PrimitiveName): + match primitive_name: + case PrimitiveName.fd: + return FdTask() + case PrimitiveName.afd: + return AfdTask() + assert_never(primitive_name) diff --git a/app/domain/task/abstract_task.py b/app/domain/task/abstract_task.py index 8595853c..df523df3 100644 --- a/app/domain/task/abstract_task.py +++ b/app/domain/task/abstract_task.py @@ -1,40 +1,51 @@ -from typing import Type from abc import ABC, abstractmethod -from pydantic import BaseModel -from desbordante import Algorithm +from enum import StrEnum +from typing import Any, Protocol +import desbordante import pandas +from pydantic import BaseModel -type AnyAlgo = Algorithm -type AnyConf = BaseModel -type AnyRes = BaseModel +class AlgoConfig(Protocol): + @property + def algo_name(self) -> StrEnum: ... + # forces to use pydantic classes there + model_dump = BaseModel.model_dump -class AbstractTask[Algo: AnyAlgo, Conf: AnyConf, Res: AnyRes](ABC): - algo: Algo - config_model_cls: Type[Conf] - result_model_cls: Type[Res] - def __init__(self, table: pandas.DataFrame) -> None: - try: - self.algo - self.config_model_cls - self.result_model_cls - except AttributeError: - raise NotImplementedError( - "Attributes `algo`, `config_model_cls` and `result_model_cls` must be implemented in non-abstract class" - ) +class TaskConfig(Protocol): + @property + def primitive_name(self) -> StrEnum: ... - self.table = table - self.algo.load_data(table=table) + @property + def config(self) -> AlgoConfig: ... - def execute(self, config: Conf | None = None) -> Res: - options = config.model_dump(exclude_unset=True) if config else {} - self.algo.execute(**options) - return self.collect_result() + # forces to use pydantic classes there + model_dump = BaseModel.model_dump - @abstractmethod - def collect_result(self) -> Res: ... +class TaskResult(Protocol): + @property + def primitive_name(self) -> StrEnum: ... + + result: Any + + # forces to use pydantic classes there + model_dump = BaseModel.model_dump -type AnyTask = AbstractTask[AnyAlgo, AnyConf, AnyRes] + +class Task[C: TaskConfig, R: TaskResult](ABC): + @abstractmethod + def match_algo_by_name(self, algo_name) -> desbordante.Algorithm: ... + + @abstractmethod + def collect_result(self, algo) -> R: ... + + def execute(self, table: pandas.DataFrame, task_config: C) -> R: + algo_config = task_config.config + options = algo_config.model_dump(exclude_unset=True, exclude={"algo_name"}) + algo = self.match_algo_by_name(algo_config.algo_name) + algo.load_data(table=table) + algo.execute(**options) + return self.collect_result(algo) diff --git a/app/domain/task/afd/__init__.py b/app/domain/task/afd/__init__.py new file mode 100644 index 00000000..394ce317 --- /dev/null +++ b/app/domain/task/afd/__init__.py @@ -0,0 +1,36 @@ +from typing import Literal, assert_never +from pydantic import BaseModel +from desbordante.fd import FdAlgorithm # This is not a typo +from app.domain.task.abstract_task import Task +from app.domain.task.primitive_name import PrimitiveName +from .config import OneOfAfdConfig +from .result import AfdAlgoResult, FdModel +from .algo_name import AfdAlgoName +from desbordante.afd.algorithms import Pyro, Tane + + +class BaseAfdTaskModel(BaseModel): + primitive_name: Literal[PrimitiveName.afd] + + +class AfdTaskConfig(BaseAfdTaskModel): + config: OneOfAfdConfig + + +class AfdTaskResult(BaseAfdTaskModel): + result: AfdAlgoResult + + +class AfdTask(Task[AfdTaskConfig, AfdTaskResult]): + def collect_result(self, algo: FdAlgorithm) -> AfdTaskResult: + fds = algo.get_fds() + algo_result = AfdAlgoResult(fds=list(map(FdModel.from_fd, fds))) + return AfdTaskResult(primitive_name=PrimitiveName.afd, result=algo_result) + + def match_algo_by_name(self, algo_name: AfdAlgoName) -> FdAlgorithm: + match algo_name: + case AfdAlgoName.Pyro: + return Pyro() + case AfdAlgoName.Tane: + return Tane() + assert_never(algo_name) diff --git a/app/domain/task/afd/algo_name.py b/app/domain/task/afd/algo_name.py new file mode 100644 index 00000000..5a55ccb6 --- /dev/null +++ b/app/domain/task/afd/algo_name.py @@ -0,0 +1,6 @@ +from enum import StrEnum, auto + + +class AfdAlgoName(StrEnum): + Pyro = auto() + Tane = auto() diff --git a/app/domain/task/afd/config.py b/app/domain/task/afd/config.py new file mode 100644 index 00000000..3c8f2581 --- /dev/null +++ b/app/domain/task/afd/config.py @@ -0,0 +1,34 @@ +from typing import Annotated, Literal, Union +from pydantic import BaseModel, Field + +from app.domain.task.afd.algo_name import AfdAlgoName + + +class BaseAfdConfig(BaseModel): ... + + +class PyroConfig(BaseAfdConfig): + algo_name: Literal[AfdAlgoName.Pyro] + + is_null_equal_null: bool + error: Annotated[float, Field(ge=0, le=1)] + max_lhs: Annotated[int, Field(ge=1, le=10)] + threads: Annotated[int, Field(ge=1, le=8)] + seed: int + + +class TaneConfig(BaseAfdConfig): + algo_name: Literal[AfdAlgoName.Tane] + + is_null_equal_null: bool + error: Annotated[float, Field(ge=0, le=1)] + max_lhs: Annotated[int, Field(ge=1, le=10)] + + +OneOfAfdConfig = Annotated[ + Union[ + PyroConfig, + TaneConfig, + ], + Field(discriminator="algo_name"), +] diff --git a/app/domain/task/afd/result.py b/app/domain/task/afd/result.py new file mode 100644 index 00000000..ce59ecf9 --- /dev/null +++ b/app/domain/task/afd/result.py @@ -0,0 +1,4 @@ +from app.domain.task.fd.result import FdAlgoResult, FdModel + +AfdAlgoResult = FdAlgoResult +FdModel = FdModel diff --git a/app/domain/task/fd/__init__.py b/app/domain/task/fd/__init__.py index c2d6442e..efe776bb 100644 --- a/app/domain/task/fd/__init__.py +++ b/app/domain/task/fd/__init__.py @@ -1,16 +1,7 @@ -from .config import ( - AidConfig, - DFDConfig, - DepminerConfig, - FDepConfig, - FUNConfig, - FastFDsConfig, - FdMineConfig, - HyFDConfig, - PyroConfig, - TaneConfig, -) -from .result import FDModel, FDAlgoResult +from typing import Literal +from pydantic import BaseModel +from app.domain.task.abstract_task import Task +from typing import assert_never from desbordante.fd import FdAlgorithm from desbordante.fd.algorithms import ( Aid, @@ -24,92 +15,50 @@ Pyro, Tane, ) -from app.domain.task.abstract_task import AbstractTask, AnyConf -from app.domain.task.task_factory import TaskFactory -from enum import auto, StrEnum - - -class FDAlgoName(StrEnum): - Aid = auto() - DFD = auto() - Depminer = auto() - FDep = auto() - FUN = auto() - FastFDs = auto() - FdMine = auto() - HyFD = auto() - Pyro = auto() - Tane = auto() - - -class FDTask[FDAlgo: FdAlgorithm, Conf: AnyConf]( - AbstractTask[FDAlgo, Conf, FDAlgoResult] -): - result_model_cls = FDAlgoResult - - def collect_result(self) -> FDAlgoResult: - fds = self.algo.get_fds() - return FDAlgoResult(fds=list(map(FDModel.from_fd, fds))) - - -fd_factory = TaskFactory(FDAlgoName, FDTask) - - -@fd_factory.register_task(FDAlgoName.Aid) -class AidTask(FDTask[Aid, AidConfig]): - config_model_cls = AidConfig - algo = Aid() - - -@fd_factory.register_task(FDAlgoName.DFD) -class DFDTask(FDTask[DFD, DFDConfig]): - config_model_cls = DFDConfig - algo = DFD() - - -@fd_factory.register_task(FDAlgoName.Depminer) -class DepminerTask(FDTask[Depminer, DepminerConfig]): - config_model_cls = DepminerConfig - algo = Depminer() - - -@fd_factory.register_task(FDAlgoName.FDep) -class FDepTask(FDTask[FDep, FDepConfig]): - config_model_cls = FDepConfig - algo = FDep() - - -@fd_factory.register_task(FDAlgoName.FUN) -class FUNTask(FDTask[FUN, FUNConfig]): - config_model_cls = FUNConfig - algo = FUN() - - -@fd_factory.register_task(FDAlgoName.FastFDs) -class FastFDsTask(FDTask[FastFDs, FastFDsConfig]): - config_model_cls = FastFDsConfig - algo = FastFDs() - - -@fd_factory.register_task(FDAlgoName.FdMine) -class FdMineTask(FDTask[FdMine, FdMineConfig]): - config_model_cls = FdMineConfig - algo = FdMine() - - -@fd_factory.register_task(FDAlgoName.HyFD) -class HyFDTask(FDTask[HyFD, HyFDConfig]): - config_model_cls = HyFDConfig - algo = HyFD() - - -@fd_factory.register_task(FDAlgoName.Pyro) -class PyroTask(FDTask[Pyro, PyroConfig]): - config_model_cls = PyroConfig - algo = Pyro() - - -@fd_factory.register_task(FDAlgoName.Tane) -class TaneTask(FDTask[Tane, TaneConfig]): - config_model_cls = TaneConfig - algo = Tane() +from app.domain.task.fd.algo_name import FdAlgoName +from app.domain.task.primitive_name import PrimitiveName +from .config import OneOfFdAlgoConfig +from .result import FdAlgoResult, FdModel + + +class BaseFdTaskModel(BaseModel): + primitive_name: Literal[PrimitiveName.fd] + + +class FdTaskConfig(BaseFdTaskModel): + config: OneOfFdAlgoConfig + + +class FdTaskResult(BaseFdTaskModel): + result: FdAlgoResult + + +class FdTask(Task[FdTaskConfig, FdTaskResult]): + def collect_result(self, algo: FdAlgorithm) -> FdTaskResult: + fds = algo.get_fds() + algo_result = FdAlgoResult(fds=list(map(FdModel.from_fd, fds))) + return FdTaskResult(primitive_name=PrimitiveName.fd, result=algo_result) + + def match_algo_by_name(self, algo_name: FdAlgoName) -> FdAlgorithm: + match algo_name: + case FdAlgoName.Aid: + return Aid() + case FdAlgoName.DFD: + return DFD() + case FdAlgoName.Depminer: + return Depminer() + case FdAlgoName.FDep: + return FDep() + case FdAlgoName.FUN: + return FUN() + case FdAlgoName.FastFDs: + return FastFDs() + case FdAlgoName.FdMine: + return FdMine() + case FdAlgoName.HyFD: + return HyFD() + case FdAlgoName.Pyro: + return Pyro() + case FdAlgoName.Tane: + return Tane() + assert_never(algo_name) diff --git a/app/domain/task/fd/algo_name.py b/app/domain/task/fd/algo_name.py new file mode 100644 index 00000000..93514588 --- /dev/null +++ b/app/domain/task/fd/algo_name.py @@ -0,0 +1,14 @@ +from enum import StrEnum, auto + + +class FdAlgoName(StrEnum): + Aid = auto() + DFD = auto() + Depminer = auto() + FDep = auto() + FUN = auto() + FastFDs = auto() + FdMine = auto() + HyFD = auto() + Pyro = auto() + Tane = auto() diff --git a/app/domain/task/fd/config.py b/app/domain/task/fd/config.py index e6bf0899..f1a70234 100644 --- a/app/domain/task/fd/config.py +++ b/app/domain/task/fd/config.py @@ -1,45 +1,67 @@ -from pydantic import Field -from typing import Annotated +from pydantic import BaseModel, Field +from typing import Annotated, Literal, Union -from app.domain.common.optional_model import OptionalModel +from app.domain.task.fd.algo_name import FdAlgoName -class AidConfig(OptionalModel): +# Should be OptionalModel with required field `algo_name` +class BaseFdConfig(BaseModel): ... + + +class AidConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.Aid] + is_null_equal_null: bool -class DFDConfig(OptionalModel): +class DFDConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.DFD] + is_null_equal_null: bool threads: Annotated[int, Field(ge=1, le=8)] -class DepminerConfig(OptionalModel): +class DepminerConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.Depminer] + is_null_equal_null: bool -class FDepConfig(OptionalModel): +class FDepConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.FDep] + is_null_equal_null: bool -class FUNConfig(OptionalModel): +class FUNConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.FUN] + is_null_equal_null: bool -class FastFDsConfig(OptionalModel): +class FastFDsConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.FastFDs] + is_null_equal_null: bool max_lhs: Annotated[int, Field(ge=1, le=10)] threads: Annotated[int, Field(ge=1, le=8)] -class FdMineConfig(OptionalModel): +class FdMineConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.FdMine] + is_null_equal_null: bool -class HyFDConfig(OptionalModel): +class HyFDConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.HyFD] + is_null_equal_null: bool -class PyroConfig(OptionalModel): +class PyroConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.Pyro] + is_null_equal_null: bool error: Annotated[float, Field(ge=0, le=1)] max_lhs: Annotated[int, Field(ge=1, le=10)] @@ -47,7 +69,26 @@ class PyroConfig(OptionalModel): seed: int -class TaneConfig(OptionalModel): +class TaneConfig(BaseFdConfig): + algo_name: Literal[FdAlgoName.Tane] + is_null_equal_null: bool error: Annotated[float, Field(ge=0, le=1)] max_lhs: Annotated[int, Field(ge=1, le=10)] + + +OneOfFdAlgoConfig = Annotated[ + Union[ + AidConfig, + DFDConfig, + DepminerConfig, + FDepConfig, + FUNConfig, + FastFDsConfig, + FdMineConfig, + HyFDConfig, + PyroConfig, + TaneConfig, + ], + Field(discriminator="algo_name"), +] diff --git a/app/domain/task/fd/result.py b/app/domain/task/fd/result.py index 744a8009..d38d2f0e 100644 --- a/app/domain/task/fd/result.py +++ b/app/domain/task/fd/result.py @@ -2,7 +2,7 @@ from desbordante.fd import FD -class FDModel(BaseModel): +class FdModel(BaseModel): @classmethod def from_fd(cls, fd: FD): return cls(lhs_indices=fd.lhs_indices, rhs_index=fd.rhs_index) @@ -11,5 +11,5 @@ def from_fd(cls, fd: FD): rhs_index: int -class FDAlgoResult(BaseModel): - fds: list[FDModel] +class FdAlgoResult(BaseModel): + fds: list[FdModel] diff --git a/app/domain/task/primitive_factory.py b/app/domain/task/primitive_factory.py deleted file mode 100644 index bafbb165..00000000 --- a/app/domain/task/primitive_factory.py +++ /dev/null @@ -1,46 +0,0 @@ -from enum import auto -from enum import StrEnum -from app.domain.task.task_factory import AnyTaskFactory -from app.domain.task.fd import fd_factory -from typing import Iterable - - -class PrimitiveName(StrEnum): - fd = auto() - # afd = auto() - # ar = auto() - # ac = auto() - # fd_verification = auto() - # mfd_verification = auto() - # statistics = auto() - # ucc = auto() - # ucc_verification = auto() - - -class PrimitiveFactory[F: AnyTaskFactory]: - primitives: dict[PrimitiveName, AnyTaskFactory] = {} - - @classmethod - def register(cls, name: PrimitiveName, factory: F) -> F: - cls.primitives[name] = factory - return factory - - @classmethod - def get_by_name(cls, name: PrimitiveName) -> AnyTaskFactory: - factory = cls.primitives.get(name, None) - if not factory: - raise ValueError( - f"Can't find task factory by provided primitive name: {name}. Do you forgot to register it in PrimitiveFactory?" - ) - return factory - - @classmethod - def get_all(cls) -> Iterable[AnyTaskFactory]: - return cls.primitives.values() - - @classmethod - def get_names(cls) -> Iterable[PrimitiveName]: - return cls.primitives.keys() - - -PrimitiveFactory.register(PrimitiveName.fd, fd_factory) diff --git a/app/domain/task/primitive_name.py b/app/domain/task/primitive_name.py new file mode 100644 index 00000000..1959867a --- /dev/null +++ b/app/domain/task/primitive_name.py @@ -0,0 +1,13 @@ +from enum import StrEnum, auto + + +class PrimitiveName(StrEnum): + fd = auto() + afd = auto() + # ar = auto() + # ac = auto() + # fd_verification = auto() + # mfd_verification = auto() + # statistics = auto() + # ucc = auto() + # ucc_verification = auto() diff --git a/app/domain/task/task_factory.py b/app/domain/task/task_factory.py deleted file mode 100644 index 68b01945..00000000 --- a/app/domain/task/task_factory.py +++ /dev/null @@ -1,44 +0,0 @@ -from enum import StrEnum -from app.domain.task.abstract_task import AnyTask -from typing import Iterable, Type - -type AnyAlgoName = StrEnum - - -class TaskFactory[E: AnyAlgoName, T: Type[AnyTask]]: - def __init__(self, enum_used_as_keys: Type[E], general_task_cls: T) -> None: - self.tasks: dict[E, T] = {} - self.enum_used_as_keys = enum_used_as_keys - - try: - general_task_cls.result_model_cls - except AttributeError: - raise ValueError( - "Attribute `result_model_cls` must be implemented in general_task_cls" - ) - - self.general_task_cls = general_task_cls - - def register_task(self, task_type: E): - def decorator(task_cls: Type[AnyTask]): - self.tasks[task_type] = task_cls - return task_cls - - return decorator - - def get_by_name(self, name: E) -> T: - task = self.tasks.get(name, None) - if not task: - raise ValueError( - f"Can't find task by provided algorithm name: {name}. Do you forgot to register it in TaskFactory?" - ) - return task - - def get_all(self) -> Iterable[T]: - return self.tasks.values() - - def get_names(self) -> Iterable[E]: - return self.tasks.keys() - - -type AnyTaskFactory = TaskFactory[AnyAlgoName, Type[AnyTask]] diff --git a/app/domain/worker/task/data_profiling_task.py b/app/domain/worker/task/data_profiling_task.py index a9551c07..621ef930 100644 --- a/app/domain/worker/task/data_profiling_task.py +++ b/app/domain/worker/task/data_profiling_task.py @@ -1,10 +1,10 @@ import logging +from typing import Any from app.db.session import get_session +from app.domain.task import OneOfTaskConfig +from app.domain.task import match_task_by_primitive_name from app.worker import worker -from app.domain.task.abstract_task import AnyConf, AnyRes -from app.domain.task.primitive_factory import PrimitiveName, PrimitiveFactory -from app.domain.task.task_factory import AnyAlgoName from app.domain.worker.task.resource_intensive_task import ResourceIntensiveTask from pydantic import UUID4 import pandas as pd @@ -13,20 +13,15 @@ @worker.task(base=ResourceIntensiveTask, ignore_result=True, max_retries=0) def data_profiling_task( - primitive_name: PrimitiveName, - algo_name: AnyAlgoName, file_id: UUID4, - config: AnyConf, -) -> AnyRes: - task_factory = PrimitiveFactory.get_by_name(primitive_name) - task_cls = task_factory.get_by_name(algo_name) - + config: OneOfTaskConfig, +) -> Any: df = pd.read_csv( "tests/datasets/university_fd.csv", sep=",", header=0 ) # TODO: Replace with actual file (by file_id) in future - task = task_cls(df) - result = task.execute(config) + task = match_task_by_primitive_name(config.primitive_name) + result = task.execute(df, config) # type: ignore return result diff --git a/app/domain/worker/task/resource_intensive_task.py b/app/domain/worker/task/resource_intensive_task.py index f19035f1..fe938175 100644 --- a/app/domain/worker/task/resource_intensive_task.py +++ b/app/domain/worker/task/resource_intensive_task.py @@ -16,4 +16,3 @@ def before_start(self, task_id, args, kwargs) -> None: resource.setrlimit( resource.RLIMIT_AS, (self.soft_memory_limit, self.hard_memory_limit) ) - super().before_start(task_id, args, kwargs) diff --git a/pyproject.toml b/pyproject.toml index a80eafb4..182a7520 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ watchdog = "^4.0.0" pytest-cov = "^4.1.0" ipykernel = "^6.29.3" polyfactory = "^2.15.0" +pyright = "^1.1.355" [build-system] requires = ["poetry-core"] diff --git a/tests/domain/task/test_fd.py b/tests/domain/task/test_fd.py index a98a16ce..868bc676 100644 --- a/tests/domain/task/test_fd.py +++ b/tests/domain/task/test_fd.py @@ -1,28 +1,26 @@ -from app.domain.task.fd import fd_factory +from app.domain.task.fd import FdTask, FdTaskConfig import pytest import pandas as pd import logging from polyfactory.factories.pydantic_factory import ModelFactory - -@pytest.mark.parametrize("task_cls", fd_factory.get_all()) -@pytest.mark.parametrize( - "table", [pd.read_csv("tests/datasets/university_fd.csv", sep=",", header=0)] -) -def test_with_default_params(task_cls, table): - task = task_cls(table) - result = task.execute() - logging.info(result) +# TODO: change when optional fields are suported +# @pytest.mark.parametrize( +# "table", [pd.read_csv("tests/datasets/university_fd.csv", sep=",", header=0)] +# ) +# def test_with_default_params(task_cls, table): +# task = FdTask() +# result = task.execute(table) +# logging.info(result) -@pytest.mark.parametrize("task_cls", fd_factory.get_all()) @pytest.mark.parametrize( "table", [pd.read_csv("tests/datasets/university_fd.csv", sep=",", header=0)] ) -def test_with_faked_params(task_cls, table): - task = task_cls(table) - config_factory = ModelFactory.create_factory(model=task_cls.config_model_cls) +def test_with_faked_params(table): + task = FdTask() + config_factory = ModelFactory.create_factory(FdTaskConfig) config = config_factory.build(factory_use_construct=True) logging.info(config) - result = task.execute(config) + result = task.execute(table, config) logging.info(result) diff --git a/tests/domain/task/test_primitive_factory.py b/tests/domain/task/test_primitive_factory.py deleted file mode 100644 index 9b9bb05a..00000000 --- a/tests/domain/task/test_primitive_factory.py +++ /dev/null @@ -1,11 +0,0 @@ -from app.domain.task.primitive_factory import PrimitiveFactory, PrimitiveName -import pytest - - -@pytest.mark.parametrize( - "primitive_name", [primitive_name.value for primitive_name in PrimitiveName] -) -def test_get_task_by_primitive_name(primitive_name: str): - task_factory = PrimitiveFactory.get_by_name(primitive_name) - for name in task_factory.enum_used_as_keys: - task_factory.get_by_name(name.value)