Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export a model with multiple entry points #7458

Open
LaurentMazare opened this issue Dec 30, 2024 · 3 comments
Open

Export a model with multiple entry points #7458

LaurentMazare opened this issue Dec 30, 2024 · 3 comments
Labels
module: doc Related to our documentation, both in docs/ and docblocks module: exir Issues related to Export IR

Comments

@LaurentMazare
Copy link

Hello,
I hope it's the appropriate place to ask such questions, let me know if it would be better suited elsewhere.
We would like to run some model via executorch, the trickiness is that this model has multiple methods that we want to expose and that can manipulate its state (one method is the actual forward pass, the other allows one to reset the internal state). I don't think I've found a proper way to do this.
I came across export multiple functions of a pytorch module that suggests calling export multiple times and relying on "Later when ExecuTorch serializes to a binary, the weights/buffer in that structure are then merged into one state dict" but didn't manage to get that to work.
First even if the documentation of torch.export mentions that it can apply to a callable here, it seems to only work on modules. And after trying to call torch.export on two different modules with a common state, these don't seem to actually get shared.
Do you know if it's possible to achieve this with executorch at the moment? Any pointer on a model that already does this? (I looked into the examples and googled around but no luck)
In case it helps, you can see the code where I tried to export two modules that share a common state here.

@cccclai
Copy link
Contributor

cccclai commented Dec 30, 2024

We need to have more documentation regarding multimethods...you can checkout the unit test written in #7281

@cccclai cccclai added module: doc Related to our documentation, both in docs/ and docblocks module: exir Issues related to Export IR labels Dec 30, 2024
@LaurentMazare
Copy link
Author

LaurentMazare commented Dec 30, 2024

Thanks, would you expect multimethods to work when modifying some inner state of a module?
I tried the following example based on the unit test you mentioned.

from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
    ExecutorBackendPartitioner,
)
import torch

class SharedModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._v = torch.nn.Parameter(torch.ones(1, dtype=torch.float))


class Module1(torch.nn.Module):
    def __init__(self, shared_module):
        super().__init__()
        self.shared_module = shared_module

    def forward(self, x):
        self.shared_module._v[:] = self.shared_module._v + x
        return self.shared_module._v


class Module2(torch.nn.Module):
    def __init__(self, shared_module):
        super().__init__()
        self.shared_module = shared_module

    def forward(self, x):
        self.shared_module._v.fill_(0.0)
        return x

def export():
    shared_module = SharedModule()
    module_1 = Module1(shared_module)
    module_2 = Module2(shared_module)
    example_inputs = (torch.randn(1),)
    module_1(*example_inputs)
    module_2(*example_inputs)

    ep1 = torch.export.export_for_training(module_1, example_inputs)
    ep2 = torch.export.export_for_training(module_2, example_inputs)

    edge_program_manager = executorch.exir.to_edge(
        {
            "forward1": ep1,
            "forward2": ep2,
        },
        compile_config=executorch.exir.EdgeCompileConfig(
            _check_ir_validity=False, _use_edge_ops=True
        ),
    )
    edge_program_manager = edge_program_manager.to_backend(ExecutorBackendPartitioner()).to_executorch()

with torch.no_grad():
    export()

However this resulted in the following error. Maybe there is something wrong in the way I'm trying to export the modules?

Traceback (most recent call last):
    ep1 = torch.export.export(module_1, example_inputs)
    ....
    return _create_aot_dispatcher_function(
RuntimeError: Found a graph input that requires gradients, and received a mutation.
This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.

edit
If I add to the torch.nn.Parameter requires_grad=False, then I get a different error:

Traceback (most recent call last):
  File "/Users/laurent/github/xctch/xctch-core/tests/shared.py", line 55, in <module>
    export()
  File "/Users/laurent/github/xctch/xctch-core/tests/shared.py", line 43, in export
    edge_program_manager = executorch.exir.to_edge(
                           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/laurent/venvs/executorch/lib/python3.12/site-packages/executorch/exir/program/_program.py", line 1142, in to_edge
    program = program.run_decompositions(_default_decomposition_table())
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/laurent/venvs/executorch/lib/python3.12/site-packages/torch/export/exported_program.py", line 114, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/laurent/venvs/executorch/lib/python3.12/site-packages/torch/export/exported_program.py", line 1003, in run_decompositions
    return _decompose_exported_program(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/laurent/venvs/executorch/lib/python3.12/site-packages/torch/export/exported_program.py", line 617, in _decompose_exported_program
    gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/laurent/venvs/executorch/lib/python3.12/site-packages/torch/export/exported_program.py", line 362, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
                           ^^^^^^^^^^^^^^^^^^^
  File "/Users/laurent/venvs/executorch/lib/python3.12/site-packages/torch/export/_trace.py", line 637, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/laurent/venvs/executorch/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1301, in aot_export_module
    return fx_g, create_graph_signature(
                 ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/laurent/venvs/executorch/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py", line 481, in create_graph_signature
    return GraphSignature.from_tracing_metadata(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/laurent/venvs/executorch/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/schemas.py", line 752, in from_tracing_metadata
    assert idx >= len(parameters)
           ^^^^^^^^^^^^^^^^^^^^^^
AssertionError

@kimishpatel
Copy link
Contributor

I would be somewhat surprised if we have a support for shared state that can be modified from two different methods. Like self._v in your example. Although note that instead of using Parameter you might wanna use register_buffer API from torch. cc @JacobSzwejbka for shared mutable state.

Now regarding the actual error, I suspect it will be also resolved if you register buffer since only buffer can be mutated/changed by the program and not parameters. THat is why I think you are seeing that error (https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/schemas.py#L824)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: doc Related to our documentation, both in docs/ and docblocks module: exir Issues related to Export IR
Projects
None yet
Development

No branches or pull requests

3 participants