archimedes.treeยถ

Utilities for working with hierarchical โ€œpytreeโ€ data structures.

Functions

all(tree[, is_leaf])

Check if all leaves in the pytree evaluate to True.

flatten(x[, is_leaf])

Flatten a pytree into a list of leaves and a treedef.

leaves(tree[, is_leaf])

Extract all leaf values from a pytree.

map(f, tree, *rest[, is_leaf])

Apply a function to each leaf in a pytree.

ravel(pytree)

Flatten a pytree to a single 1D array.

reduce(function, tree, initializer[, is_leaf])

Reduce a pytree to a single value using a function and initializer.

register_dataclass(nodetype[, data_fields, ...])

Register a dataclass as a pytree node with customized field handling.

register_pytree_node(ty, to_iter, from_iter)

Register a custom type as a pytree node.

structure(tree[, is_leaf])

Extract the structure of a pytree without the leaf values.

unflatten(treedef, xs)

Reconstruct a pytree from a list of leaves and a treedef.