Skip to content

Commit

Permalink
Merge pull request #664 from ElliottKasoar/add-preprocess-config
Browse files Browse the repository at this point in the history
Add pre-processing config file option
  • Loading branch information
ilyes319 authored Dec 5, 2024
2 parents 01e0352 + 5e1c4dd commit 690fd47
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
23 changes: 19 additions & 4 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
"--config",
type=str,
is_config_file=True,
help="config file to agregate options",
help="config file to aggregate options",
)
except ImportError:
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -727,9 +727,24 @@ def build_default_arg_parser() -> argparse.ArgumentParser:


def build_preprocess_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
try:
import configargparse

parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add(
"--config",
type=str,
is_config_file=True,
help="config file to aggregate options",
)
except ImportError:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument(
"--train_file",
help="Training set h5 file",
Expand Down
40 changes: 40 additions & 0 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ase.io
import numpy as np
import pytest
import yaml
from ase.atoms import Atoms

pytest_mace_dir = Path(__file__).parent.parent
Expand Down Expand Up @@ -164,3 +165,42 @@ def test_preprocess_data(tmp_path, sample_configs):
np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8)

print("All checks passed successfully!")


def test_preprocess_config(tmp_path, sample_configs):
ase.io.write(tmp_path / "sample.xyz", sample_configs)

preprocess_params = {
"train_file": str(tmp_path / "sample.xyz"),
"r_max": 5.0,
"config_type_weights": "{'Default':1.0}",
"num_process": 2,
"valid_fraction": 0.1,
"h5_prefix": str(tmp_path / "preprocessed_"),
"compute_statistics": None,
"seed": 42,
"energy_key": "REF_energy",
"forces_key": "REF_forces",
"stress_key": "REF_stress",
}
filename = tmp_path / "config.yaml"
with open(filename, "w", encoding="utf-8") as file:
yaml.dump(preprocess_params, file)

run_env = os.environ.copy()
sys.path.insert(0, str(Path(__file__).parent.parent))
run_env["PYTHONPATH"] = ":".join(sys.path)
print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"])

cmd = (
sys.executable
+ " "
+ str(preprocess_data)
+ " "
+ "--config"
+ " "
+ str(filename)
)

p = subprocess.run(cmd.split(), env=run_env, check=True)
assert p.returncode == 0

0 comments on commit 690fd47

Please sign in to comment.