Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor data module #674

Open
wants to merge 36 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d641a4b
WIP keyspec implementation
WillBaldwin0 Sep 2, 2024
d61f917
WIP2
WillBaldwin0 Sep 2, 2024
f8a3c51
Merge branch 'multihead-merge' into refactor-data
WillBaldwin0 Sep 2, 2024
cefb4f7
Merge branch 'develop' into refactor-data
WillBaldwin0 Sep 2, 2024
a7cead1
WIP3
WillBaldwin0 Sep 2, 2024
79bd1df
Merge branch 'develop' into refactor-data
WillBaldwin0 Sep 2, 2024
818bcc8
fixed some tests
WillBaldwin0 Sep 2, 2024
bfae06a
Merge branch 'develop' into refactor-data
WillBaldwin0 Sep 2, 2024
b52d20b
new interface passing old tests
WillBaldwin0 Sep 2, 2024
a742297
linting and fixed preprocess data
WillBaldwin0 Sep 2, 2024
ddfb31d
more linting
WillBaldwin0 Sep 2, 2024
4253216
Update unittest.yaml
WillBaldwin0 Sep 3, 2024
8803c48
fix key overwriting and unittests
WillBaldwin0 Sep 3, 2024
6e31995
small bug in settings REF_forces
WillBaldwin0 Sep 3, 2024
dfefca0
remove head key and some minor fixes
WillBaldwin0 Sep 4, 2024
2e4d524
default to Default for heads
WillBaldwin0 Sep 5, 2024
7d5c70d
Merge branch 'develop' into refactor-data
WillBaldwin0 Sep 5, 2024
1390a0b
added new test and fix calculator
WillBaldwin0 Sep 6, 2024
339231e
fix tests seed
WillBaldwin0 Sep 6, 2024
70341b1
fix average e0s method
WillBaldwin0 Sep 9, 2024
92d39eb
added missing charges and dipoles weights
WillBaldwin0 Sep 9, 2024
e9e2779
linting
WillBaldwin0 Sep 9, 2024
c0b65e2
moved keyspec construction into run_train
WillBaldwin0 Sep 9, 2024
3b0c34f
pass copies to neighborhood
WillBaldwin0 Sep 18, 2024
b78d1ec
convience function for logging dataset stats
Oct 22, 2024
eed2a41
formatting
Oct 22, 2024
b6bf6c0
fix type hint
RokasEl Oct 24, 2024
5b44fce
minor fixes from review
Oct 29, 2024
503e410
Merge branch 'develop' into refactor-data
WillBaldwin0 Oct 29, 2024
c00410c
fixes for new tests and linting
Oct 29, 2024
908acd1
head key in preprocessor
Oct 29, 2024
808f794
formatting
Oct 29, 2024
8d9f6cb
Merge branch 'develop' into refactor-data
WillBaldwin0 Oct 29, 2024
7a19ed6
new calculator syntax in test_run_train
WillBaldwin0 Oct 29, 2024
b8ef3ab
Merge branch 'develop' into refactor-data
ilyes319 Nov 6, 2024
8579c26
Merge pull request #663 from WillBaldwin0/refactor-data
ilyes319 Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,18 @@ def __init__(
[int(z) for z in self.models[0].atomic_numbers]
)
self.charges_key = charges_key

try:
self.heads = self.models[0].heads
self.available_heads = self.models[0].heads
except AttributeError:
self.heads = ["Default"]
self.available_heads = ["Default"]
self.head = kwargs.get("head", "Default")
assert (
self.head in self.available_heads
), f"specified head {self.head}, but model available model heads are {self.available_heads}"

print("using head", self.head, "out of", self.available_heads)

model_dtype = get_model_dtype(self.models[0])
if default_dtype == "":
print(
Expand Down Expand Up @@ -235,11 +243,19 @@ def _create_result_tensors(
return dict_of_tensors

def _atoms_to_batch(self, atoms):
config = data.config_from_atoms(atoms, charges_key=self.charges_key)
keyspec = data.KeySpecification(
info_keys={}, arrays_keys={"charges": self.charges_key}
)
config = data.config_from_atoms(
atoms, key_specification=keyspec, head_name=self.head
)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads
config,
z_table=self.z_table,
cutoff=self.r_max,
heads=self.available_heads,
)
],
batch_size=1,
Expand Down
13 changes: 7 additions & 6 deletions mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tqdm

from mace import data, tools
from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.data.utils import save_configurations_as_HDF5
from mace.modules import compute_statistics
from mace.tools import torch_geometric
Expand Down Expand Up @@ -144,6 +145,10 @@ def run(args: argparse.Namespace):
new hdf5 file that is ready for training with on-the-fly dataloading
"""

# currently support only command line property_key syntax
args.key_specification = KeySpecification()
update_keyspec_from_kwargs(args.key_specification, vars(args))

# Setup
tools.set_seeds(args.seed)
random.seed(args.seed)
Expand Down Expand Up @@ -177,12 +182,8 @@ def run(args: argparse.Namespace):
config_type_weights=config_type_weights,
test_path=args.test_file,
seed=args.seed,
energy_key=args.energy_key,
forces_key=args.forces_key,
stress_key=args.stress_key,
virials_key=args.virials_key,
dipole_key=args.dipole_key,
charges_key=args.charges_key,
key_specification=args.key_specification,
head_name=None,
)

# Atomic number table
Expand Down
50 changes: 30 additions & 20 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import mace
from mace import data, tools
from mace.calculators.foundations_models import mace_mp, mace_off
from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.tools import torch_geometric
from mace.tools.model_script_utils import configure_model
from mace.tools.multihead_tools import (
Expand Down Expand Up @@ -70,6 +71,10 @@ def run(args: argparse.Namespace) -> None:
tag = tools.get_tag(name=args.name, seed=args.seed)
args, input_log_messages = tools.check_args(args)

# default keyspec to update using heads dictionary
args.key_specification = KeySpecification()
update_keyspec_from_kwargs(args.key_specification, vars(args))

if args.device == "xpu":
try:
import intel_extension_for_pytorch as ipex
Expand Down Expand Up @@ -153,12 +158,26 @@ def run(args: argparse.Namespace) -> None:

if args.heads is not None:
args.heads = ast.literal_eval(args.heads)
for _, head_dict in args.heads.items():
# priority is global args < head property_key values < head info_keys+arrays_keys
head_keyspec = deepcopy(args.key_specification)
update_keyspec_from_kwargs(head_keyspec, head_dict)
head_keyspec.update(
info_keys=head_dict.get("info_keys", {}),
arrays_keys=head_dict.get("arrays_keys", {}),
)
head_dict["key_specification"] = head_keyspec
else:
args.heads = prepare_default_head(args)

logging.info("===========LOADING INPUT DATA===========")
heads = list(args.heads.keys())
logging.info(f"Using heads: {heads}")
logging.info("Using the key specifications to parse data:")
for name, head_dict in args.heads.items():
head_keyspec = head_dict["key_specification"]
logging.info(f"{name}: {head_keyspec}")

head_configs: List[HeadConfig] = []
for head, head_args in args.heads.items():
logging.info(f"============= Processing head {head} ===========")
Expand Down Expand Up @@ -187,7 +206,6 @@ def run(args: argparse.Namespace) -> None:
head_config.atomic_energies_dict = ast.literal_eval(
statistics["atomic_energies"]
)

# Data preparation
if check_path_ase_read(head_config.train_file):
if head_config.valid_file is not None:
Expand All @@ -205,12 +223,7 @@ def run(args: argparse.Namespace) -> None:
config_type_weights=config_type_weights,
test_path=head_config.test_file,
seed=args.seed,
energy_key=head_config.energy_key,
forces_key=head_config.forces_key,
stress_key=head_config.stress_key,
virials_key=head_config.virials_key,
dipole_key=head_config.dipole_key,
charges_key=head_config.charges_key,
key_specification=head_config.key_specification,
head_name=head_config.head_name,
keep_isolated_atoms=head_config.keep_isolated_atoms,
)
Expand Down Expand Up @@ -251,14 +264,21 @@ def run(args: argparse.Namespace) -> None:
"Using foundation model for multiheads finetuning with Materials Project data"
)
heads = list(dict.fromkeys(["pt_head"] + heads))
mp_keyspec = KeySpecification()
update_keyspec_from_kwargs(mp_keyspec, vars(args))
mp_keyspec.update(
info_keys={"energy": "energy", "stress": "stress"},
arrays_keys={"forces": "forces"},
)
head_config_pt = HeadConfig(
head_name="pt_head",
E0s="foundation",
statistics_file=args.statistics_file,
key_specification=mp_keyspec,
compute_avg_num_neighbors=False,
avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors,
)
collections = assemble_mp_data(args, tag, head_configs)
collections = assemble_mp_data(args, tag, head_configs, head_config_pt)
head_config_pt.collections = collections
head_config_pt.train_file = f"mp_finetuning-{tag}.xyz"
head_configs.append(head_config_pt)
Expand All @@ -275,12 +295,7 @@ def run(args: argparse.Namespace) -> None:
config_type_weights=None,
test_path=None,
seed=args.seed,
energy_key=args.energy_key,
forces_key=args.forces_key,
stress_key=args.stress_key,
virials_key=args.virials_key,
dipole_key=args.dipole_key,
charges_key=args.charges_key,
key_specification=args.key_specification,
head_name="pt_head",
keep_isolated_atoms=args.keep_isolated_atoms,
)
Expand All @@ -292,12 +307,7 @@ def run(args: argparse.Namespace) -> None:
statistics_file=args.statistics_file,
valid_fraction=args.valid_fraction,
config_type_weights=None,
energy_key=args.energy_key,
forces_key=args.forces_key,
stress_key=args.stress_key,
virials_key=args.virials_key,
dipole_key=args.dipole_key,
charges_key=args.charges_key,
key_specification=args.key_specification,
keep_isolated_atoms=args.keep_isolated_atoms,
collections=collections,
avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors,
Expand Down
4 changes: 4 additions & 0 deletions mace/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .utils import (
Configuration,
Configurations,
KeySpecification,
compute_average_E0s,
config_from_atoms,
config_from_atoms_list,
Expand All @@ -13,6 +14,7 @@
save_configurations_as_HDF5,
save_dataset_as_HDF5,
test_config_types,
update_keyspec_from_kwargs,
)

__all__ = [
Expand All @@ -31,4 +33,6 @@
"dataset_from_sharded_hdf5",
"save_AtomicData_to_HDF5",
"save_configurations_as_HDF5",
"KeySpecification",
"update_keyspec_from_kwargs",
]
Loading
Loading