Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate test sets separately for different heads #681

Merged
merged 1 commit into from
Dec 5, 2024

Conversation

ThomasWarford
Copy link

@ThomasWarford ThomasWarford commented Nov 10, 2024

See issue.

I am not sure what the desired behaviour is here. Do we want the user to be able to specify 'head_key'='head_name' in the input xyz files? This fix would override that.

Using the main branch, adding head=MP2 to a configuration's info dict does indeed add a row 'Default_MP2' to the output table.

2024-11-25 11:54:00.502 INFO: Error-table on TEST:
+----------------+---------------------+------------------+-------------------+
|  config_type   | RMSE E / meV / atom | RMSE F / meV / A | relative F RMSE % |
+----------------+---------------------+------------------+-------------------+
| NaCl_8_Default |          425.1      |         59.0     |          7.43     |
|   NaCl_8_MP2   |          337.9      |         91.3     |         10.09     |
+----------------+---------------------+------------------+-------------------+

@ThomasWarford ThomasWarford changed the title add head_key=head_name to info dict of atoms regardless of isolated_a… Evaluate test sets separately for different heads Nov 10, 2024
@ThomasWarford ThomasWarford reopened this Nov 25, 2024
@ThomasWarford ThomasWarford marked this pull request as draft November 25, 2024 11:05
@ThomasWarford
Copy link
Author

In main branch:
For each head get_dataset_from_xyz is called:

    test_configs = []
    if test_path is not None:
        _, all_test_configs = data.load_from_xyz(
            file_path=test_path,
            config_type_weights=config_type_weights,
            energy_key=energy_key,
            forces_key=forces_key,
            dipole_key=dipole_key,
            stress_key=stress_key,
            virials_key=virials_key,
            charges_key=charges_key,
            head_key=head_key,
            extract_atomic_energies=False,
            head_name=head_name,
        )
        # create list of tuples (config_type, list(Atoms))
        test_configs = data.test_config_types(all_test_configs)
        logging.info(
            f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':"
        )
        for name, tmp_configs in test_configs:
            logging.info(
                f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces"
            )

    return (
        SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs),
        atomic_energies_dict,
    )

The test_configs are returned as a dictionary, which is got from data.test_config_types.

The keys of test_configs is conf.config_type + "_" + conf.head for each configration and the values are lists of configurations.

The PR currently sets conf.head for each configuration. I'm going to look into how the train and valid tables deal with this.

@ThomasWarford
Copy link
Author

ThomasWarford commented Nov 25, 2024

In run_train.run:

    train_valid_data_loader = {}
    for head_config in head_configs:
        data_loader_name = "train_" + head_config.head_name
        train_valid_data_loader[data_loader_name] = head_config.train_loader
    for head, valid_loader in valid_loaders.items():
        data_load_name = "valid_" + head
        train_valid_data_loader[data_load_name] = valid_loader

    test_sets = {}
    stop_first_test = False
    test_data_loader = {}
    if all(
        head_config.test_file == head_configs[0].test_file
        for head_config in head_configs
    ) and head_configs[0].test_file is not None:
        stop_first_test = True
    if all(
        head_config.test_dir == head_configs[0].test_dir
        for head_config in head_configs
    ) and head_configs[0].test_dir is not None:
        stop_first_test = True
    for head_config in head_configs:
        if check_path_ase_read(head_config.train_file):
            for name, subset in head_config.collections.tests:
                test_sets[name] = [
                    data.AtomicData.from_config(
                        config, z_table=z_table, cutoff=args.r_max, heads=heads
                    )
                    for config in subset
                ]
        ...      
        ...
        for test_name, test_set in test_sets.items():
            test_sampler = None
            if args.distributed:
                test_sampler = torch.utils.data.distributed.DistributedSampler(
                    test_set,
                    num_replicas=world_size,
                    rank=rank,
                    shuffle=True,
                    drop_last=True,
                    seed=args.seed,
                )
            try:
                drop_last = test_set.drop_last
            except AttributeError as e:  # pylint: disable=W0612
                drop_last = False
            test_loader = torch_geometric.dataloader.DataLoader(
                test_set,
                batch_size=args.valid_batch_size,
                shuffle=(test_sampler is None),
                drop_last=drop_last,
                num_workers=args.num_workers,
                pin_memory=args.pin_memory,
            )
            test_data_loader[test_name] = test_loader

Key difference:

For train and valid sets the keys of the dataloader are given by "train_" + head/"valid_" + head for each head.

For test dataloader the keys are the same as the keys of head_config.collections.tests, which are set by data.test_config_types, and default to "Default_Default".

@ThomasWarford ThomasWarford marked this pull request as ready for review November 28, 2024 16:42
@ilyes319 ilyes319 merged commit 01e0352 into ACEsuit:develop Dec 5, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants