Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn when unused keys are present in the config #154

Draft
wants to merge 15 commits into
base: develop
Choose a base branch
from
4 changes: 2 additions & 2 deletions nequip/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ def fresh_start(config):

config = init_n_update(config)

trainer = TrainerWandB(model=None, **dict(config))
trainer = TrainerWandB.from_config(model=None, config=config)
else:
from nequip.train.trainer import Trainer

trainer = Trainer(model=None, **dict(config))
trainer = Trainer.from_config(model=None, config=config)

# what is this
# to update wandb data?
Expand Down
58 changes: 39 additions & 19 deletions nequip/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
atomic_write,
finish_all_writes,
atomic_write_group,
dtype_from_name,
)
from nequip.utils.versions import check_code_version
from nequip.model import model_from_config
Expand Down Expand Up @@ -254,23 +253,22 @@ def __init__(
save_ema_checkpoint_freq: int = -1,
report_init_validation: bool = True,
verbose="INFO",
**kwargs,
config=None,
):
self._initialized = False
self.cumulative_wall = 0
logging.debug("* Initialize Trainer")

# store all init arguments
self.model = model
assert isinstance(config, Config)

_local_kwargs = {}
# set attributes for init arguments
for key in self.init_keys:
setattr(self, key, locals()[key])
_local_kwargs[key] = locals()[key]

self.model = model
self.ema = None

output = Output.get_output(dict(**_local_kwargs, **kwargs))
output = Output.get_output(config)
self.output = output

self.logfile = output.open_logfile("log", propagate=True)
Expand Down Expand Up @@ -306,7 +304,7 @@ def __init__(

# sort out all the other parameters
# for samplers, optimizer and scheduler
self.kwargs = deepcopy(kwargs)
self.config = config
self.optimizer_kwargs = deepcopy(optimizer_kwargs)
self.lr_scheduler_kwargs = deepcopy(lr_scheduler_kwargs)
self.early_stopping_kwargs = deepcopy(early_stopping_kwargs)
Expand All @@ -320,7 +318,7 @@ def __init__(
builder=Loss,
prefix="loss",
positional_args=dict(coeffs=self.loss_coeffs),
all_args=self.kwargs,
all_args=self.config,
)
self.loss_stat = LossStat(self.loss)

Expand All @@ -342,7 +340,7 @@ def __init__(
self._remove_from_model_input = self._remove_from_model_input.union(
AtomicDataDict.ALL_ENERGY_KEYS
)
if kwargs.get("_override_allow_truth_label_inputs", False):
if config.get("_override_allow_truth_label_inputs", False):
# needed for unit testing models
self._remove_from_model_input = set()

Expand All @@ -363,14 +361,25 @@ def __init__(

self.init()

@classmethod
def from_config(kls, model, config):
return instantiate(
kls,
positional_args=dict(model=model, config=config),
all_args=config,
# For BC, because trainer has some *_kwargs args that aren't strict sub-builders
# ex. `optimizer_kwargs`
_strict_kwargs_postfix=False,
)[0]

def init_objects(self):
# initialize optimizer
self.optim, self.optimizer_kwargs = instantiate_from_cls_name(
module=torch.optim,
class_name=self.optimizer_name,
prefix="optimizer",
positional_args=dict(params=self.model.parameters(), lr=self.learning_rate),
all_args=self.kwargs,
all_args=self.config,
optional_args=self.optimizer_kwargs,
)

Expand All @@ -396,15 +405,15 @@ def init_objects(self):
prefix="lr_scheduler",
positional_args=dict(optimizer=self.optim),
optional_args=self.lr_scheduler_kwargs,
all_args=self.kwargs,
all_args=self.config,
)

# initialize early stopping conditions
key_mapping, kwargs = instantiate(
EarlyStopping,
prefix="early_stopping",
optional_args=self.early_stopping_kwargs,
all_args=self.kwargs,
all_args=self.config,
return_args_only=True,
)
n_args = 0
Expand Down Expand Up @@ -444,15 +453,15 @@ def init_keys(self):
return [
key
for key in list(inspect.signature(Trainer.__init__).parameters.keys())
if key not in (["self", "kwargs", "model"] + Trainer.object_keys)
if key not in (["self", "kwargs", "model", "config"] + Trainer.object_keys)
]

@property
def params(self):
return self.as_dict(state_dict=False, training_progress=False, kwargs=False)

def update_kwargs(self, config):
self.kwargs.update(
self.config.update(
{key: value for key, value in config.items() if key not in self.init_keys}
)

Expand Down Expand Up @@ -486,7 +495,15 @@ def as_dict(
dictionary[key] = getattr(self, key, None)

if kwargs:
dictionary.update(getattr(self, "kwargs", {}))
dictionary.update(
{
k: v
for k, v in getattr(self, "config", {}).items()
# config could have keys that already got taken for the named parameters of the trainer
# those are already handled in the loop above over init_keys
if k not in inspect.signature(Trainer.__init__).parameters.keys()
}
)

if state_dict:
dictionary["state_dict"] = {}
Expand Down Expand Up @@ -627,7 +644,7 @@ def from_dict(cls, dictionary, append: Optional[bool] = None):

state_dict = dictionary.pop("state_dict", None)

trainer = cls(model=model, **dictionary)
trainer = cls.from_config(model=model, config=Config.from_dict(dictionary))

if state_dict is not None and trainer.model is not None:
logging.debug("Reload optimizer and scheduler states")
Expand Down Expand Up @@ -686,7 +703,7 @@ def load_model_from_training_session(
# this set as default dtype... does it matter?
model.to(
device=torch.device(device),
dtype=dtype_from_name(config.default_dtype),
dtype=torch.get_default_dtype(),
)
model_state_dict = torch.load(
traindir + "/" + model_name, map_location=device
Expand Down Expand Up @@ -730,7 +747,7 @@ def init_metrics(self):
builder=Metrics,
prefix="metrics",
positional_args=dict(components=self.metrics_components),
all_args=self.kwargs,
all_args=self.config,
)

if not (
Expand Down Expand Up @@ -764,6 +781,9 @@ def train(self):

self.init_metrics()

# we're done initializing things, so check:
self.config.warn_unused()

while not self.stop_cond:

self.epoch_step()
Expand Down
18 changes: 10 additions & 8 deletions nequip/utils/auto_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def instantiate(
remove_kwargs: bool = True,
return_args_only: bool = False,
parent_builders: list = [],
_strict_kwargs_postfix: bool = True,
):
"""Automatic initializing class instance by matching keys in the parameter dictionary to the constructor function.

Expand Down Expand Up @@ -98,14 +99,15 @@ def instantiate(
config = Config.from_class(builder, remove_kwargs=remove_kwargs)

# be strict about _kwargs keys:
allow = config.allow_list()
for key in allow:
bname = key[:-7]
if key.endswith("_kwargs") and bname not in allow:
raise KeyError(
f"Instantiating {builder.__name__}: found kwargs argument `{key}`, but no parameter `{bname}` for the corresponding builder. (Did you rename `{bname}` but forget to change `{bname}_kwargs`?) Either add a parameter for `{bname}` if you are trying to allow construction of a submodule, or, if `{bname}_kwargs` is just supposed to be a dictionary, rename it without `_kwargs`."
)
del allow
if _strict_kwargs_postfix:
allow = config.allow_list()
for key in allow:
bname = key[:-7]
if key.endswith("_kwargs") and bname not in allow:
raise KeyError(
f"Instantiating {builder.__name__}: found kwargs argument `{key}`, but no parameter `{bname}` for the corresponding builder. (Did you rename `{bname}` but forget to change `{bname}_kwargs`?) Either add a parameter for `{bname}` if you are trying to allow construction of a submodule, or, if `{bname}_kwargs` is just supposed to be a dictionary, rename it without `_kwargs`."
)
del allow

key_mapping = {}
if all_args is not None:
Expand Down
29 changes: 27 additions & 2 deletions nequip/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

"""
import inspect
import logging

from copy import deepcopy
from typing import Optional
Expand All @@ -48,12 +49,14 @@ def __init__(
config: Optional[dict] = None,
allow_list: Optional[list] = None,
exclude_keys: Optional[list] = None,
defaults: Optional[dict] = None,
):

object.__setattr__(self, "_items", dict())
object.__setattr__(self, "_item_types", dict())
object.__setattr__(self, "_allow_list", list())
object.__setattr__(self, "_allow_all", True)
object.__setattr__(self, "_accessed_keys", set())

if allow_list is not None:
self.add_allow_list(allow_list, default_values={})
Expand All @@ -62,6 +65,15 @@ def __init__(
config = {
key: value for key, value in config.items() if key not in exclude_keys
}

object.__setattr__(self, "_initial_keys", frozenset(config.keys()))

if defaults is not None:
tmp = config
config = defaults.copy()
if tmp is not None:
config.update(tmp)
del tmp
if config is not None:
self.update(config)

Expand All @@ -80,6 +92,7 @@ def as_dict(self):
return dict(self)

def __getitem__(self, key):
self._accessed_keys.add(key)
return self._items[key]

def get_type(self, key):
Expand Down Expand Up @@ -151,12 +164,14 @@ def items(self):
__setattr__ = __setitem__

def __getattr__(self, key):
self._accessed_keys.add(key)
return self.__getitem__(key)

def __contains__(self, key):
return key in self._items

def pop(self, *args):
self._accessed_keys.add(args[0])
return self._items.pop(*args)

def update_w_prefix(
Expand Down Expand Up @@ -211,6 +226,8 @@ def update(self, dictionary: dict, allow_val_change=None):
keys (set): set of keys being udpated

"""
if dictionary is None:
dictionary = {}

keys = []

Expand All @@ -227,6 +244,7 @@ def update(self, dictionary: dict, allow_val_change=None):
return set(keys) - set([None])

def get(self, *args):
self._accessed_keys.add(args[0])
return self._items.get(*args)

def persist(self):
Expand Down Expand Up @@ -266,8 +284,7 @@ def from_file(filename: str, format: Optional[str] = None, defaults: dict = {}):

@staticmethod
def from_dict(dictionary: dict, defaults: dict = {}):
c = Config(defaults)
c.update(dictionary)
c = Config(dictionary, defaults=defaults)
return c

@staticmethod
Expand Down Expand Up @@ -338,3 +355,11 @@ def from_function(function, remove_kwargs=False):
return Config(config=default_params, allow_list=param_keys)

load = from_file

def warn_unused(self):
# = Warn about unused keys =
unused_keys = self._initial_keys - self._accessed_keys
if len(unused_keys) > 0:
logging.warn(
f"!!! Keys {', '.join('`%s`' % k for k in unused_keys)} appeared in the config but were not used. Please check if any of them have a typo or should have been used!!!"
)
14 changes: 3 additions & 11 deletions nequip/utils/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional

from .config import Config
from .auto_init import instantiate


class Output:
Expand Down Expand Up @@ -144,20 +145,11 @@ def as_dict(self):

@classmethod
def get_output(cls, kwargs: dict = {}):

d = inspect.signature(cls.__init__)
_kwargs = {
key: kwargs.get(key, None)
for key in list(d.parameters.keys())
if key not in ["self", "kwargs"]
}
return cls(**_kwargs)
return cls.from_config(Config.from_dict(kwargs))

@classmethod
def from_config(cls, config):
c = Config.from_class(cls)
c.update(config)
return cls(**dict(c))
return instantiate(cls, all_args=config)[0]


def set_if_none(x, y):
Expand Down
Loading