Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(shartank) PagedLlamaAttentionBlockTest::testExportNondecomposed fails with torch 2.4.0 and above #684

Open
marbre opened this issue Dec 12, 2024 · 0 comments

Comments

@marbre
Copy link
Collaborator

marbre commented Dec 12, 2024

Using

  • HEAD at commit d279aff
  • Python 3.11.10
  • iree-base-compiler==3.1.0rc20241212
  • iree-base-runtime==3.1.0rc20241212
  • iree-turbine==3.1.0rc20241211

the test PagedLlamaAttentionBlockTest::testExportNondecomposed fails with

  • torch==2.4.0+cpu and torch==2.4.1+cpu

FAILED tests/layers/paged_llama_attention_block_test.py::PagedLlamaAttentionBlockTest::testExportNondecomposed - iree.compiler._mlir_libs._site_initialize..MLIRError: Verification failed:

and

  • torch==2.5.0+cpu and torch==2.5.1+cpu

FAILED tests/layers/paged_llama_attention_block_test.py::PagedLlamaAttentionBlockTest::testExportNondecomposed - AssertionError: 'torch.aten._scaled_dot_product_flash_attention_for_cpu' not found in 'module @module {\n util.global private @__auto.constant_660_660_torch.float32 = dense_resource<__a...

Full error log for torch==2.4.0+cpu and torch==2.4.1+cpu (log for 2.5.* to large):

self = <tests.layers.paged_llama_attention_block_test.PagedLlamaAttentionBlockTest testMethod=testExportNondecomposed>

    def testExportNondecomposed(self):
        dtype = torch.float32
    
        cache = PagedKVCache(
            transformer_block_count=self.transformer_block_count,
            attn_head_count=self.head_count_kv,
            attn_head_dim=self.attention_head_dim,
            cache_partition_count=self.cache_partition_count,
            block_seq_stride=self.block_seq_stride,
            dtype=dtype,
        )
    
        cache_state = cache.paged.allocate(self.page_count)
        cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype)
    
        theta = make_llama_attention_block_theta(
            block_idx=0,
            head_count=self.attention_head_count,
            head_count_kv=self.head_count_kv,
            head_dim=self.attention_head_dim,
            embedding_length=self.embedding_length,
        )
        attn = PagedLlamaAttentionBlock(
            theta=theta,
            block_index=self.block_index,
            cache=cache,
            head_count=self.attention_head_count,
            head_dim=self.attention_head_dim,
            head_count_kv=self.head_count_kv,
            rms_epsilon=self.rms_epsilon,
            attention_kernel="torch",
        )
    
        seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
            self.batch_size, -1
        )
    
        embedding_module = RotaryEmbeddingLayer(
            rope_dimension_count=self.rope_dimension_count,
            max_seqlen=self.max_seqlen,
            rope_freq_base=self.rope_freq_base,
        )
    
        class MyModule(torch.nn.Module):
            def forward(self, h, seq_block_ids, cache_state):
                return attn.forward(
                    h,
                    seq_block_ids=seq_block_ids,
                    embedding=embedding_module,
                    start_index=0,
                    cache_state=cache_state,
                )
    
        mod = MyModule()
        h = torch.rand(
            [
                self.batch_size,
                self.max_seqlen,
                self.attention_head_count * self.attention_head_dim,
            ]
        )
        mod.forward(h, seq_block_ids, cache_state)
        ep = torch.export.export(
            mod,
            args=(
                h,
                seq_block_ids,
                cache_state,
            ),
        )
>       output = aot.export(ep)

tests/layers/paged_llama_attention_block_test.py:191: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.venv-uv-3.11/lib/python3.11/site-packages/iree/turbine/aot/exporter.py:351: in export
    cm = TransformedModule(context=context, import_to="import")
.venv-uv-3.11/lib/python3.11/site-packages/iree/turbine/aot/compiled_module.py:749: in __new__
    module_builder.finalize_construct()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <iree.turbine.aot.support.ir_utils.ModuleBuilder object at 0x7a6fb9a8acc0>

    def finalize_construct(self):
        try:
>           self.module_op.verify()
E           iree.compiler._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
E           error: "/home/mbrehler/repos/shark-ai/sharktank/tests/layers/paged_llama_attention_block_test.py":166:0: 'torch.aten.scaled_dot_product_attention' op expected 8 operands, but found 7
E            note: "/home/mbrehler/repos/shark-ai/sharktank/tests/layers/paged_llama_attention_block_test.py":166:0: see current operation: %403 = "torch.aten.scaled_dot_product_attention"(%392, %395, %398, %399, %400, %401, %402) : (!torch.vtensor<[3,30,119,22],f32>, !torch.vtensor<[3,30,119,22],f32>, !torch.vtensor<[3,30,119,22],f32>, !torch.none, !torch.float, !torch.bool, !torch.none) -> !torch.vtensor<[3,30,119,22],f32>

.venv-uv-3.11/lib/python3.11/site-packages/iree/turbine/aot/support/ir_utils.py:232: MLIRError
@marbre marbre changed the title (shartank) PagedLlamaAttentionBlockTest::testExportNondecomposed fails with torch 2.4.0 and above (shartank) PagedLlamaAttentionBlockTest::testExportNondecomposed fails with torch 2.4.0 and above Dec 12, 2024
marbre added a commit to marbre/shark-ai that referenced this issue Dec 12, 2024
Marks `testExportNondecomposed` as expected to fail if
running with torch>=2.4.0, see nod-ai#684.
marbre added a commit to marbre/shark-ai that referenced this issue Dec 12, 2024
Marks `testExportNondecomposed` as expected to fail if
running with torch>=2.4.0, see nod-ai#684.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant