diff --git a/app/domain/common/optional_model.py b/app/domain/common/optional_model.py index 355b8f75..39965fe3 100644 --- a/app/domain/common/optional_model.py +++ b/app/domain/common/optional_model.py @@ -4,11 +4,15 @@ class OptionalModel(BaseModel): + __non_optional_fields__ = set() + @classmethod def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: super().__pydantic_init_subclass__(**kwargs) for field in cls.model_fields.values(): + if field in cls.__non_optional_fields__: + continue field.default = None cls.model_rebuild(force=True) diff --git a/app/domain/task/afd/config.py b/app/domain/task/afd/config.py index 3c8f2581..5ced5662 100644 --- a/app/domain/task/afd/config.py +++ b/app/domain/task/afd/config.py @@ -1,10 +1,14 @@ from typing import Annotated, Literal, Union -from pydantic import BaseModel, Field +from pydantic import Field +from app.domain.common.optional_model import OptionalModel from app.domain.task.afd.algo_name import AfdAlgoName -class BaseAfdConfig(BaseModel): ... +class BaseAfdConfig(OptionalModel): + __non_optional_fields__ = { + "algo_name", + } class PyroConfig(BaseAfdConfig): diff --git a/app/domain/task/fd/config.py b/app/domain/task/fd/config.py index f1a70234..ab3f305d 100644 --- a/app/domain/task/fd/config.py +++ b/app/domain/task/fd/config.py @@ -1,11 +1,14 @@ -from pydantic import BaseModel, Field +from pydantic import Field from typing import Annotated, Literal, Union +from app.domain.common.optional_model import OptionalModel from app.domain.task.fd.algo_name import FdAlgoName -# Should be OptionalModel with required field `algo_name` -class BaseFdConfig(BaseModel): ... +class BaseFdConfig(OptionalModel): + __non_optional_fields__ = { + "algo_name", + } class AidConfig(BaseFdConfig): diff --git a/tests/domain/task/test_fd.py b/tests/domain/task/test_fd.py index 868bc676..9e6e5c7e 100644 --- a/tests/domain/task/test_fd.py +++ b/tests/domain/task/test_fd.py @@ -4,14 +4,24 @@ import logging from polyfactory.factories.pydantic_factory import ModelFactory -# TODO: change when optional fields are suported -# @pytest.mark.parametrize( -# "table", [pd.read_csv("tests/datasets/university_fd.csv", sep=",", header=0)] -# ) -# def test_with_default_params(task_cls, table): -# task = FdTask() -# result = task.execute(table) -# logging.info(result) +from app.domain.task.fd.algo_name import FdAlgoName +from app.domain.task.primitive_name import PrimitiveName + + +@pytest.mark.parametrize("algo_name", [algo_name.value for algo_name in FdAlgoName]) +@pytest.mark.parametrize( + "table", + [pd.read_csv("tests/datasets/university_fd.csv", sep=",", header=0)], +) +def test_with_default_params(algo_name, table): + task = FdTask() + config = FdTaskConfig( + primitive_name=PrimitiveName.fd, + config={"algo_name": algo_name}, # type: ignore + ) + logging.info(config) + result = task.execute(table, config) + logging.info(result) @pytest.mark.parametrize(