Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MACE 0.3.9 can not jit compile model trained with previous versions #741

Open
lohedges opened this issue Dec 9, 2024 · 11 comments
Open
Labels
bug Something isn't working

Comments

@lohedges
Copy link

lohedges commented Dec 9, 2024

It appears that the MACE models generated by version 0.3.9 aren't TorchScript compliant. To reproduce:

import torch
from mace.calculators.foundations_models import mace_off

model = mace_off(return_raw_model=True)

script_model = torch.jit.script(model)

Gives (truncated):

...

RuntimeError:
Module 'reshape_irreps' has no attribute 'cueq_config' :
  File "/home/lester/.conda/envs/emle/lib/python3.10/site-packages/mace/modules/irreps_tools.py", line 89
            field = tensor[:, ix : ix + mul * d]  # [batch, sample, mul * repr]
            ix += mul * d
            if hasattr(self, "cueq_config") and self.cueq_config is not None:
                                                ~~~~~~~~~~~~~~~~ <--- HERE
                if self.cueq_config.layout_str == "mul_ir":
                    field = field.reshape(batch, mul, d)

The same error is triggered using e3nn to compile the model, e.g:

from e3nn.util import jit

script_model = jit.compile(model)

This works fine with version 0.3.6, which is what I was using previously. For reference, I am using Python 3.10 on Linux x86 and installed mace-torch via pip.

Cheers.

@ilyes319
Copy link
Contributor

ilyes319 commented Dec 9, 2024

What is your torch version? It is normal that the torch jit does not work.

I think it is just you can not compile old models but only models trained with 0.3.9 (we have tests for that). If you want to compile a model trained with an old version, that you need to use that version.

@lohedges
Copy link
Author

lohedges commented Dec 9, 2024

Thanks, I'll check the torch version this evening and report back. This is just using your own MACE OFF models in a completely fresh environment, I'm not training anything.

@ilyes319
Copy link
Contributor

ilyes319 commented Dec 9, 2024

Exactly, mace-off was trained with an older version of mace such that you can not compile it with 0.3.9. You need to use an older version, (probably anything up to 0.3.8).
Any model trained with 0.3.9 will be able to be jit compiled with 0.3.9.

@lohedges
Copy link
Author

lohedges commented Dec 9, 2024

Thanks, that makes sense. However, if that's the case, shouldn't there be logic to download a compatible version at runtime? Surely the models encode the version of MACE that they were trained with? I guess this is probably a question for the MACE-OFF team so I can ask there.

@ilyes319
Copy link
Contributor

ilyes319 commented Dec 9, 2024

Unfortunately no, models at the time did not encode the version. You can still evaluate the model with the current version, just not jit compile.
I can probably fix that though in the next release to make even jit compile backward compatible. I ll ping you.

For now, I recommand you to just compile using 0.3.6, and save the compiled model somewhere.

@ilyes319 ilyes319 changed the title MACE 0.3.9 models are not TorchScript compliant MACE 0.3.9 can not jit compile model trained with previous versions Dec 9, 2024
@lohedges
Copy link
Author

lohedges commented Dec 9, 2024

Thanks, no problem. I'll just add a pin for now.

@lohedges
Copy link
Author

Just to confirm that, as suggested, MACE-OFF works with MACE up to and including version 0.3.8.

I can probably fix that though in the next release to make even jit compile backward compatible. I ll ping you.

Thanks, this would be really helpful. Our use case is to create a dual MACE-EMLE model at runtime (EMLE does the elecstrostatic embedding). We need it to be serializable so that it can be loaded with OpenMM-ML, or directly in C++. Having self-consistency between the version of MACE and MACE-OFF would be ideal.

@RokasEl
Copy link
Collaborator

RokasEl commented Dec 10, 2024

@lohedges, while you wait for the update you can just load the state dict of the old model into an instance of the model with the latest MACE version.

calc = mace_off("medium", device="cuda")
model = calc.models[0]
z_table = calc.z_table
zs = z_table.zs
updated_model = ScaleShiftMACE(
    atomic_numbers=list(zs),
    atomic_energies=np.zeros(len(zs)), # will be loaded from the state dict
    num_elements=len(zs),
    atomic_inter_scale=model.scale_shift.scale,
    atomic_inter_shift=model.scale_shift.shift,
    **hypers.to_dict(),
)
updated_model.load_state_dict(model.state_dict())

from e3nn.util import jit
script_model = jit.compile(updated_model) # works fine!

For the medium mace_off the hyperparams are:

MACE_Hyperparameters(r_max=5.0, num_bessel=8, num_polynomial_cutoff=5, max_ell=3, num_interactions=2, interaction_cls_first='RealAgnosticInteractionBlock', interaction_cls='RealAgnosticResidualInteractionBlock', hidden_irreps='128x0e+128x1o', MLP_irreps='16x0e', avg_num_neighbors=18.41771125793457, correlation=3, gate='silu')

@ilyes319
Copy link
Contributor

ilyes319 commented Dec 10, 2024

Thanks @RokasEl, that's a good idea. One can even extract the hypers directly using:

from mace.tools.scripts_utils import extract_config_mace_model

calc = mace_off("medium", device="cuda")
source_model = calc.models[0]
config = extract_config_mace_model(source_model)
target_model = source_model.__class__(**config).to(device)
target_model.load_state_dict(source_model.state_dict())
from e3nn.util import jit
script_model = jit.compile(target_model)

You can swap to any MACE model for the source model, it will extract the hypers.

@lohedges
Copy link
Author

Brilliant, many thanks for this @RokasEl. I'll add this in now.

@lohedges
Copy link
Author

Can confirm that this works perfectly. Thanks again for your help.

lohedges added a commit to chemle/emle-engine that referenced this issue Dec 10, 2024
@ilyes319 ilyes319 added the bug Something isn't working label Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants