Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Minor improvements in jagged tensor identification (#919)
Summary: Pull Request resolved: #919 ATT. Concretely: 1. Tagging jagged tensor nodes in the fx graph now proceeds even if the nodes have a rank different from 2. This is to acomodate for the cases when jagged tensors are unsqueezed / reshaped at the path from the input (where they should be tagged for downstream shape inference) to the fbgemm op (where they are detected). 2. Inputs with the `shape[0]` being equal to one of the JT `shape[0]`, but not having an offsets tag attached are now ignored instead of failing the whole jagged tensor map inference. 3. Instead of falling back to the jagged batch-dim based JT shape inference, we now either fully rely on the inferred jagged tensor map or not at all if we fail to infer one. 4. Add `jagged_index_select` to the list of anchor ops for recognizing and tagging the jagged tensors and offsets in the fx graph. Reviewed By: qxy11 Differential Revision: D48825713 fbshipit-source-id: 32851a0180dd47bfbc6669216cdf79817af64670
- Loading branch information