-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Evaluate test sets separately for different heads #681
Evaluate test sets separately for different heads #681
Conversation
In
The The keys of The PR currently sets |
In train_valid_data_loader = {}
for head_config in head_configs:
data_loader_name = "train_" + head_config.head_name
train_valid_data_loader[data_loader_name] = head_config.train_loader
for head, valid_loader in valid_loaders.items():
data_load_name = "valid_" + head
train_valid_data_loader[data_load_name] = valid_loader
test_sets = {}
stop_first_test = False
test_data_loader = {}
if all(
head_config.test_file == head_configs[0].test_file
for head_config in head_configs
) and head_configs[0].test_file is not None:
stop_first_test = True
if all(
head_config.test_dir == head_configs[0].test_dir
for head_config in head_configs
) and head_configs[0].test_dir is not None:
stop_first_test = True
for head_config in head_configs:
if check_path_ase_read(head_config.train_file):
for name, subset in head_config.collections.tests:
test_sets[name] = [
data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max, heads=heads
)
for config in subset
]
...
...
for test_name, test_set in test_sets.items():
test_sampler = None
if args.distributed:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_set,
num_replicas=world_size,
rank=rank,
shuffle=True,
drop_last=True,
seed=args.seed,
)
try:
drop_last = test_set.drop_last
except AttributeError as e: # pylint: disable=W0612
drop_last = False
test_loader = torch_geometric.dataloader.DataLoader(
test_set,
batch_size=args.valid_batch_size,
shuffle=(test_sampler is None),
drop_last=drop_last,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
)
test_data_loader[test_name] = test_loader Key difference: For train and valid sets the keys of the dataloader are given by For test dataloader the keys are the same as the keys of |
See issue.
I am not sure what the desired behaviour is here. Do we want the user to be able to specify 'head_key'='head_name' in the input xyz files? This fix would override that.
Using the main branch, adding
head=MP2
to a configuration's info dict does indeed add a row 'Default_MP2' to the output table.