From ceb1a669a448b54fbb0a024669fbed89f3d38f6e Mon Sep 17 00:00:00 2001 From: alin m elena Date: Wed, 21 Aug 2024 15:29:17 +0100 Subject: [PATCH 1/2] add nequip --- janus_core/helpers/janus_types.py | 2 +- janus_core/helpers/mlip_calculators.py | 12 ++++++++++++ pyproject.toml | 2 ++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index 073f9223..0a4241ef 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -138,7 +138,7 @@ class CorrelationKwargs(TypedDict, total=True): # Janus specific Architectures = Literal[ - "mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet" + "mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet", "nequip" ] Devices = Literal["cpu", "cuda", "mps", "xpu"] Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"] diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index dd6c0354..0e0e3409 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -217,6 +217,18 @@ def choose_calculator( kwargs.setdefault("sevennet_config", None) calculator = SevenNetCalculator(model=model, device=device, **kwargs) + elif arch == "nequip": + from nequip.ase import NequIPCalculator + + model = model_path if model_path else "" + + calculator = NequIPCalculator.from_deployed_model( + model_path=model, + device=device, + **kwargs) + + + else: raise ValueError( f"Unrecognized {arch=}. Suported architectures " diff --git a/pyproject.toml b/pyproject.toml index d344e5b4..fb2a4fc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ sevenn = { version = "0.9.3", optional = true } torchdata = {version = "0.7.1", optional = true} # Pin due to dgl issue torch_geometric = { version = "^2.5.3", optional = true } ruff = "^0.5.7" +nequip = {version = "^0.6.1", optional = true } [tool.poetry.extras] all = ["alignn", "chgnet", "matgl", "dgl", "torchdata", "sevenn", "torch_geometric"] @@ -55,6 +56,7 @@ alignn = ["alignn"] chgnet = ["chgnet"] m3gnet = ["matgl", "dgl", "torchdata"] sevennet = ["sevenn", "torch_geometric"] +nequip = ["nequip"] [tool.poetry.group.dev.dependencies] coverage = {extras = ["toml"], version = "^7.4.1"} From 64033a2685b8e9e0f925696f5d36bb50d7bacdf2 Mon Sep 17 00:00:00 2001 From: alin elena Date: Wed, 4 Sep 2024 11:22:11 +0100 Subject: [PATCH 2/2] add nequip raw --- janus_core/helpers/mlip_calculators.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index 0e0e3409..648b0b34 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -223,11 +223,8 @@ def choose_calculator( model = model_path if model_path else "" calculator = NequIPCalculator.from_deployed_model( - model_path=model, - device=device, - **kwargs) - - + model_path=model, device=device, **kwargs + ) else: raise ValueError(