-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: replace factories with tagged unions
- Loading branch information
1 parent
d238f24
commit 371ab9b
Showing
19 changed files
with
319 additions
and
324 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,21 @@ | ||
from uuid import UUID | ||
from fastapi import APIRouter, HTTPException | ||
from pydantic import UUID4 | ||
from typing import Type | ||
from app.domain.task.primitive_factory import PrimitiveName, PrimitiveFactory | ||
from app.domain.task.task_factory import AnyAlgoName | ||
from app.domain.task.abstract_task import AnyTask, AnyRes | ||
from app.domain.worker.task.data_profiling_task import data_profiling_task | ||
from app.domain.task import OneOfTaskConfig | ||
|
||
router = APIRouter(prefix="/task") | ||
|
||
|
||
def generate_set_task_endpoint( | ||
primitive_name: PrimitiveName, | ||
algo_name: AnyAlgoName, | ||
task_cls: Type[AnyTask], | ||
): | ||
primitive_router = APIRouter(prefix=f"/{primitive_name}", tags=[primitive_name]) | ||
@router.post("") | ||
def set_task( | ||
file_id: UUID4, | ||
config: OneOfTaskConfig, | ||
) -> UUID4: | ||
async_result = data_profiling_task.delay(file_id, config) | ||
return UUID(async_result.id, version=4) | ||
|
||
@primitive_router.post( | ||
f"/{algo_name}", | ||
name=f"Set {algo_name} task", | ||
tags=["set task"], | ||
) | ||
def _( | ||
file_id: UUID4, | ||
config: task_cls.config_model_cls, | ||
) -> UUID4: | ||
async_result = data_profiling_task.delay( | ||
primitive_name, algo_name, file_id, config | ||
) | ||
return async_result.id | ||
|
||
router.include_router(primitive_router) | ||
|
||
|
||
def generate_get_task_result_endpoint( | ||
primitive_name: PrimitiveName, result_cls: Type[AnyRes] | ||
): | ||
primitive_router = APIRouter(prefix=f"/{primitive_name}", tags=[primitive_name]) | ||
|
||
@primitive_router.get("", name=f"Get {primitive_name} result", tags=["get result"]) | ||
def _(task_id: UUID4) -> result_cls: | ||
raise HTTPException(418, "Not implemented yet") | ||
|
||
router.include_router(primitive_router) | ||
|
||
|
||
for primitive_name in PrimitiveFactory.get_names(): | ||
task_factory = PrimitiveFactory.get_by_name(primitive_name) | ||
for algo_name in task_factory.get_names(): | ||
task_cls = task_factory.get_by_name(algo_name) | ||
generate_set_task_endpoint(primitive_name, algo_name, task_cls) | ||
|
||
for primitive_name in PrimitiveFactory.get_names(): | ||
task_factory = PrimitiveFactory.get_by_name(primitive_name) | ||
generate_get_task_result_endpoint( | ||
primitive_name, task_factory.general_task_cls.result_model_cls | ||
) | ||
@router.get("/{task_id}") | ||
def retrieve_task(task_id: UUID4) -> None: | ||
raise HTTPException(418, "Not implemented yet") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from app.domain.task.afd import AfdTask, AfdTaskConfig, AfdTaskResult | ||
from app.domain.task.fd import FdTaskConfig, FdTaskResult | ||
from typing import Annotated, Union, assert_never | ||
from pydantic import Field | ||
from app.domain.task.fd import FdTask | ||
from app.domain.task.primitive_name import PrimitiveName | ||
|
||
|
||
OneOfTaskConfig = Annotated[ | ||
Union[ | ||
FdTaskConfig, | ||
AfdTaskConfig, | ||
], | ||
Field(discriminator="primitive_name"), | ||
] | ||
|
||
OneOfTaskResult = Annotated[ | ||
Union[ | ||
FdTaskResult, | ||
AfdTaskResult, | ||
], | ||
Field(discriminator="primitive_name"), | ||
] | ||
|
||
|
||
def match_task_by_primitive_name(primitive_name: PrimitiveName): | ||
match primitive_name: | ||
case PrimitiveName.fd: | ||
return FdTask() | ||
case PrimitiveName.afd: | ||
return AfdTask() | ||
assert_never(primitive_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,51 @@ | ||
from typing import Type | ||
from abc import ABC, abstractmethod | ||
from pydantic import BaseModel | ||
from desbordante import Algorithm | ||
from enum import StrEnum | ||
from typing import Any, Protocol | ||
import desbordante | ||
import pandas | ||
from pydantic import BaseModel | ||
|
||
|
||
type AnyAlgo = Algorithm | ||
type AnyConf = BaseModel | ||
type AnyRes = BaseModel | ||
class AlgoConfig(Protocol): | ||
@property | ||
def algo_name(self) -> StrEnum: ... | ||
|
||
# forces to use pydantic classes there | ||
model_dump = BaseModel.model_dump | ||
|
||
class AbstractTask[Algo: AnyAlgo, Conf: AnyConf, Res: AnyRes](ABC): | ||
algo: Algo | ||
config_model_cls: Type[Conf] | ||
result_model_cls: Type[Res] | ||
|
||
def __init__(self, table: pandas.DataFrame) -> None: | ||
try: | ||
self.algo | ||
self.config_model_cls | ||
self.result_model_cls | ||
except AttributeError: | ||
raise NotImplementedError( | ||
"Attributes `algo`, `config_model_cls` and `result_model_cls` must be implemented in non-abstract class" | ||
) | ||
class TaskConfig(Protocol): | ||
@property | ||
def primitive_name(self) -> StrEnum: ... | ||
|
||
self.table = table | ||
self.algo.load_data(table=table) | ||
@property | ||
def config(self) -> AlgoConfig: ... | ||
|
||
def execute(self, config: Conf | None = None) -> Res: | ||
options = config.model_dump(exclude_unset=True) if config else {} | ||
self.algo.execute(**options) | ||
return self.collect_result() | ||
# forces to use pydantic classes there | ||
model_dump = BaseModel.model_dump | ||
|
||
@abstractmethod | ||
def collect_result(self) -> Res: ... | ||
|
||
class TaskResult(Protocol): | ||
@property | ||
def primitive_name(self) -> StrEnum: ... | ||
|
||
result: Any | ||
|
||
# forces to use pydantic classes there | ||
model_dump = BaseModel.model_dump | ||
|
||
type AnyTask = AbstractTask[AnyAlgo, AnyConf, AnyRes] | ||
|
||
class Task[C: TaskConfig, R: TaskResult](ABC): | ||
@abstractmethod | ||
def match_algo_by_name(self, algo_name) -> desbordante.Algorithm: ... | ||
|
||
@abstractmethod | ||
def collect_result(self, algo) -> R: ... | ||
|
||
def execute(self, table: pandas.DataFrame, task_config: C) -> R: | ||
algo_config = task_config.config | ||
options = algo_config.model_dump(exclude_unset=True, exclude={"algo_name"}) | ||
algo = self.match_algo_by_name(algo_config.algo_name) | ||
algo.load_data(table=table) | ||
algo.execute(**options) | ||
return self.collect_result(algo) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Literal, assert_never | ||
from pydantic import BaseModel | ||
from desbordante.fd import FdAlgorithm # This is not a typo | ||
from app.domain.task.abstract_task import Task | ||
from app.domain.task.primitive_name import PrimitiveName | ||
from .config import OneOfAfdConfig | ||
from .result import AfdAlgoResult, FdModel | ||
from .algo_name import AfdAlgoName | ||
from desbordante.afd.algorithms import Pyro, Tane | ||
|
||
|
||
class BaseAfdTaskModel(BaseModel): | ||
primitive_name: Literal[PrimitiveName.afd] | ||
|
||
|
||
class AfdTaskConfig(BaseAfdTaskModel): | ||
config: OneOfAfdConfig | ||
|
||
|
||
class AfdTaskResult(BaseAfdTaskModel): | ||
result: AfdAlgoResult | ||
|
||
|
||
class AfdTask(Task[AfdTaskConfig, AfdTaskResult]): | ||
def collect_result(self, algo: FdAlgorithm) -> AfdTaskResult: | ||
fds = algo.get_fds() | ||
algo_result = AfdAlgoResult(fds=list(map(FdModel.from_fd, fds))) | ||
return AfdTaskResult(primitive_name=PrimitiveName.afd, result=algo_result) | ||
|
||
def match_algo_by_name(self, algo_name: AfdAlgoName) -> FdAlgorithm: | ||
match algo_name: | ||
case AfdAlgoName.Pyro: | ||
return Pyro() | ||
case AfdAlgoName.Tane: | ||
return Tane() | ||
assert_never(algo_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from enum import StrEnum, auto | ||
|
||
|
||
class AfdAlgoName(StrEnum): | ||
Pyro = auto() | ||
Tane = auto() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Annotated, Literal, Union | ||
from pydantic import BaseModel, Field | ||
|
||
from app.domain.task.afd.algo_name import AfdAlgoName | ||
|
||
|
||
class BaseAfdConfig(BaseModel): ... | ||
|
||
|
||
class PyroConfig(BaseAfdConfig): | ||
algo_name: Literal[AfdAlgoName.Pyro] | ||
|
||
is_null_equal_null: bool | ||
error: Annotated[float, Field(ge=0, le=1)] | ||
max_lhs: Annotated[int, Field(ge=1, le=10)] | ||
threads: Annotated[int, Field(ge=1, le=8)] | ||
seed: int | ||
|
||
|
||
class TaneConfig(BaseAfdConfig): | ||
algo_name: Literal[AfdAlgoName.Tane] | ||
|
||
is_null_equal_null: bool | ||
error: Annotated[float, Field(ge=0, le=1)] | ||
max_lhs: Annotated[int, Field(ge=1, le=10)] | ||
|
||
|
||
OneOfAfdConfig = Annotated[ | ||
Union[ | ||
PyroConfig, | ||
TaneConfig, | ||
], | ||
Field(discriminator="algo_name"), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from app.domain.task.fd.result import FdAlgoResult, FdModel | ||
|
||
AfdAlgoResult = FdAlgoResult | ||
FdModel = FdModel |
Oops, something went wrong.