Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating onnx file for the inference of Mamba? #200

Open
llmexperiment opened this issue Feb 28, 2024 · 11 comments
Open

Generating onnx file for the inference of Mamba? #200

llmexperiment opened this issue Feb 28, 2024 · 11 comments

Comments

@llmexperiment
Copy link

Dear @tridao , @albertfgu ,

It looks like it is not straightforward to generate onnx file due to following reason using torch.onnx.export:

  1. It looks like the underlying scan operator is implemented in the triton
  2. We need the recursive version of scan for the inference which I believe is located starting line 119 (lines 119 to 133 where 133 is the return) as shown here: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119

Above two prevents (based on my understanding) generating onnx file. It would be great to have onnx file for the inference part for the smallest model.

Any suggestions how we can generate onnx file for the inference? (also for training separately)?

@tridao
Copy link
Collaborator

tridao commented Feb 28, 2024

We have no experience with ONNX. Do you have ideas on how to generate onnx for custom operations? If so would you like to contribute?

@llmexperiment
Copy link
Author

We have no experience with ONNX. Do you have ideas on how to generate onnx for custom operations? If so would you like to contribute?

Thanks @tridao ! I am working towards that direction (here is a how to do it: https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md)

Could you let me know what are other custom operators if any other than scan?

I know the code for scan is here (if I am not mistaken): https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.cpp

Any chance, if you have standalone implementation of scan?

@xenova
Copy link

xenova commented Mar 22, 2024

@llmexperiment You may be interested in looking at the HF transformers implementation (PR here), which supports a fallback if causal-conv1d is not found in the environment. I've also been trying to convert Mamba models to ONNX for transformers.js, but I've been running into a few issues. If I figure something out, I'll update the thread.

@IamShubhamGupto
Copy link
Contributor

Hey Im interested in converting the Vision Mamba - Vim paper to onnx but have not had success. I decided to start working with mamba layers first and then proceed there.

This is the current status of my code stack trace

root@ubuntu:/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim# python3 demo_export.py
Fetching 4 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 28149.69it/s]
/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py:57: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert H == self.img_size[0] and W == self.img_size[1], \
/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py:420: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if if_random_token_rank:
/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py:133: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert weight.shape == (N,)
/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py:150: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if N > BLOCK_N:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1124, in ast_to_ttir
    generator.visit(fn.parse())
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 293, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/lib/python3.8/ast.py", line 379, in generic_visit
    self.visit(item)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 288, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 414, in visit_Assign
    values = self.visit(node.value)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 946, in visit_Call
    return fn(*args, **extra_kwargs, **kws)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/language/core.py", line 30, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/language/core.py", line 813, in arange
    return semantic.arange(start, end, _builder)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/language/semantic.py", line 485, in arange
    raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "demo_export.py", line 61, in <module>
    torch.onnx.export(model, dummy_input, "vim_s_midclstok_ft_81p6acc_fp16.onnx", input_names=["input"],
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1596, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 133, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 124, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py", line 541, in forward
    x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
  File "/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py", line 478, in forward_features
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py", line 115, in forward
    hidden_states, residual = fused_add_norm_fn(
  File "/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 100, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 100, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 83, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/testing.py", line 104, in do_bench
    fn()
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 81, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 63, in _layer_norm_fwd_1pass_kernel
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/compiler.py", line 476, in compile
    next_module = compile_kernel(module)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/compiler.py", line 381, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 31:24:    HAS_BIAS: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    X += row * stride_x_row
    Y += row * stride_y_row
    if HAS_RESIDUAL:
        RESIDUAL += row * stride_res_row
    if STORE_RESIDUAL_OUT:
        RESIDUAL_OUT += row * stride_res_out_row
    # Compute mean and variance
    cols = tl.arange(0, BLOCK_N)
                        ^
ValueError("arange's arguments must be of type tl.constexpr")

from what I understand, torch.arange creates dynamic arguments which triton is not happy with.

@zengjie617789
Copy link

replace the optimized triton operations with original torch operations.

@IamShubhamGupto
Copy link
Contributor

replace the optimized triton operations with original torch operations.

could you share more details on your implementation?

@chengyupku
Copy link

I noticed that when using the original torch operations, G_intermediate becomes very large, causing many configurations to run out of memory. Is there any way to avoid this issue?

@zengjie617789
Copy link

replace the optimized triton operations with original torch operations.

could you share more details on your implementation?

  1. In layernorm.py file, there are functions such as rms_norm_ref, using these naive function instead of triton implement;
  2. rewrite all einsum function.

@bhack
Copy link

bhack commented Dec 15, 2024

If you are interested we have a few threads also at:

pytorch/pytorch#130150 (selective_scan custom ops)

pytorch/pytorch#95408 (comment) (native selective_scan and associative_scan)

pytorch/pytorch#120189 (mamba native).

@HBSDLJZ
Copy link

HBSDLJZ commented Dec 16, 2024

replace the optimized triton operations with original torch operations.

could you share more details on your implementation?

  1. In layernorm.py file, there are functions such as rms_norm_ref, using these naive function instead of triton implement;
  2. rewrite all einsum function.

Can you show me the code?Thanks!

@HBSDLJZ
Copy link

HBSDLJZ commented Dec 16, 2024

Can you show me the code?Thanks!

Can you show me the code?Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants