Skip to content

Commit

Permalink
feat: replace factories with tagged unions
Browse files Browse the repository at this point in the history
  • Loading branch information
toadharvard committed Mar 25, 2024
1 parent d238f24 commit 371ab9b
Show file tree
Hide file tree
Showing 19 changed files with 319 additions and 324 deletions.
61 changes: 12 additions & 49 deletions app/api/task.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,21 @@
from uuid import UUID
from fastapi import APIRouter, HTTPException
from pydantic import UUID4
from typing import Type
from app.domain.task.primitive_factory import PrimitiveName, PrimitiveFactory
from app.domain.task.task_factory import AnyAlgoName
from app.domain.task.abstract_task import AnyTask, AnyRes
from app.domain.worker.task.data_profiling_task import data_profiling_task
from app.domain.task import OneOfTaskConfig

router = APIRouter(prefix="/task")


def generate_set_task_endpoint(
primitive_name: PrimitiveName,
algo_name: AnyAlgoName,
task_cls: Type[AnyTask],
):
primitive_router = APIRouter(prefix=f"/{primitive_name}", tags=[primitive_name])
@router.post("")
def set_task(
file_id: UUID4,
config: OneOfTaskConfig,
) -> UUID4:
async_result = data_profiling_task.delay(file_id, config)
return UUID(async_result.id, version=4)

@primitive_router.post(
f"/{algo_name}",
name=f"Set {algo_name} task",
tags=["set task"],
)
def _(
file_id: UUID4,
config: task_cls.config_model_cls,
) -> UUID4:
async_result = data_profiling_task.delay(
primitive_name, algo_name, file_id, config
)
return async_result.id

router.include_router(primitive_router)


def generate_get_task_result_endpoint(
primitive_name: PrimitiveName, result_cls: Type[AnyRes]
):
primitive_router = APIRouter(prefix=f"/{primitive_name}", tags=[primitive_name])

@primitive_router.get("", name=f"Get {primitive_name} result", tags=["get result"])
def _(task_id: UUID4) -> result_cls:
raise HTTPException(418, "Not implemented yet")

router.include_router(primitive_router)


for primitive_name in PrimitiveFactory.get_names():
task_factory = PrimitiveFactory.get_by_name(primitive_name)
for algo_name in task_factory.get_names():
task_cls = task_factory.get_by_name(algo_name)
generate_set_task_endpoint(primitive_name, algo_name, task_cls)

for primitive_name in PrimitiveFactory.get_names():
task_factory = PrimitiveFactory.get_by_name(primitive_name)
generate_get_task_result_endpoint(
primitive_name, task_factory.general_task_cls.result_model_cls
)
@router.get("/{task_id}")
def retrieve_task(task_id: UUID4) -> None:
raise HTTPException(418, "Not implemented yet")
32 changes: 32 additions & 0 deletions app/domain/task/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from app.domain.task.afd import AfdTask, AfdTaskConfig, AfdTaskResult
from app.domain.task.fd import FdTaskConfig, FdTaskResult
from typing import Annotated, Union, assert_never
from pydantic import Field
from app.domain.task.fd import FdTask
from app.domain.task.primitive_name import PrimitiveName


OneOfTaskConfig = Annotated[
Union[
FdTaskConfig,
AfdTaskConfig,
],
Field(discriminator="primitive_name"),
]

OneOfTaskResult = Annotated[
Union[
FdTaskResult,
AfdTaskResult,
],
Field(discriminator="primitive_name"),
]


def match_task_by_primitive_name(primitive_name: PrimitiveName):
match primitive_name:
case PrimitiveName.fd:
return FdTask()
case PrimitiveName.afd:
return AfdTask()
assert_never(primitive_name)
67 changes: 39 additions & 28 deletions app/domain/task/abstract_task.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,51 @@
from typing import Type
from abc import ABC, abstractmethod
from pydantic import BaseModel
from desbordante import Algorithm
from enum import StrEnum
from typing import Any, Protocol
import desbordante
import pandas
from pydantic import BaseModel


type AnyAlgo = Algorithm
type AnyConf = BaseModel
type AnyRes = BaseModel
class AlgoConfig(Protocol):
@property
def algo_name(self) -> StrEnum: ...

# forces to use pydantic classes there
model_dump = BaseModel.model_dump

class AbstractTask[Algo: AnyAlgo, Conf: AnyConf, Res: AnyRes](ABC):
algo: Algo
config_model_cls: Type[Conf]
result_model_cls: Type[Res]

def __init__(self, table: pandas.DataFrame) -> None:
try:
self.algo
self.config_model_cls
self.result_model_cls
except AttributeError:
raise NotImplementedError(
"Attributes `algo`, `config_model_cls` and `result_model_cls` must be implemented in non-abstract class"
)
class TaskConfig(Protocol):
@property
def primitive_name(self) -> StrEnum: ...

self.table = table
self.algo.load_data(table=table)
@property
def config(self) -> AlgoConfig: ...

def execute(self, config: Conf | None = None) -> Res:
options = config.model_dump(exclude_unset=True) if config else {}
self.algo.execute(**options)
return self.collect_result()
# forces to use pydantic classes there
model_dump = BaseModel.model_dump

@abstractmethod
def collect_result(self) -> Res: ...

class TaskResult(Protocol):
@property
def primitive_name(self) -> StrEnum: ...

result: Any

# forces to use pydantic classes there
model_dump = BaseModel.model_dump

type AnyTask = AbstractTask[AnyAlgo, AnyConf, AnyRes]

class Task[C: TaskConfig, R: TaskResult](ABC):
@abstractmethod
def match_algo_by_name(self, algo_name) -> desbordante.Algorithm: ...

@abstractmethod
def collect_result(self, algo) -> R: ...

def execute(self, table: pandas.DataFrame, task_config: C) -> R:
algo_config = task_config.config
options = algo_config.model_dump(exclude_unset=True, exclude={"algo_name"})
algo = self.match_algo_by_name(algo_config.algo_name)
algo.load_data(table=table)
algo.execute(**options)
return self.collect_result(algo)
36 changes: 36 additions & 0 deletions app/domain/task/afd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Literal, assert_never
from pydantic import BaseModel
from desbordante.fd import FdAlgorithm # This is not a typo
from app.domain.task.abstract_task import Task
from app.domain.task.primitive_name import PrimitiveName
from .config import OneOfAfdConfig
from .result import AfdAlgoResult, FdModel
from .algo_name import AfdAlgoName
from desbordante.afd.algorithms import Pyro, Tane


class BaseAfdTaskModel(BaseModel):
primitive_name: Literal[PrimitiveName.afd]


class AfdTaskConfig(BaseAfdTaskModel):
config: OneOfAfdConfig


class AfdTaskResult(BaseAfdTaskModel):
result: AfdAlgoResult


class AfdTask(Task[AfdTaskConfig, AfdTaskResult]):
def collect_result(self, algo: FdAlgorithm) -> AfdTaskResult:
fds = algo.get_fds()
algo_result = AfdAlgoResult(fds=list(map(FdModel.from_fd, fds)))
return AfdTaskResult(primitive_name=PrimitiveName.afd, result=algo_result)

def match_algo_by_name(self, algo_name: AfdAlgoName) -> FdAlgorithm:
match algo_name:
case AfdAlgoName.Pyro:
return Pyro()
case AfdAlgoName.Tane:
return Tane()
assert_never(algo_name)
6 changes: 6 additions & 0 deletions app/domain/task/afd/algo_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import StrEnum, auto


class AfdAlgoName(StrEnum):
Pyro = auto()
Tane = auto()
34 changes: 34 additions & 0 deletions app/domain/task/afd/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field

from app.domain.task.afd.algo_name import AfdAlgoName


class BaseAfdConfig(BaseModel): ...


class PyroConfig(BaseAfdConfig):
algo_name: Literal[AfdAlgoName.Pyro]

is_null_equal_null: bool
error: Annotated[float, Field(ge=0, le=1)]
max_lhs: Annotated[int, Field(ge=1, le=10)]
threads: Annotated[int, Field(ge=1, le=8)]
seed: int


class TaneConfig(BaseAfdConfig):
algo_name: Literal[AfdAlgoName.Tane]

is_null_equal_null: bool
error: Annotated[float, Field(ge=0, le=1)]
max_lhs: Annotated[int, Field(ge=1, le=10)]


OneOfAfdConfig = Annotated[
Union[
PyroConfig,
TaneConfig,
],
Field(discriminator="algo_name"),
]
4 changes: 4 additions & 0 deletions app/domain/task/afd/result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from app.domain.task.fd.result import FdAlgoResult, FdModel

AfdAlgoResult = FdAlgoResult
FdModel = FdModel
Loading

0 comments on commit 371ab9b

Please sign in to comment.