Skip to content

Commit

Permalink
Update tensor_shape.py for JAX/NumPy backend.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681829342
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Oct 3, 2024
1 parent 4b8182b commit 59b3651
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1435,11 +1435,12 @@ def is_compatible_with(self, other):
"""
other = as_shape(other)
if self.dims is not None and other.dims is not None:
if self._dims is not None and other._dims is not None: # pylint: disable=protected-access
if self.rank != other.rank:
return False
for x_dim, y_dim in zip(self.dims, other.dims):
if not x_dim.is_compatible_with(y_dim):
for x_dim, y_dim in zip(self._dims, other._dims): # pylint: disable=protected-access
# Inline TensorShape.dims logic for performance in tight loops.
if x_dim is not None and y_dim is not None and x_dim != y_dim:
return False
return True

Expand Down

0 comments on commit 59b3651

Please sign in to comment.