Skip to content

Commit

Permalink
[XLA:Python] Use &PyArray_Type rather than looking up numpy.ndarray v…
Browse files Browse the repository at this point in the history
…ia Python attrs.

This is slightly simpler, and avoids the disagreement that triggers jax-ml/jax#25468 so we may as well land it.

PiperOrigin-RevId: 705909895
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Dec 13, 2024
1 parent b119483 commit 79faaf0
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions xla/python/py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,7 @@ absl::StatusOr<DevicePutResultFn> DevicePut(nb::handle arg,
(*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
HandlePythonScalar<complex128, complex64>;

const auto numpy = nb::module_::import_("numpy");
(*p)[numpy.attr("ndarray").ptr()] = HandleNumpyArray;
(*p)[reinterpret_cast<PyObject*>(&PyArray_Type)] = HandleNumpyArray;

// Numpy scalar types. For some of them, we share the handler with
// Python types (np_int64, np_float64, np_complex128).
Expand Down Expand Up @@ -553,8 +552,7 @@ absl::StatusOr<PyArgSignature> PyArgSignatureOfValue(nb::handle arg,
numpy_array.ndim()),
/*weak_type=*/false);
};
const auto numpy = nb::module_::import_("numpy");
(*p)[numpy.attr("ndarray").ptr()] = numpy_handler;
(*p)[reinterpret_cast<PyObject*>(&PyArray_Type)] = numpy_handler;

ToPyArgSignatureHandler np_uint64_handler =
[](nb::handle h,
Expand Down

0 comments on commit 79faaf0

Please sign in to comment.