Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalid Runtime inputs to embedding #952

Open
dgolubovicTT opened this issue Dec 23, 2024 · 2 comments
Open

Invalid Runtime inputs to embedding #952

dgolubovicTT opened this issue Dec 23, 2024 · 2 comments
Assignees

Comments

@dgolubovicTT
Copy link
Contributor

I am getting data mismatch for embedding op when input indices are int64. Turns out that input indices are invalid right before call to ttnn::embedding in runtime: third_party/tt-mlir/runtime/lib/ttnn/operations/embedding/embedding.cpp

Namely, input indices in torch look like this:

16044 8239 2933 13760 16963 16379 31427 6503 31497 9683 14101 26866

In third_party/tt-mlir/runtime/lib/ttnn/operations/embedding/embedding.cpp run method they look like this:

16044 0 8239 0 2933 0 13760 0 16963 0 16379 0 , data type is UINT32

So every other index is now 0, which causes embedding to return invalid rows from embedding matrix.

I've verified that input tensor is still valid in forge/csrc/runtime/runtime.cpp::run_binary, so the issue is somewhere between.

To repro:

checkout: dgolubovic/repro-embedding-input-indices-issue
run: pytest -svv forge/test/mlir/test_ops.py::test_embedding

One additional note: If the input indices are int64:

    inputs = [
        torch.randint(0, vocab_size, (1, token_num)),
    ]

This issue occurs. However, If input indices is int32, everything works.

    inputs = [
        torch.randint(0, vocab_size, (1, token_num),dtype=torch.int32),
    ]

On the other hand, ttnn::embedding op accepts dtype=DataType::UINT32, so runtime tensor has to have that datatype. @jnie-TT do you know where do we do this datatype conversion from pytorch tensors to ttnn.Tensor?

@jnie-TT
Copy link
Contributor

jnie-TT commented Dec 23, 2024

@dgolubovicTT can you also include the dumped mlir graph here as well in the ttnn dialect?

@dgolubovicTT
Copy link
Contributor Author

Ofc. It is the same in both cases (input indices int32 and int64), which is expected.

Embedding_test_ops_ttnn.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants