diff --git a/internal/domain/task/__init__.py b/internal/domain/task/__init__.py index 2e40fd9..e1b1be4 100644 --- a/internal/domain/task/__init__.py +++ b/internal/domain/task/__init__.py @@ -2,3 +2,5 @@ from internal.domain.task.entities import AfdTask # noqa: F401 from internal.domain.task.entities import AcTask # noqa: F401 from internal.domain.task.entities import IndTask # noqa: F401 +from internal.domain.task.entities import AindTask # noqa: F401 +from internal.domain.task.entities import ArTask # noqa: F401 diff --git a/internal/domain/task/entities/__init__.py b/internal/domain/task/entities/__init__.py index 9f13224..28856b2 100644 --- a/internal/domain/task/entities/__init__.py +++ b/internal/domain/task/entities/__init__.py @@ -5,6 +5,7 @@ from internal.domain.task.entities.ac import AcTask from internal.domain.task.entities.ind import IndTask from internal.domain.task.entities.aind import AindTask +from internal.domain.task.entities.ar import ArTask from internal.domain.task.value_objects import PrimitiveName @@ -32,4 +33,6 @@ def match_task_by_primitive_name(primitive_name: PrimitiveName): return IndTask() case PrimitiveName.aind: return AindTask() + case PrimitiveName.ar: + return ArTask() assert_never(primitive_name) diff --git a/internal/domain/task/entities/ar/__init__.py b/internal/domain/task/entities/ar/__init__.py new file mode 100644 index 0000000..f9a4dec --- /dev/null +++ b/internal/domain/task/entities/ar/__init__.py @@ -0,0 +1 @@ +from internal.domain.task.entities.ar.ar_task import ArTask # noqa: F401 diff --git a/internal/domain/task/entities/ar/ar_task.py b/internal/domain/task/entities/ar/ar_task.py new file mode 100644 index 0000000..715f706 --- /dev/null +++ b/internal/domain/task/entities/ar/ar_task.py @@ -0,0 +1,58 @@ +from desbordante.ar import ArAlgorithm +from desbordante.ar.algorithms import Apriori +from internal.domain.task.entities.task import Task +from internal.domain.task.value_objects import PrimitiveName, IncorrectAlgorithmName +from internal.domain.task.value_objects.ar import ArTaskConfig, ArTaskResult +from internal.domain.task.value_objects.ar import ( + ArAlgoName, + ArModel, + ArAlgoResult, +) + + +class ArTask(Task[ArAlgorithm, ArTaskConfig, ArTaskResult]): + """ + Task class for Association Rule (AR) mining. + + This class handles the execution of different AR algorithms and processes + the results into the appropriate format. It implements the abstract methods + defined in the Task base class. + + Methods: + - _match_algo_by_name(algo_name: ArAlgoName) -> ArAlgorithm: + Match AR algorithm by its name. + - _collect_result(algo: ArAlgorithm) -> ArTaskResult: + Process the output of the AR algorithm and return the result. + """ + + def _collect_result(self, algo: ArAlgorithm) -> ArTaskResult: + """ + Collect and process the AR result. + + Args: + algo (ArAlgorithm): AR algorithm to process. + Returns: + ArTaskResult: The processed result containing association rules. + """ + ar_ids = algo.get_ar_ids() + ar_strings = algo.get_ars() + algo_result = ArAlgoResult( + ars=list(map(ArModel.from_ar, ar_strings)), + ar_ids=list(map(ArModel.from_ar_ids, ar_ids)), + ) + return ArTaskResult(primitive_name=PrimitiveName.ar, result=algo_result) + + def _match_algo_by_name(self, algo_name: str) -> ArAlgorithm: + """ + Match the association rule algorithm by name. + + Args: + algo_name (ArAlgoName): The name of the AR algorithm. + Returns: + ArAlgorithm: The corresponding algorithm instance. + """ + match algo_name: + case ArAlgoName.Apriori: + return Apriori() + case _: + raise IncorrectAlgorithmName(algo_name, "AR") diff --git a/internal/domain/task/entities/task.py b/internal/domain/task/entities/task.py index 9d8e04f..3df5433 100644 --- a/internal/domain/task/entities/task.py +++ b/internal/domain/task/entities/task.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod import desbordante import pandas -from internal.domain.task.value_objects import TaskConfig, TaskResult +from internal.domain.task.value_objects import PrimitiveName, TaskConfig, TaskResult class Task[A: desbordante.Algorithm, C: TaskConfig, R: TaskResult](ABC): @@ -60,10 +60,15 @@ def execute(self, table: pandas.DataFrame, task_config: C) -> R: algo_config = task_config.config options = algo_config.model_dump(exclude_unset=True, exclude={"algo_name"}) algo = self._match_algo_by_name(algo_config.algo_name) - # TODO: IND, AIND requires multiple tables - try: - algo.load_data(table=table) - except desbordante.ConfigurationError: - algo.load_data(tables=[table]) + + # TODO: FIX THIS PLS!!! + match task_config.primitive_name: + case PrimitiveName.ind | PrimitiveName.aind: + algo.load_data(tables=[table]) + case PrimitiveName.ar: + algo.load_data(table=table, input_format=options["input_format"]) + case _: + algo.load_data(table=table) + algo.execute(**options) return self._collect_result(algo) diff --git a/internal/domain/task/value_objects/__init__.py b/internal/domain/task/value_objects/__init__.py index 4432e32..640dfe5 100644 --- a/internal/domain/task/value_objects/__init__.py +++ b/internal/domain/task/value_objects/__init__.py @@ -6,6 +6,7 @@ from internal.domain.task.value_objects.ac import AcTaskConfig, AcTaskResult from internal.domain.task.value_objects.ind import IndTaskConfig, IndTaskResult from internal.domain.task.value_objects.aind import AindTaskConfig, AindTaskResult +from internal.domain.task.value_objects.ar import ArTaskConfig, ArTaskResult from internal.domain.task.value_objects.config import TaskConfig # noqa: F401 from internal.domain.task.value_objects.result import TaskResult # noqa: F401 @@ -24,11 +25,25 @@ ) OneOfTaskConfig = Annotated[ - Union[FdTaskConfig, AfdTaskConfig, AcTaskConfig, IndTaskConfig, AindTaskConfig], + Union[ + FdTaskConfig, + AfdTaskConfig, + AcTaskConfig, + IndTaskConfig, + AindTaskConfig, + ArTaskConfig, + ], Field(discriminator="primitive_name"), ] OneOfTaskResult = Annotated[ - Union[FdTaskResult, AfdTaskResult, AcTaskResult, IndTaskResult, AindTaskResult], + Union[ + FdTaskResult, + AfdTaskResult, + AcTaskResult, + IndTaskResult, + AindTaskResult, + ArTaskResult, + ], Field(discriminator="primitive_name"), ] diff --git a/internal/domain/task/value_objects/ar/__init__.py b/internal/domain/task/value_objects/ar/__init__.py new file mode 100644 index 0000000..3cba82f --- /dev/null +++ b/internal/domain/task/value_objects/ar/__init__.py @@ -0,0 +1,23 @@ +from typing import Literal + +from pydantic import BaseModel + +from internal.domain.task.value_objects.primitive_name import PrimitiveName +from internal.domain.task.value_objects.ar.algo_config import OneOfArAlgoConfig +from internal.domain.task.value_objects.ar.algo_result import ( # noqa: F401 + ArAlgoResult, + ArModel, +) +from internal.domain.task.value_objects.ar.algo_name import ArAlgoName # noqa: F401 + + +class BaseArTaskModel(BaseModel): + primitive_name: Literal[PrimitiveName.ar] + + +class ArTaskConfig(BaseArTaskModel): + config: OneOfArAlgoConfig + + +class ArTaskResult(BaseArTaskModel): + result: ArAlgoResult diff --git a/internal/domain/task/value_objects/ar/algo_config.py b/internal/domain/task/value_objects/ar/algo_config.py new file mode 100644 index 0000000..a381255 --- /dev/null +++ b/internal/domain/task/value_objects/ar/algo_config.py @@ -0,0 +1,36 @@ +from typing import Literal, Annotated +from pydantic import Field +from internal.domain.common import OptionalModel +from internal.domain.task.value_objects.ar.algo_name import ArAlgoName +from internal.domain.task.value_objects.ar.algo_descriptions import descriptions + + +class BaseArConfig(OptionalModel): + __non_optional_fields__ = { + "algo_name", + } + + +class AprioriConfig(BaseArConfig): + algo_name: Literal[ArAlgoName.Apriori] + + has_tid: Annotated[bool, Field(description=descriptions["has_tid"])] + minconf: Annotated[float, Field(ge=0, le=1, description=descriptions["minconf"])] + minsup: Annotated[float, Field(ge=0, le=1, description=descriptions["minsup"])] + input_format: Annotated[ + str, + Literal["singular", "tabular"], + Field(description=descriptions["input_format"]), + ] + item_column_index: Annotated[ + int, Field(ge=0, description=descriptions["item_column_index"]) + ] + tid_column_index: Annotated[ + int, Field(ge=0, description=descriptions["tid_column_index"]) + ] + + +OneOfArAlgoConfig = Annotated[ + AprioriConfig, + Field(discriminator="algo_name"), +] diff --git a/internal/domain/task/value_objects/ar/algo_descriptions.py b/internal/domain/task/value_objects/ar/algo_descriptions.py new file mode 100644 index 0000000..918245b --- /dev/null +++ b/internal/domain/task/value_objects/ar/algo_descriptions.py @@ -0,0 +1,8 @@ +descriptions = { + "has_tid": "Indicates that the first column contains the transaction IDs", + "minconf": "Minimum confidence value (between 0 and 1)", + "input_format": "Format of the input dataset for AR mining", + "item_column_index": "Index of the column where an item name is stored", + "minsup": "Minimum support value (between 0 and 1)", + "tid_column_index": "Index of the column where a TID is stored", +} diff --git a/internal/domain/task/value_objects/ar/algo_name.py b/internal/domain/task/value_objects/ar/algo_name.py new file mode 100644 index 0000000..b8b5b81 --- /dev/null +++ b/internal/domain/task/value_objects/ar/algo_name.py @@ -0,0 +1,5 @@ +from enum import StrEnum, auto + + +class ArAlgoName(StrEnum): + Apriori = auto() diff --git a/internal/domain/task/value_objects/ar/algo_result.py b/internal/domain/task/value_objects/ar/algo_result.py new file mode 100644 index 0000000..e8c2c41 --- /dev/null +++ b/internal/domain/task/value_objects/ar/algo_result.py @@ -0,0 +1,21 @@ +from pydantic import BaseModel +from desbordante.ar import ARStrings, ArIDs + + +class ArModel(BaseModel): + @classmethod + def from_ar(cls, ar: ARStrings): + return cls(confidence=ar.confidence, left=ar.left, right=ar.right) + + @classmethod + def from_ar_ids(cls, ar_id: ArIDs): + return cls(confidence=ar_id.confidence, left=ar_id.left, right=ar_id.right) + + confidence: float + left: list[str] + right: list[str] + + +class ArAlgoResult(BaseModel): + ars: list[ArModel] + ar_ids: list[ArModel] diff --git a/internal/domain/task/value_objects/primitive_name.py b/internal/domain/task/value_objects/primitive_name.py index 6e1f3c4..10163f1 100644 --- a/internal/domain/task/value_objects/primitive_name.py +++ b/internal/domain/task/value_objects/primitive_name.py @@ -4,7 +4,7 @@ class PrimitiveName(StrEnum): fd = auto() afd = auto() - # ar = auto() + ar = auto() ac = auto() ind = auto() aind = auto()