diff --git a/nequip/scripts/train.py b/nequip/scripts/train.py index a4304e5b..8f642677 100644 --- a/nequip/scripts/train.py +++ b/nequip/scripts/train.py @@ -149,11 +149,11 @@ def fresh_start(config): config = init_n_update(config) - trainer = TrainerWandB(model=None, **dict(config)) + trainer = TrainerWandB.from_config(model=None, config=config) else: from nequip.train.trainer import Trainer - trainer = Trainer(model=None, **dict(config)) + trainer = Trainer.from_config(model=None, config=config) # what is this # to update wandb data? diff --git a/nequip/train/trainer.py b/nequip/train/trainer.py index d92094d0..055e7eb2 100644 --- a/nequip/train/trainer.py +++ b/nequip/train/trainer.py @@ -38,7 +38,6 @@ atomic_write, finish_all_writes, atomic_write_group, - dtype_from_name, ) from nequip.utils.versions import check_code_version from nequip.model import model_from_config @@ -254,23 +253,22 @@ def __init__( save_ema_checkpoint_freq: int = -1, report_init_validation: bool = True, verbose="INFO", - **kwargs, + config=None, ): self._initialized = False self.cumulative_wall = 0 logging.debug("* Initialize Trainer") - # store all init arguments - self.model = model + assert isinstance(config, Config) - _local_kwargs = {} + # set attributes for init arguments for key in self.init_keys: setattr(self, key, locals()[key]) - _local_kwargs[key] = locals()[key] + self.model = model self.ema = None - output = Output.get_output(dict(**_local_kwargs, **kwargs)) + output = Output.get_output(config) self.output = output self.logfile = output.open_logfile("log", propagate=True) @@ -306,7 +304,7 @@ def __init__( # sort out all the other parameters # for samplers, optimizer and scheduler - self.kwargs = deepcopy(kwargs) + self.config = config self.optimizer_kwargs = deepcopy(optimizer_kwargs) self.lr_scheduler_kwargs = deepcopy(lr_scheduler_kwargs) self.early_stopping_kwargs = deepcopy(early_stopping_kwargs) @@ -320,7 +318,7 @@ def __init__( builder=Loss, prefix="loss", positional_args=dict(coeffs=self.loss_coeffs), - all_args=self.kwargs, + all_args=self.config, ) self.loss_stat = LossStat(self.loss) @@ -342,7 +340,7 @@ def __init__( self._remove_from_model_input = self._remove_from_model_input.union( AtomicDataDict.ALL_ENERGY_KEYS ) - if kwargs.get("_override_allow_truth_label_inputs", False): + if config.get("_override_allow_truth_label_inputs", False): # needed for unit testing models self._remove_from_model_input = set() @@ -363,6 +361,17 @@ def __init__( self.init() + @classmethod + def from_config(kls, model, config): + return instantiate( + kls, + positional_args=dict(model=model, config=config), + all_args=config, + # For BC, because trainer has some *_kwargs args that aren't strict sub-builders + # ex. `optimizer_kwargs` + _strict_kwargs_postfix=False, + )[0] + def init_objects(self): # initialize optimizer self.optim, self.optimizer_kwargs = instantiate_from_cls_name( @@ -370,7 +379,7 @@ def init_objects(self): class_name=self.optimizer_name, prefix="optimizer", positional_args=dict(params=self.model.parameters(), lr=self.learning_rate), - all_args=self.kwargs, + all_args=self.config, optional_args=self.optimizer_kwargs, ) @@ -396,7 +405,7 @@ def init_objects(self): prefix="lr_scheduler", positional_args=dict(optimizer=self.optim), optional_args=self.lr_scheduler_kwargs, - all_args=self.kwargs, + all_args=self.config, ) # initialize early stopping conditions @@ -404,7 +413,7 @@ def init_objects(self): EarlyStopping, prefix="early_stopping", optional_args=self.early_stopping_kwargs, - all_args=self.kwargs, + all_args=self.config, return_args_only=True, ) n_args = 0 @@ -444,7 +453,7 @@ def init_keys(self): return [ key for key in list(inspect.signature(Trainer.__init__).parameters.keys()) - if key not in (["self", "kwargs", "model"] + Trainer.object_keys) + if key not in (["self", "kwargs", "model", "config"] + Trainer.object_keys) ] @property @@ -452,7 +461,7 @@ def params(self): return self.as_dict(state_dict=False, training_progress=False, kwargs=False) def update_kwargs(self, config): - self.kwargs.update( + self.config.update( {key: value for key, value in config.items() if key not in self.init_keys} ) @@ -486,7 +495,15 @@ def as_dict( dictionary[key] = getattr(self, key, None) if kwargs: - dictionary.update(getattr(self, "kwargs", {})) + dictionary.update( + { + k: v + for k, v in getattr(self, "config", {}).items() + # config could have keys that already got taken for the named parameters of the trainer + # those are already handled in the loop above over init_keys + if k not in inspect.signature(Trainer.__init__).parameters.keys() + } + ) if state_dict: dictionary["state_dict"] = {} @@ -627,7 +644,7 @@ def from_dict(cls, dictionary, append: Optional[bool] = None): state_dict = dictionary.pop("state_dict", None) - trainer = cls(model=model, **dictionary) + trainer = cls.from_config(model=model, config=Config.from_dict(dictionary)) if state_dict is not None and trainer.model is not None: logging.debug("Reload optimizer and scheduler states") @@ -686,7 +703,7 @@ def load_model_from_training_session( # this set as default dtype... does it matter? model.to( device=torch.device(device), - dtype=dtype_from_name(config.default_dtype), + dtype=torch.get_default_dtype(), ) model_state_dict = torch.load( traindir + "/" + model_name, map_location=device @@ -730,7 +747,7 @@ def init_metrics(self): builder=Metrics, prefix="metrics", positional_args=dict(components=self.metrics_components), - all_args=self.kwargs, + all_args=self.config, ) if not ( @@ -764,6 +781,9 @@ def train(self): self.init_metrics() + # we're done initializing things, so check: + self.config.warn_unused() + while not self.stop_cond: self.epoch_step() diff --git a/nequip/utils/auto_init.py b/nequip/utils/auto_init.py index 8a9a9917..321121a9 100644 --- a/nequip/utils/auto_init.py +++ b/nequip/utils/auto_init.py @@ -67,6 +67,7 @@ def instantiate( remove_kwargs: bool = True, return_args_only: bool = False, parent_builders: list = [], + _strict_kwargs_postfix: bool = True, ): """Automatic initializing class instance by matching keys in the parameter dictionary to the constructor function. @@ -98,14 +99,15 @@ def instantiate( config = Config.from_class(builder, remove_kwargs=remove_kwargs) # be strict about _kwargs keys: - allow = config.allow_list() - for key in allow: - bname = key[:-7] - if key.endswith("_kwargs") and bname not in allow: - raise KeyError( - f"Instantiating {builder.__name__}: found kwargs argument `{key}`, but no parameter `{bname}` for the corresponding builder. (Did you rename `{bname}` but forget to change `{bname}_kwargs`?) Either add a parameter for `{bname}` if you are trying to allow construction of a submodule, or, if `{bname}_kwargs` is just supposed to be a dictionary, rename it without `_kwargs`." - ) - del allow + if _strict_kwargs_postfix: + allow = config.allow_list() + for key in allow: + bname = key[:-7] + if key.endswith("_kwargs") and bname not in allow: + raise KeyError( + f"Instantiating {builder.__name__}: found kwargs argument `{key}`, but no parameter `{bname}` for the corresponding builder. (Did you rename `{bname}` but forget to change `{bname}_kwargs`?) Either add a parameter for `{bname}` if you are trying to allow construction of a submodule, or, if `{bname}_kwargs` is just supposed to be a dictionary, rename it without `_kwargs`." + ) + del allow key_mapping = {} if all_args is not None: diff --git a/nequip/utils/config.py b/nequip/utils/config.py index d13e0546..00fb1350 100644 --- a/nequip/utils/config.py +++ b/nequip/utils/config.py @@ -35,6 +35,7 @@ """ import inspect +import logging from copy import deepcopy from typing import Optional @@ -48,12 +49,14 @@ def __init__( config: Optional[dict] = None, allow_list: Optional[list] = None, exclude_keys: Optional[list] = None, + defaults: Optional[dict] = None, ): object.__setattr__(self, "_items", dict()) object.__setattr__(self, "_item_types", dict()) object.__setattr__(self, "_allow_list", list()) object.__setattr__(self, "_allow_all", True) + object.__setattr__(self, "_accessed_keys", set()) if allow_list is not None: self.add_allow_list(allow_list, default_values={}) @@ -62,6 +65,15 @@ def __init__( config = { key: value for key, value in config.items() if key not in exclude_keys } + + object.__setattr__(self, "_initial_keys", frozenset(config.keys())) + + if defaults is not None: + tmp = config + config = defaults.copy() + if tmp is not None: + config.update(tmp) + del tmp if config is not None: self.update(config) @@ -80,6 +92,7 @@ def as_dict(self): return dict(self) def __getitem__(self, key): + self._accessed_keys.add(key) return self._items[key] def get_type(self, key): @@ -151,12 +164,14 @@ def items(self): __setattr__ = __setitem__ def __getattr__(self, key): + self._accessed_keys.add(key) return self.__getitem__(key) def __contains__(self, key): return key in self._items def pop(self, *args): + self._accessed_keys.add(args[0]) return self._items.pop(*args) def update_w_prefix( @@ -211,6 +226,8 @@ def update(self, dictionary: dict, allow_val_change=None): keys (set): set of keys being udpated """ + if dictionary is None: + dictionary = {} keys = [] @@ -227,6 +244,7 @@ def update(self, dictionary: dict, allow_val_change=None): return set(keys) - set([None]) def get(self, *args): + self._accessed_keys.add(args[0]) return self._items.get(*args) def persist(self): @@ -266,8 +284,7 @@ def from_file(filename: str, format: Optional[str] = None, defaults: dict = {}): @staticmethod def from_dict(dictionary: dict, defaults: dict = {}): - c = Config(defaults) - c.update(dictionary) + c = Config(dictionary, defaults=defaults) return c @staticmethod @@ -338,3 +355,11 @@ def from_function(function, remove_kwargs=False): return Config(config=default_params, allow_list=param_keys) load = from_file + + def warn_unused(self): + # = Warn about unused keys = + unused_keys = self._initial_keys - self._accessed_keys + if len(unused_keys) > 0: + logging.warn( + f"!!! Keys {', '.join('`%s`' % k for k in unused_keys)} appeared in the config but were not used. Please check if any of them have a typo or should have been used!!!" + ) diff --git a/nequip/utils/output.py b/nequip/utils/output.py index a8dbf760..2468919a 100644 --- a/nequip/utils/output.py +++ b/nequip/utils/output.py @@ -8,6 +8,7 @@ from typing import Optional from .config import Config +from .auto_init import instantiate class Output: @@ -144,20 +145,11 @@ def as_dict(self): @classmethod def get_output(cls, kwargs: dict = {}): - - d = inspect.signature(cls.__init__) - _kwargs = { - key: kwargs.get(key, None) - for key in list(d.parameters.keys()) - if key not in ["self", "kwargs"] - } - return cls(**_kwargs) + return cls.from_config(Config.from_dict(kwargs)) @classmethod def from_config(cls, config): - c = Config.from_class(cls) - c.update(config) - return cls(**dict(c)) + return instantiate(cls, all_args=config)[0] def set_if_none(x, y): diff --git a/tests/unit/trainer/test_trainer.py b/tests/unit/trainer/test_trainer.py index c8169fda..ff76eb29 100644 --- a/tests/unit/trainer/test_trainer.py +++ b/tests/unit/trainer/test_trainer.py @@ -15,6 +15,7 @@ from nequip.train.trainer import Trainer from nequip.utils.savenload import load_file from nequip.nn import GraphModuleMixin +from nequip.utils import Config def dummy_builder(): @@ -22,26 +23,28 @@ def dummy_builder(): # set up two config to test -DEBUG = False NATOMS = 3 NFRAMES = 10 -minimal_config = dict( - run_name="test", - n_train=4, - n_val=4, - exclude_keys=["sth"], - max_epochs=2, - batch_size=2, - learning_rate=1e-2, - optimizer="Adam", - seed=0, - append=False, - T_0=50, - T_mult=2, - loss_coeffs={"forces": 2}, - early_stopping_patiences={"loss": 50}, - early_stopping_lower_bounds={"LR": 1e-10}, - model_builders=[dummy_builder], +minimal_config = Config( + config=dict( + run_name="test", + n_train=4, + n_val=4, + exclude_keys=["sth"], + max_epochs=2, + batch_size=2, + learning_rate=1e-2, + optimizer="Adam", + seed=0, + append=False, + T_0=50, + T_mult=2, + loss_coeffs={"forces": 2}, + early_stopping_patiences={"loss": 50}, + early_stopping_lower_bounds={"LR": 1e-10}, + model_builders=[dummy_builder], + verbose="debug", + ) ) @@ -54,7 +57,7 @@ def trainer(): model = model_from_config(minimal_config) with tempfile.TemporaryDirectory(prefix="output") as path: minimal_config["root"] = path - c = Trainer(model=model, **minimal_config) + c = Trainer.from_config(model=model, config=minimal_config) yield c @@ -77,10 +80,10 @@ def test_duplicate_id_2(self, temp_data): minimal_config["root"] = temp_data model = DummyNet(3) - Trainer(model=model, **minimal_config) + Trainer.from_config(model=model, config=minimal_config) with pytest.raises(RuntimeError): - Trainer(model=model, **minimal_config) + Trainer.from_config(model=model, config=minimal_config) class TestSaveLoad: @@ -336,16 +339,19 @@ def unscale(self, data, force_process=False): @pytest.fixture(scope="class") def scale_train(nequip_dataset): with tempfile.TemporaryDirectory(prefix="output") as path: - trainer = Trainer( - model=DummyScale(AtomicDataDict.FORCE_KEY, scale=1.3, shift=1), - n_train=4, - n_val=4, - max_epochs=0, - batch_size=2, - loss_coeffs=AtomicDataDict.FORCE_KEY, - root=path, - run_name="test_scale", + model = DummyScale(AtomicDataDict.FORCE_KEY, scale=1.3, shift=1) + config = Config.from_dict( + dict( + n_train=4, + n_val=4, + max_epochs=0, + batch_size=2, + loss_coeffs=AtomicDataDict.FORCE_KEY, + root=path, + run_name="test_scale", + ) ) + trainer = Trainer.from_config(model=model, config=config) trainer.set_dataset(nequip_dataset) trainer.train() trainer.scale = 1.3