Skip to content

Commit

Permalink
feat: make config fields optional
Browse files Browse the repository at this point in the history
  • Loading branch information
toadharvard committed Mar 25, 2024
1 parent 371ab9b commit 00518cd
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 13 deletions.
4 changes: 4 additions & 0 deletions app/domain/common/optional_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@


class OptionalModel(BaseModel):
__non_optional_fields__ = set()

@classmethod
def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
super().__pydantic_init_subclass__(**kwargs)

for field in cls.model_fields.values():
if field in cls.__non_optional_fields__:
continue
field.default = None

cls.model_rebuild(force=True)
8 changes: 6 additions & 2 deletions app/domain/task/afd/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field
from pydantic import Field

from app.domain.common.optional_model import OptionalModel
from app.domain.task.afd.algo_name import AfdAlgoName


class BaseAfdConfig(BaseModel): ...
class BaseAfdConfig(OptionalModel):
__non_optional_fields__ = {
"algo_name",
}


class PyroConfig(BaseAfdConfig):
Expand Down
9 changes: 6 additions & 3 deletions app/domain/task/fd/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from pydantic import BaseModel, Field
from pydantic import Field
from typing import Annotated, Literal, Union

from app.domain.common.optional_model import OptionalModel
from app.domain.task.fd.algo_name import FdAlgoName


# Should be OptionalModel with required field `algo_name`
class BaseFdConfig(BaseModel): ...
class BaseFdConfig(OptionalModel):
__non_optional_fields__ = {
"algo_name",
}


class AidConfig(BaseFdConfig):
Expand Down
26 changes: 18 additions & 8 deletions tests/domain/task/test_fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@
import logging
from polyfactory.factories.pydantic_factory import ModelFactory

# TODO: change when optional fields are suported
# @pytest.mark.parametrize(
# "table", [pd.read_csv("tests/datasets/university_fd.csv", sep=",", header=0)]
# )
# def test_with_default_params(task_cls, table):
# task = FdTask()
# result = task.execute(table)
# logging.info(result)
from app.domain.task.fd.algo_name import FdAlgoName
from app.domain.task.primitive_name import PrimitiveName


@pytest.mark.parametrize("algo_name", [algo_name.value for algo_name in FdAlgoName])
@pytest.mark.parametrize(
"table",
[pd.read_csv("tests/datasets/university_fd.csv", sep=",", header=0)],
)
def test_with_default_params(algo_name, table):
task = FdTask()
config = FdTaskConfig(
primitive_name=PrimitiveName.fd,
config={"algo_name": algo_name}, # type: ignore
)
logging.info(config)
result = task.execute(table, config)
logging.info(result)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 00518cd

Please sign in to comment.