diff --git a/xla/python/pytree.cc b/xla/python/pytree.cc index 5a165cde06920..138316c722d56 100644 --- a/xla/python/pytree.cc +++ b/xla/python/pytree.cc @@ -40,6 +40,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -100,7 +101,7 @@ void PyTreeRegistry::Register( if (!it.second) { throw std::invalid_argument( absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", - nb::cast(nb::repr(type)))); + nb::cast(nb::repr(type)))); } } @@ -116,7 +117,7 @@ void PyTreeRegistry::RegisterDataclass(nb::object type, if (!it.second) { throw std::invalid_argument(absl::StrFormat( "Duplicate custom dataclass PyTreeDef type registration for %s.", - nb::cast(nb::repr(std::move(type))))); + nb::cast(nb::repr(std::move(type))))); } } @@ -129,7 +130,7 @@ PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { throw std::invalid_argument(absl::StrCat( "The to_iterable function for a custom PyTree node should return " "a (children, aux_data) tuple, got ", - nb::cast(nb::repr(out)))); + nb::cast(nb::repr(out)))); } nb::iterable leaves; if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { @@ -137,7 +138,7 @@ PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { "The to_iterable function for a custom PyTree node should return " "a (children, aux_data) tuple where 'children' is iterable, " "got ", - nb::cast(nb::repr(out)))); + nb::cast(nb::repr(out)))); } return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); } @@ -161,7 +162,7 @@ PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { throw std::invalid_argument(absl::StrCat( "The to_iterable_with_keys function for a custom PyTree " "node should return a (key_leaf_pairs, aux_data) tuple, got ", - nb::cast(nb::repr(out)))); + nb::cast(nb::repr(out)))); } nb::iterable key_leaf_pairs; if (!nb::try_cast(leaves_and_aux_data[0], key_leaf_pairs)) { @@ -169,7 +170,7 @@ PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { "The to_iterable_with_keys function for a custom PyTree node should " "return a (key_leaf_pairs, aux_data) tuple where 'key_leaf_pairs' is " "iterable, got ", - nb::cast(nb::repr(leaves_and_aux_data)))); + nb::cast(nb::repr(leaves_and_aux_data)))); } for (nb::handle key_leaf_pair : key_leaf_pairs) { nb::tuple key_leaf_pair_tuple; @@ -178,7 +179,7 @@ PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { throw std::invalid_argument(absl::StrCat( "The to_iterable_with_keys function for a custom PyTree node should " "return a (key_leaf_pairs, aux_data) tuple where 'child", - nb::cast(nb::repr(key_leaf_pair)))); + nb::cast(nb::repr(key_leaf_pair)))); } result.push_back(std::make_pair(nb::borrow(key_leaf_pair_tuple[0]), nb::borrow(key_leaf_pair_tuple[1]))); @@ -291,22 +292,62 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const { } nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/false); +} + +nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/true); +} + +nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, + bool with_keys) const { PyTreeRegistry::Registration const* custom; PyTreeKind kind = KindOfObject(x, &custom); switch (kind) { case PyTreeKind::kNone: return nb::make_tuple(nb::make_tuple(), nb::none()); - case PyTreeKind::kTuple: - case PyTreeKind::kList: + case PyTreeKind::kTuple: { + if (with_keys) { + auto size = PyTuple_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyTuple_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kList: { + if (with_keys) { + auto size = PyList_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyList_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } return nb::make_tuple(nb::borrow(x), nb::none()); + } case PyTreeKind::kDict: { nb::dict dict = nb::borrow(x); std::vector sorted_keys = GetSortedPyDictKeys(dict.ptr()); nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); for (size_t i = 0; i < sorted_keys.size(); ++i) { - PyTuple_SET_ITEM(values.ptr(), i, - nb::object(dict[sorted_keys[i]]).release().ptr()); + nb::object& key = sorted_keys[i]; + nb::object value = nb::object(dict[key]); + if (with_keys) { + value = nb::make_tuple(make_nb_class(key), value); + } + PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr()); PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); } return nb::make_tuple(std::move(values), std::move(keys)); @@ -314,12 +355,32 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { case PyTreeKind::kNamedTuple: { nb::tuple in = nb::borrow(x); nb::list out; + if (with_keys) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(in, "_fields"), fields) || + in.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : in) { + out.append(nb::make_tuple( + make_nb_class(nb::str(*field_iter)), entry)); + } + return nb::make_tuple(std::move(out), x.type()); + } for (size_t i = 0; i < in.size(); ++i) { out.append(in[i]); } return nb::make_tuple(std::move(out), x.type()); } case PyTreeKind::kCustom: { + if (with_keys) { + auto [leaves, aux_data] = custom->ToIterableWithKeys(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } auto [leaves, aux_data] = custom->ToIterable(x); return nb::make_tuple(std::move(leaves), std::move(aux_data)); } @@ -327,9 +388,12 @@ nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { auto data_size = custom->data_fields.size(); nb::list leaves = nb::steal(PyList_New(data_size)); for (int leaf = 0; leaf < data_size; ++leaf) { - PyList_SET_ITEM( - leaves.ptr(), leaf, - nb::getattr(x, custom->data_fields[leaf]).release().ptr()); + nb::object value = nb::getattr(x, custom->data_fields[leaf]); + if (with_keys) { + value = nb::make_tuple( + make_nb_class(custom->data_fields[leaf]), value); + } + PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr()); } auto meta_size = custom->meta_fields.size(); nb::object aux_data = nb::steal(PyTuple_New(meta_size)); @@ -401,21 +465,21 @@ std::string SequenceKey::ToReprString() const { } std::string DictKey::ToString() const { - return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); + return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); } std::string DictKey::ToReprString() const { return absl::StrFormat("DictKey(key=%s)", - nb::cast(nb::repr(key_))); + nb::cast(nb::repr(key_))); } std::string GetAttrKey::ToString() const { - return absl::StrFormat(".%s", nb::cast(name_)); + return absl::StrFormat(".%s", nb::cast(name_)); } std::string GetAttrKey::ToReprString() const { return absl::StrFormat("GetAttrKey(name='%s')", - nb::cast(name_)); + nb::cast(name_)); } std::string FlattenedIndexKey::ToString() const { @@ -483,7 +547,7 @@ void PyTreeDef::FlattenImpl(nb::handle handle, T& leaves, } else if (!nb::try_cast(o, is_known_leaf)) { throw std::invalid_argument(absl::StrCat( "is_leaf predicate returned a non-boolean value ", - nb::cast(nb::repr(o)), "; expected a boolean")); + nb::cast(nb::repr(o)), "; expected a boolean")); } } if (is_known_leaf) { @@ -836,7 +900,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (it == traversal_.rend()) { throw std::invalid_argument(absl::StrFormat( "Tree structures did not match: %s vs %s", - nb::cast(nb::repr(xs)), ToString())); + nb::cast(nb::repr(xs)), ToString())); } const Node& node = *it; nb::object object = agenda.back(); @@ -861,7 +925,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { "the previous behavior, you can usually write:\n" " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " "b, is_leaf=lambda x: x is None)", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } break; @@ -869,13 +933,13 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (!PyTuple_CheckExact(object.ptr())) { throw std::invalid_argument( absl::StrFormat("Expected tuple, got %s.", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } nb::tuple tuple = nb::borrow(object); if (tuple.size() != node.arity) { throw std::invalid_argument(absl::StrFormat( "Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } for (nb::handle entry : tuple) { agenda.push_back(nb::borrow(entry)); @@ -887,13 +951,13 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (!PyList_CheckExact(object.ptr())) { throw std::invalid_argument( absl::StrFormat("Expected list, got %s.", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } nb::list list = nb::borrow(object); if (list.size() != node.arity) { throw std::invalid_argument(absl::StrFormat( "List arity mismatch: %d != %d; list: %s.", list.size(), - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } for (nb::handle entry : list) { agenda.push_back(nb::borrow(entry)); @@ -905,7 +969,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (!PyDict_CheckExact(object.ptr())) { throw std::invalid_argument( absl::StrFormat("Expected dict, got %s.", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } nb::dict dict = nb::borrow(object); std::vector keys = GetSortedPyDictKeys(dict.ptr()); @@ -914,9 +978,9 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { // vector. This is error path so it is fine to pay conversion cost. throw std::invalid_argument( absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", - nb::cast( + nb::cast( nb::repr(nb::cast(node.sorted_dict_keys))), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } for (nb::handle key : keys) { agenda.push_back(dict[key]); @@ -929,19 +993,19 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { !nb::hasattr(object, "_fields")) { throw std::invalid_argument( absl::StrFormat("Expected named tuple, got %s.", - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(object)))); } nb::tuple tuple = nb::borrow(object); if (tuple.size() != node.arity) { throw std::invalid_argument(absl::StrFormat( "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } if (tuple.type().not_equal(node.node_data)) { throw std::invalid_argument(absl::StrFormat( "Named tuple type mismatch: expected type: %s, tuple: %s.", - nb::cast(nb::repr(node.node_data)), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(object)))); } for (nb::handle entry : tuple) { agenda.push_back(nb::borrow(entry)); @@ -954,16 +1018,16 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (registration != node.custom) { throw std::invalid_argument(absl::StrFormat( "Custom node type mismatch: expected type: %s, value: %s.", - nb::cast(nb::repr(node.custom->type)), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(object)))); } auto [leaves, aux_data] = node.custom->ToIterable(object); if (node.node_data.not_equal(aux_data)) { throw std::invalid_argument(absl::StrFormat( "Mismatch custom node data: %s != %s; value: %s.", - nb::cast(nb::repr(node.node_data)), - nb::cast(nb::repr(aux_data)), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); } int arity = 0; for (nb::handle entry : leaves) { @@ -973,7 +1037,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (arity != node.arity) { throw std::invalid_argument(absl::StrFormat( "Custom type arity mismatch: %d != %d; value: %s.", arity, - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } break; } @@ -984,8 +1048,8 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { throw std::invalid_argument(absl::StrFormat( "Custom dataclasss node type mismatch: expected type: %s, value: " "%s.", - nb::cast(nb::repr(node.custom->type)), - nb::cast(nb::repr(std::move(object))))); + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(std::move(object))))); } auto meta_size = node.custom->meta_fields.size(); nb::object aux_data = nb::steal(PyTuple_New(meta_size)); @@ -999,15 +1063,15 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (node.node_data.not_equal(aux_data)) { throw std::invalid_argument(absl::StrFormat( "Mismatch custom dataclass node data: %s != %s; value: %s.", - nb::cast(nb::repr(node.node_data)), - nb::cast(nb::repr(aux_data)), - nb::cast(nb::repr(object)))); + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); } auto data_size = node.custom->data_fields.size(); if (data_size != node.arity) { throw std::invalid_argument(absl::StrFormat( "Custom type arity mismatch: %d != %d; value: %s.", data_size, - node.arity, nb::cast(nb::repr(object)))); + node.arity, nb::cast(nb::repr(object)))); } for (int leaf = 0; leaf < data_size; ++leaf) { agenda.push_back(nb::borrow( @@ -1020,7 +1084,7 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { if (it != traversal_.rend() || leaf != -1) { throw std::invalid_argument( absl::StrFormat("Tree structures did not match: %s vs %s", - nb::cast(nb::repr(xs)), ToString())); + nb::cast(nb::repr(xs)), ToString())); } return leaves; } @@ -1213,7 +1277,7 @@ std::string PyTreeDef::ToString() const { auto child_iter = agenda.end() - node.arity; for (const nb::handle& key : node.sorted_dict_keys) { absl::StrAppendFormat(&representation, "%s%s: %s", separator, - nb::cast(nb::repr(key)), + nb::cast(nb::repr(key)), *child_iter); child_iter++; separator = ", "; @@ -1232,7 +1296,7 @@ std::string PyTreeDef::ToString() const { if (node.node_data) { // Node data for named tuples is the type. data = absl::StrFormat( - "[%s]", nb::cast( + "[%s]", nb::cast( nb::str(nb::getattr(node.node_data, "__name__")))); } } else { @@ -1240,7 +1304,7 @@ std::string PyTreeDef::ToString() const { nb::str(nb::getattr(node.custom->type, "__name__"))); if (node.node_data) { data = absl::StrFormat( - "[%s]", nb::cast(nb::str(node.node_data))); + "[%s]", nb::cast(nb::str(node.node_data))); } } @@ -1309,7 +1373,7 @@ void PyTreeDef::FromPickle(nb::object pickle) { if (node.custom == nullptr) { throw xla::XlaRuntimeError( absl::StrCat("Unknown custom type in pickled PyTreeDef: ", - nb::cast(nb::repr(t[3])))); + nb::cast(nb::repr(t[3])))); } } else { if (!t[3].is_none()) { @@ -1503,7 +1567,7 @@ nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( if (registration == nullptr) { throw std::logic_error(absl::StrFormat( "Could not find type: %s.", - nb::cast(nb::repr(node_data->first)))); + nb::cast(nb::repr(node_data->first)))); } node.kind = registration->kind; if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { @@ -1577,6 +1641,9 @@ void BuildPytreeSubmodule(nb::module_& m) { nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, nb::arg("tree").none()); + registry.def("flatten_one_level_with_keys", + &PyTreeRegistry::FlattenOneLevelWithKeys, + nb::arg("tree").none()); registry.def( "flatten_with_path", [](nb_class_ptr registry, nb::object x, @@ -1637,7 +1704,7 @@ void BuildPytreeSubmodule(nb::module_& m) { "deserialize_using_proto", [](nb_class_ptr registry, nb::bytes data) { jax::PyTreeDefProto input; - std::string_view serialized(data.c_str(), data.size()); + absl::string_view serialized(data.c_str(), data.size()); if (serialized.size() > std::numeric_limits::max()) { throw xla::XlaRuntimeError( "Pytree serialization too large to deserialize."); diff --git a/xla/python/pytree.h b/xla/python/pytree.h index 1dc8c6effc24e..55ddf041232d5 100644 --- a/xla/python/pytree.h +++ b/xla/python/pytree.h @@ -115,6 +115,11 @@ class PyTreeRegistry { // Flattens a pytree one level, returning either a tuple of the leaves and // the node data, or None, if the entry is a leaf. nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; static PyType_Slot slots_[]; diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index aadc1c2f6c71c..1f04470846690 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 300 +_version = 301 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_extension/pytree.pyi b/xla/python/xla_extension/pytree.pyi index a777e364e6503..a90bb59ad876f 100644 --- a/xla/python/xla_extension/pytree.pyi +++ b/xla/python/xla_extension/pytree.pyi @@ -48,6 +48,9 @@ class PyTreeRegistry: def flatten_one_level( self, tree: Any ) -> Optional[Tuple[Iterable[Any], Any]]: ... + def flatten_one_level_with_keys( + self, tree: Any + ) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ... def flatten_with_path( self, tree: Any,