You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am getting data mismatch for embedding op when input indices are int64. Turns out that input indices are invalid right before call to ttnn::embedding in runtime: third_party/tt-mlir/runtime/lib/ttnn/operations/embedding/embedding.cpp
On the other hand, ttnn::embedding op accepts dtype=DataType::UINT32, so runtime tensor has to have that datatype. @jnie-TT do you know where do we do this datatype conversion from pytorch tensors to ttnn.Tensor?
The text was updated successfully, but these errors were encountered:
I am getting data mismatch for embedding op when input indices are int64. Turns out that input indices are invalid right before call to ttnn::embedding in runtime:
third_party/tt-mlir/runtime/lib/ttnn/operations/embedding/embedding.cpp
Namely, input indices in torch look like this:
16044 8239 2933 13760 16963 16379 31427 6503 31497 9683 14101 26866
In
third_party/tt-mlir/runtime/lib/ttnn/operations/embedding/embedding.cpp
run method they look like this:16044 0 8239 0 2933 0 13760 0 16963 0 16379 0 , data type is UINT32
So every other index is now 0, which causes embedding to return invalid rows from embedding matrix.
I've verified that input tensor is still valid in
forge/csrc/runtime/runtime.cpp::run_binary
, so the issue is somewhere between.To repro:
checkout: dgolubovic/repro-embedding-input-indices-issue
run: pytest -svv forge/test/mlir/test_ops.py::test_embedding
One additional note: If the input indices are int64:
This issue occurs. However, If input indices is int32, everything works.
On the other hand, ttnn::embedding op accepts dtype=DataType::UINT32, so runtime tensor has to have that datatype. @jnie-TT do you know where do we do this datatype conversion from pytorch tensors to ttnn.Tensor?
The text was updated successfully, but these errors were encountered: