From 9783adb438e717be6f1fd409a2adab217a3e5f8d Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Mon, 2 Dec 2024 17:37:33 -0800 Subject: [PATCH] Implement C++ flatten one level with keys and use it for the prefix/equality error printing. With this, we should be able to safely delete the python with-path registry after the new jaxlib release. PiperOrigin-RevId: 702138746 --- xla/python/pytree.cc | 60 ++++++++++++++++++++++++++--- xla/python/pytree.h | 5 +++ xla/python/xla_client.py | 2 +- xla/python/xla_extension/pytree.pyi | 3 ++ 4 files changed, 64 insertions(+), 6 deletions(-) diff --git a/xla/python/pytree.cc b/xla/python/pytree.cc index 5a165cde069201..2ece98efd747ea 100644 --- a/xla/python/pytree.cc +++ b/xla/python/pytree.cc @@ -291,6 +291,15 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const { } nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/false); +} + +nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/true); +} + +nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, + bool with_keys) const { PyTreeRegistry::Registration const* custom; PyTreeKind kind = KindOfObject(x, &custom); switch (kind) { @@ -298,6 +307,16 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { return nb::make_tuple(nb::make_tuple(), nb::none()); case PyTreeKind::kTuple: case PyTreeKind::kList: + if (with_keys) { + auto size = PyTuple_GET_SIZE(x.ptr()); + nb::tuple key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + const auto key = make_nb_class(i); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, x[i]).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } return nb::make_tuple(nb::borrow(x), nb::none()); case PyTreeKind::kDict: { nb::dict dict = nb::borrow(x); @@ -305,8 +324,13 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); for (size_t i = 0; i < sorted_keys.size(); ++i) { - PyTuple_SET_ITEM(values.ptr(), i, - nb::object(dict[sorted_keys[i]]).release().ptr()); + nb::object value = nb::object(dict[sorted_keys[i]]); + if (with_keys) { + // value = nb::make_tuple( + // make_nb_class(nb::object(sorted_keys[i])), value); + value = nb::make_tuple(make_nb_class(sorted_keys[i]), value); + } + PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr()); PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); } return nb::make_tuple(std::move(values), std::move(keys)); @@ -314,12 +338,32 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { case PyTreeKind::kNamedTuple: { nb::tuple in = nb::borrow(x); nb::list out; + if (with_keys) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(in, "_fields"), fields) || + in.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : in) { + out.append(nb::make_tuple( + make_nb_class(nb::str(*field_iter)), entry)); + } + return nb::make_tuple(std::move(out), x.type()); + } for (size_t i = 0; i < in.size(); ++i) { out.append(in[i]); } return nb::make_tuple(std::move(out), x.type()); } case PyTreeKind::kCustom: { + if (with_keys) { + auto [leaves, aux_data] = custom->ToIterableWithKeys(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } auto [leaves, aux_data] = custom->ToIterable(x); return nb::make_tuple(std::move(leaves), std::move(aux_data)); } @@ -327,9 +371,12 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { auto data_size = custom->data_fields.size(); nb::list leaves = nb::steal(PyList_New(data_size)); for (int leaf = 0; leaf < data_size; ++leaf) { - PyList_SET_ITEM( - leaves.ptr(), leaf, - nb::getattr(x, custom->data_fields[leaf]).release().ptr()); + nb::object value = nb::getattr(x, custom->data_fields[leaf]); + if (with_keys) { + value = nb::make_tuple( + make_nb_class(custom->data_fields[leaf]), value); + } + PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr()); } auto meta_size = custom->meta_fields.size(); nb::object aux_data = nb::steal(PyTuple_New(meta_size)); @@ -1577,6 +1624,9 @@ void BuildPytreeSubmodule(nb::module_& m) { nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, nb::arg("tree").none()); + registry.def("flatten_one_level_with_keys", + &PyTreeRegistry::FlattenOneLevelWithKeys, + nb::arg("tree").none()); registry.def( "flatten_with_path", [](nb_class_ptr registry, nb::object x, diff --git a/xla/python/pytree.h b/xla/python/pytree.h index 1dc8c6effc24e8..55ddf041232d58 100644 --- a/xla/python/pytree.h +++ b/xla/python/pytree.h @@ -115,6 +115,11 @@ class PyTreeRegistry { // Flattens a pytree one level, returning either a tuple of the leaves and // the node data, or None, if the entry is a leaf. nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; static PyType_Slot slots_[]; diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index d15ff8201d4b37..5db45df4a7f32d 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 299 +_version = 300 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_extension/pytree.pyi b/xla/python/xla_extension/pytree.pyi index 0aaee09a8e9cc8..81bdcbbb1aec75 100644 --- a/xla/python/xla_extension/pytree.pyi +++ b/xla/python/xla_extension/pytree.pyi @@ -32,6 +32,9 @@ class PyTreeRegistry: def flatten_one_level( self, tree: Any ) -> Optional[Tuple[Iterable[Any], Any]]: ... + def flatten_one_level_with_keys( + self, tree: Any + ) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ... def flatten_with_path( self, tree: Any,