Skip to content

Commit

Permalink
Implement C++ flatten one level with keys and use it for the prefix/e…
Browse files Browse the repository at this point in the history
…quality error printing.

With this, we should be able to safely delete the python with-path registry after the new jaxlib release.

PiperOrigin-RevId: 702138746
  • Loading branch information
IvyZX authored and Google-ML-Automation committed Dec 3, 2024
1 parent a7c38b7 commit 9783adb
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 6 deletions.
60 changes: 55 additions & 5 deletions xla/python/pytree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,45 +291,92 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
}

nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const {
return FlattenOneLevelImpl(x, /*with_keys=*/false);
}

nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const {
return FlattenOneLevelImpl(x, /*with_keys=*/true);
}

nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x,
bool with_keys) const {
PyTreeRegistry::Registration const* custom;
PyTreeKind kind = KindOfObject(x, &custom);
switch (kind) {
case PyTreeKind::kNone:
return nb::make_tuple(nb::make_tuple(), nb::none());
case PyTreeKind::kTuple:
case PyTreeKind::kList:
if (with_keys) {
auto size = PyTuple_GET_SIZE(x.ptr());
nb::tuple key_leaves = nb::steal<nb::tuple>(PyTuple_New(size));
for (int i = 0; i < size; ++i) {
const auto key = make_nb_class<SequenceKey>(i);
PyTuple_SET_ITEM(key_leaves.ptr(), i,
nb::make_tuple(key, x[i]).release().ptr());
}
return nb::make_tuple(std::move(key_leaves), nb::none());
}
return nb::make_tuple(nb::borrow(x), nb::none());
case PyTreeKind::kDict: {
nb::dict dict = nb::borrow<nb::dict>(x);
std::vector<nb::object> sorted_keys = GetSortedPyDictKeys(dict.ptr());
nb::tuple keys = nb::steal<nb::tuple>(PyTuple_New(sorted_keys.size()));
nb::tuple values = nb::steal<nb::tuple>(PyTuple_New(sorted_keys.size()));
for (size_t i = 0; i < sorted_keys.size(); ++i) {
PyTuple_SET_ITEM(values.ptr(), i,
nb::object(dict[sorted_keys[i]]).release().ptr());
nb::object value = nb::object(dict[sorted_keys[i]]);
if (with_keys) {
// value = nb::make_tuple(
// make_nb_class<DictKey>(nb::object(sorted_keys[i])), value);
value = nb::make_tuple(make_nb_class<DictKey>(sorted_keys[i]), value);
}
PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr());
PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr());
}
return nb::make_tuple(std::move(values), std::move(keys));
}
case PyTreeKind::kNamedTuple: {
nb::tuple in = nb::borrow<nb::tuple>(x);
nb::list out;
if (with_keys) {
// Get key names from NamedTuple fields.
nb::tuple fields;
if (!nb::try_cast<nb::tuple>(nb::getattr(in, "_fields"), fields) ||
in.size() != fields.size()) {
throw std::invalid_argument(
"A namedtuple's _fields attribute should have the same size as "
"the tuple.");
}
auto field_iter = fields.begin();
for (nb::handle entry : in) {
out.append(nb::make_tuple(
make_nb_class<GetAttrKey>(nb::str(*field_iter)), entry));
}
return nb::make_tuple(std::move(out), x.type());
}
for (size_t i = 0; i < in.size(); ++i) {
out.append(in[i]);
}
return nb::make_tuple(std::move(out), x.type());
}
case PyTreeKind::kCustom: {
if (with_keys) {
auto [leaves, aux_data] = custom->ToIterableWithKeys(x);
return nb::make_tuple(std::move(leaves), std::move(aux_data));
}
auto [leaves, aux_data] = custom->ToIterable(x);
return nb::make_tuple(std::move(leaves), std::move(aux_data));
}
case PyTreeKind::kDataclass: {
auto data_size = custom->data_fields.size();
nb::list leaves = nb::steal<nb::list>(PyList_New(data_size));
for (int leaf = 0; leaf < data_size; ++leaf) {
PyList_SET_ITEM(
leaves.ptr(), leaf,
nb::getattr(x, custom->data_fields[leaf]).release().ptr());
nb::object value = nb::getattr(x, custom->data_fields[leaf]);
if (with_keys) {
value = nb::make_tuple(
make_nb_class<GetAttrKey>(custom->data_fields[leaf]), value);
}
PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr());
}
auto meta_size = custom->meta_fields.size();
nb::object aux_data = nb::steal(PyTuple_New(meta_size));
Expand Down Expand Up @@ -1577,6 +1624,9 @@ void BuildPytreeSubmodule(nb::module_& m) {
nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt);
registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel,
nb::arg("tree").none());
registry.def("flatten_one_level_with_keys",
&PyTreeRegistry::FlattenOneLevelWithKeys,
nb::arg("tree").none());
registry.def(
"flatten_with_path",
[](nb_class_ptr<PyTreeRegistry> registry, nb::object x,
Expand Down
5 changes: 5 additions & 0 deletions xla/python/pytree.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class PyTreeRegistry {
// Flattens a pytree one level, returning either a tuple of the leaves and
// the node data, or None, if the entry is a leaf.
nanobind::object FlattenOneLevel(nanobind::handle x) const;
// Similar to above but returns a key-leaf pair for each leaf.
nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const;
// Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys.
nanobind::object FlattenOneLevelImpl(nanobind::handle x,
bool with_keys) const;

static PyType_Slot slots_[];

Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 299
_version = 300

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down
3 changes: 3 additions & 0 deletions xla/python/xla_extension/pytree.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class PyTreeRegistry:
def flatten_one_level(
self, tree: Any
) -> Optional[Tuple[Iterable[Any], Any]]: ...
def flatten_one_level_with_keys(
self, tree: Any
) -> Optional[Tuple[Iterable[_KeyLeafPair], Any]]: ...
def flatten_with_path(
self,
tree: Any,
Expand Down

0 comments on commit 9783adb

Please sign in to comment.