Skip to content

Commit

Permalink
Avoid jnp import in utils/generic.py (#30322)
Browse files Browse the repository at this point in the history
fix

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Apr 18, 2024
1 parent 60d5f8f commit 01ae3b8
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@
)


if is_flax_available():
import jax.numpy as jnp


class cached_property(property):
"""
Descriptor that mimics @property but caches output in member variable.
Expand Down Expand Up @@ -624,6 +620,8 @@ def transpose(array, axes=None):

return tf.transpose(array, perm=axes)
elif is_jax_tensor(array):
import jax.numpy as jnp

return jnp.transpose(array, axes=axes)
else:
raise ValueError(f"Type not supported for transpose: {type(array)}.")
Expand All @@ -643,6 +641,8 @@ def reshape(array, newshape):

return tf.reshape(array, newshape)
elif is_jax_tensor(array):
import jax.numpy as jnp

return jnp.reshape(array, newshape)
else:
raise ValueError(f"Type not supported for reshape: {type(array)}.")
Expand All @@ -662,6 +662,8 @@ def squeeze(array, axis=None):

return tf.squeeze(array, axis=axis)
elif is_jax_tensor(array):
import jax.numpy as jnp

return jnp.squeeze(array, axis=axis)
else:
raise ValueError(f"Type not supported for squeeze: {type(array)}.")
Expand All @@ -681,6 +683,8 @@ def expand_dims(array, axis):

return tf.expand_dims(array, axis=axis)
elif is_jax_tensor(array):
import jax.numpy as jnp

return jnp.expand_dims(array, axis=axis)
else:
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
Expand Down

0 comments on commit 01ae3b8

Please sign in to comment.