Skip to content

Commit

Permalink
Minor improvements in jagged tensor identification (#919)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #919

ATT. Concretely:

1. Tagging jagged tensor nodes in the fx graph now proceeds even if the nodes have a rank different from 2. This is to acomodate for the cases when jagged tensors are unsqueezed / reshaped at the path from the input (where they should be tagged for downstream shape inference) to the fbgemm op (where they are detected).

2. Inputs with the `shape[0]` being equal to one of the JT `shape[0]`, but not having an offsets tag attached are now ignored instead of failing the whole jagged tensor map inference.

3. Instead of falling back to the jagged batch-dim based JT shape inference, we now either fully rely on the inferred jagged tensor map or not at all if we fail to infer one.

4. Add `jagged_index_select` to the list of anchor ops for recognizing and tagging the jagged tensors and offsets in the fx graph.

Reviewed By: qxy11

Differential Revision: D48825713

fbshipit-source-id: 32851a0180dd47bfbc6669216cdf79817af64670
  • Loading branch information
aakhundov authored and facebook-github-bot committed Sep 2, 2023
1 parent 0f91ea5 commit 6bbe03c
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions fx2ait/fx2ait/tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,12 @@ def _try_getting_jagged_tensor_map(
for i, inp in enumerate(inputs):
if inp.shape[0] in jagged_tensor_batch_dims:
offsets_name = fx_inputs[i].meta.get("offsets_name", None)
if offsets_name is None or len(offsets_name) > 1:
if offsets_name is None:
# not a jagged tensor
continue
if len(offsets_name) > 1:
# offsets name attached to the jagged tensor's
# fx.Node is either unavailable or ambiguous
# fx.Node is either ambiguous: failing here
return None
offsets_name = list(offsets_name)[0]
if offsets_name not in seen_offsets_names:
Expand Down Expand Up @@ -387,7 +390,7 @@ def from_input_list_with_batch_size_jagged_tensor(
jagged_tensor_batch_dims=jagged_tensor_batch_dims,
fx_inputs=fx_inputs,
)
if jagged_tensor_map is not None:
if jagged_tensor_map:
logger.info("Successfully detected a jagged_tensor_map:")
for input_id, jagged_tensor_id in jagged_tensor_map.items():
logger.info(f"{input_id=}, {jagged_tensor_id=}")
Expand All @@ -407,7 +410,7 @@ def from_input_list_with_batch_size_jagged_tensor(
batch_dim_lower_bound: int = 0
batch_dim_upper_bound: int = 0
batch_dim_name: str = ""
if jagged_tensor_map is not None and ind in jagged_tensor_map:
if jagged_tensor_map and ind in jagged_tensor_map:
batch_dim_lower_bound = 0 # when all sequences are empty
# if the maximum sequence length for this jagged tensor was not
# inferred from the offsets, we use the globally configured
Expand All @@ -417,7 +420,7 @@ def from_input_list_with_batch_size_jagged_tensor(
)
batch_dim_upper_bound = max_batch_size * max_seq_len
batch_dim_name = f"batch_size_jagged_tensor_id_{jagged_tensor_map[ind]}"
elif batch_dim in jagged_tensor_batch_dims:
elif not jagged_tensor_map and batch_dim in jagged_tensor_batch_dims:
batch_dim_lower_bound = 0
max_seq_len = max_seq_lens_from_offsets.get(
batch_dim, max_sequence_length
Expand Down

0 comments on commit 6bbe03c

Please sign in to comment.