From 6c58d3a6c64771a2741069ecc3bdbdb0c6433c4c Mon Sep 17 00:00:00 2001 From: andnp Date: Thu, 12 Sep 2024 20:20:30 -0600 Subject: [PATCH 1/3] feat: use default values for new configs --- ml_experiment/DefinitionPart.py | 44 ++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/ml_experiment/DefinitionPart.py b/ml_experiment/DefinitionPart.py index ca101b8..a4cc299 100644 --- a/ml_experiment/DefinitionPart.py +++ b/ml_experiment/DefinitionPart.py @@ -15,13 +15,30 @@ def __init__(self, name: str, base: str | None = None): self.base_path = base or os.getcwd() self._properties: Dict[str, Set[ValueType]] = defaultdict(set) - - def add_property(self, key: str, value: ValueType): + self._prior_values: Dict[str, ValueType] = {} + + def add_property( + self, + key: str, + value: ValueType, + assume_prior_value: ValueType | None = None, + ): self._properties[key].add(value) - def add_sweepable_property(self, key: str, values: Iterable[ValueType]): + if assume_prior_value is not None: + self._prior_values[key] = assume_prior_value + + def add_sweepable_property( + self, + key: str, + values: Iterable[ValueType], + assume_prior_value: ValueType | None = None, + ): self._properties[key] |= set(values) + if assume_prior_value is not None: + self._prior_values[key] = assume_prior_value + def get_results_path(self) -> str: import __main__ experiment_name = __main__.__file__.split('/')[-2] @@ -41,7 +58,8 @@ def commit(self): # grabbing from prior tables where possible, or generating a unique id for new configs next_config_id = table_registry.get_max_configuration_id(cur, self.name) + 1 for configuration in configurations: - existing_id = table_registry.get_configuration_id(cur, self.name, configuration) + config_query = self._get_configuration_without_priors(configuration) + existing_id = table_registry.get_configuration_id(cur, self.name, config_query) if existing_id is not None: configuration['id'] = existing_id @@ -73,6 +91,24 @@ def commit(self): con.commit() con.close() + def _get_configuration_without_priors(self, configuration: Dict[str, ValueType]): + """ + When a new property is introduced that has an assumed prior value, + then we need to search for configuration ids without the new + property and associated those with the new config. + + This function gives back the configuration to search for to + obtain an id. + """ + out = {} + for k, v in configuration.items(): + if k in self._prior_values and v == self._prior_values[k]: + continue + + out[k] = v + + return out + def generate_configurations(properties: Dict[str, Set[ValueType]]): for configuration in product(*properties.values()): From 170e8ec53ecae8e8568557ac2bfb2d5c51e1a243 Mon Sep 17 00:00:00 2001 From: andnp Date: Thu, 12 Sep 2024 20:52:02 -0600 Subject: [PATCH 2/3] refactor: add maybe monad implementation Simplifies the logic in several functions by allowing chaining Nullable codepaths. Helps keep logic relatively flat instead of nested ifs. ``` if x is None: if y is None: # do a else: # do b else: # do c ``` becomes (something like) ``` outcome = ( Maybe(x) .or(y) ) ``` --- ml_experiment/_utils/maybe.py | 54 +++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 ml_experiment/_utils/maybe.py diff --git a/ml_experiment/_utils/maybe.py b/ml_experiment/_utils/maybe.py new file mode 100644 index 0000000..9550531 --- /dev/null +++ b/ml_experiment/_utils/maybe.py @@ -0,0 +1,54 @@ +from __future__ import annotations +from typing import Callable, Generic, TypeVar + + +T = TypeVar('T') +U = TypeVar('U') + +class Maybe(Generic[T]): + def __init__(self, v: T | None): + self._v: T | None = v + + + def map(self, f: Callable[[T], U | None]) -> Maybe[U]: + if self._v is None: + return Maybe[U](None) + + u = f(self._v) + return Maybe(u) + + + def flat_map(self, f: Callable[[T], Maybe[U]]) -> Maybe[U]: + if self._v is None: + return Maybe[U](None) + + return f(self._v) + + + def flat_otherwise(self, f: Callable[[], Maybe[T]]) -> Maybe[T]: + if self._v is None: + return f() + + return self + + + def or_else(self, t: T) -> T: + if self._v is None: + return t + + return self._v + + + def expect(self, msg: str = '') -> T: + if self._v is None: + raise Exception(msg) + + return self._v + + + def is_none(self): + return self._v is None + + + def is_some(self): + return self._v is not None diff --git a/pyproject.toml b/pyproject.toml index 076530b..0f2748d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ version_files = ["pyproject.toml"] [tool.ruff.lint] select = ['F', 'E', 'W', 'B'] -ignore = ['E501', 'E701'] +ignore = ['E501', 'E701', 'B023'] [tool.pyright] include = ['ml_experiment'] From 571d1dece8d040c174e1d7b069e4ecfa53fae946 Mon Sep 17 00:00:00 2001 From: andnp Date: Thu, 12 Sep 2024 20:56:10 -0600 Subject: [PATCH 3/3] fix: search both partial and complete config for priors --- ml_experiment/DefinitionPart.py | 12 ++++++------ ml_experiment/metadata/MetadataTableRegistry.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ml_experiment/DefinitionPart.py b/ml_experiment/DefinitionPart.py index a4cc299..cf310fd 100644 --- a/ml_experiment/DefinitionPart.py +++ b/ml_experiment/DefinitionPart.py @@ -59,16 +59,16 @@ def commit(self): next_config_id = table_registry.get_max_configuration_id(cur, self.name) + 1 for configuration in configurations: config_query = self._get_configuration_without_priors(configuration) - existing_id = table_registry.get_configuration_id(cur, self.name, config_query) - if existing_id is not None: - configuration['id'] = existing_id + configuration['id'] = ( + table_registry.get_configuration_id(cur, self.name, configuration) + .flat_otherwise(lambda: table_registry.get_configuration_id(cur, self.name, config_query)) + .or_else(next_config_id) + ) - else: - configuration['id'] = next_config_id + if configuration['id'] == next_config_id: next_config_id += 1 - # determine whether we should build a new table # and what version to call that table latest_table = table_registry.get_latest_version(cur, self.name) diff --git a/ml_experiment/metadata/MetadataTableRegistry.py b/ml_experiment/metadata/MetadataTableRegistry.py index 08128de..2f073da 100644 --- a/ml_experiment/metadata/MetadataTableRegistry.py +++ b/ml_experiment/metadata/MetadataTableRegistry.py @@ -2,9 +2,9 @@ import ml_experiment._utils.sqlite as sqlu from typing import Dict, Iterable +from ml_experiment._utils.maybe import Maybe from ml_experiment.metadata.MetadataTable import MetadataTable, ValueType - class MetadataTableRegistry: def __init__(self): # cached results @@ -73,11 +73,11 @@ def get_max_configuration_id(self, cur: sqlite3.Cursor, part_name: str) -> int: return max(all_ids) - def get_configuration_id(self, cur: sqlite3.Cursor, part_name: str, configuration: Dict[str, ValueType]) -> int | None: + def get_configuration_id(self, cur: sqlite3.Cursor, part_name: str, configuration: Dict[str, ValueType]) -> Maybe[int]: latest = self.get_latest_version(cur, part_name) if latest is None: - return None + return Maybe(None) # walk backwards starting from the latest version # if any table contains an id, then stop @@ -89,9 +89,9 @@ def get_configuration_id(self, cur: sqlite3.Cursor, part_name: str, configuratio conf_id = table.get_configuration_id(cur, configuration) if conf_id is not None: - return conf_id + return Maybe(conf_id) - return None + return Maybe(None) def create_new_table(self, cur: sqlite3.Cursor, part_name: str, version: int, config_params: Iterable[str]) -> MetadataTable: