etree.tree_utils.TreeAPI#

class etils.etree.tree_utils.TreeAPI(backend: etils.etree.backend.Backend)[source]#

Bases: object

Tree API, using either jax.tree_utils, tf.nest or tree backend.

map(map_fn: Callable[[...], etils.etree.tree_utils._Tout], *trees: Any | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]], is_leaf: collections.abc.Callable[[Any], bool] | None = None) etils.etree.tree_utils._Tout | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]][source]#

Same as tree.map_structure.

Parameters:
  • map_fn – Worker function

  • *trees – Nested input to pass to the map_fn

  • is_leaf – Don’t recurse into leaf if is_leaf(node) is True

Returns:

The nested structure after map_fn has been applied.

parallel_map(map_fn: Callable[[...], etils.etree.tree_utils._Tout], *trees: Any | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]], num_threads: int | None = None, progress_bar: bool = False, is_leaf: collections.abc.Callable[[Any], bool] | None = None) etils.etree.tree_utils._Tout | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]][source]#

Same as tree.map_structure but apply map_fn in parallel.

Parameters:
  • map_fn – Worker function

  • *trees – Nested input to pass to the map_fn

  • num_threads – Number of workers (default to CPU count * 5)

  • progress_bar – If True, display a progression bar.

  • is_leaf – Don’t recurse into leaf if is_leaf(node) is True

Returns:

The nested structure after map_fn has been applied.

unzip(tree: Iterable[Any] | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]]) Iterator[Any | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]]][source]#

Unpack a tree of iterable.

This is the reverse operation of tree.map_structure(zip, *trees)

Example:

etree.unzip({'a': np.array([1, 2, 3])}) == [{'a': 1}, {'a': 2}, {'a': 3}]
Parameters:

tree – The tree to unzip

Yields:

Trees of same structure than the input, but with individual elements.

stack(trees: Iterable[etils.enp.array_types.typing.Array | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]]]) etils.enp.array_types.typing.Array | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]][source]#

Stack a tree of Iterable[Array].

Supports jax, tf, np.

Example:

etree.stack([
    {'a': np.array([1])},
    {'a': np.array([2])},
    {'a': np.array([3])},
]) == {
    'a': np.array([[1], [2], [3]])
}
Parameters:

trees – The list of tree to stack

Returns:

Tree of arrays.

spec_like(tree: etils.enp.array_types.typing.Array | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]], *, ignore_other: bool = True) etils.enp.array_spec.ArraySpec | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]][source]#

Inspect a tree of array, works with any array type.

Example:

model = MyModel()
variables = model.init(jax.random.PRNGKey(0), x)

# Inspect the `variables` tree structures
print(etree.spec_like(variables))
Parameters:
  • tree – The tree of array

  • ignore_other – If True, non-array are forwarded as-is.

Returns:

The tree of enp.ArraySpec.