Skip to content

Commit

Permalink
swap multihead test to float64
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Sep 2, 2024
1 parent 4e0cc27 commit 3fb8b60
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions tests/test_run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_run_train_multihead(tmp_path, fitting_configs):
mace_params["loss"] = "weighted"
mace_params["hidden_irreps"] = "128x0e"
mace_params["r_max"] = 6.0
mace_params["default_dtype"] = "float32"
mace_params["default_dtype"] = "float64"
mace_params["num_radial_basis"] = 10
mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock"
mace_params["config"] = tmp_path / "config.yaml"
Expand Down Expand Up @@ -349,7 +349,7 @@ def test_run_train_multihead(tmp_path, fitting_configs):
assert p.returncode == 0

calc = MACECalculator(
tmp_path / "MACE.model", device="cpu", default_dtype="float32"
tmp_path / "MACE.model", device="cpu", default_dtype="float64"
)

Es = []
Expand All @@ -358,30 +358,30 @@ def test_run_train_multihead(tmp_path, fitting_configs):
Es.append(at.get_potential_energy())

print("Es", Es)
# from a run on 22/08/2024 on commit
# from a run on 02/09/2024 on develop branch
ref_Es = [
0.0,
0.0,
0.1492728888988495,
0.12760481238365173,
0.18094804883003235,
0.2017526775598526,
0.09473809599876404,
0.20055484771728516,
0.1673969328403473,
0.1053609699010849,
0.29178786277770996,
0.06670654565095901,
0.09736010432243347,
0.23458734154701233,
0.09877493232488632,
-0.022957436740398407,
0.2738725543022156,
0.13694337010383606,
0.12737643718719482,
-0.07650933414697647,
-0.012938144616782665,
0.061228662729263306,
0.10637113905361611,
-0.012499594026624754,
0.08983077108171753,
0.21071322543112597,
-0.028921849222784398,
-0.02423359575741567,
0.022923252188079057,
-0.02048334610058991,
0.4349711162741364,
-0.04455577015569887,
-0.09765806785570091,
0.16013134616829822,
0.0758442928017698,
-0.05931856557011721,
0.33964473532953265,
0.134338442158641,
0.18024119757783053,
-0.18914740992058765,
-0.06503477155294624,
0.03436649147415213,
]
assert np.allclose(Es, ref_Es)

Expand Down

0 comments on commit 3fb8b60

Please sign in to comment.