etree.tree_utils.TreeAPI#
- class etils.etree.tree_utils.TreeAPI(backend: etils.etree.backend.Backend)[source]#
Bases:
object
Tree API, using either jax.tree_utils, tf.nest or tree backend.
- map(map_fn: Callable[[...], etils.etree.tree_utils._Tout], *trees: Any | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]], is_leaf: collections.abc.Callable[[Any], bool] | None = None) etils.etree.tree_utils._Tout | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] [source]#
Same as tree.map_structure.
- Parameters:
map_fn – Worker function
*trees – Nested input to pass to the map_fn
is_leaf – Don’t recurse into leaf if is_leaf(node) is True
- Returns:
The nested structure after map_fn has been applied.
- parallel_map(map_fn: Callable[[...], etils.etree.tree_utils._Tout], *trees: Any | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]], num_threads: int | None = None, progress_bar: bool = False, is_leaf: collections.abc.Callable[[Any], bool] | None = None) etils.etree.tree_utils._Tout | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] [source]#
Same as tree.map_structure but apply map_fn in parallel.
- Parameters:
map_fn – Worker function
*trees – Nested input to pass to the map_fn
num_threads – Number of workers (default to CPU count * 5)
progress_bar – If True, display a progression bar.
is_leaf – Don’t recurse into leaf if is_leaf(node) is True
- Returns:
The nested structure after map_fn has been applied.
- unzip(tree: Iterable[Any] | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]]) Iterator[Any | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]]] [source]#
Unpack a tree of iterable.
This is the reverse operation of tree.map_structure(zip, *trees)
Example:
etree.unzip({'a': np.array([1, 2, 3])}) == [{'a': 1}, {'a': 2}, {'a': 3}]
- Parameters:
tree – The tree to unzip
- Yields:
Trees of same structure than the input, but with individual elements.
- stack(trees: Iterable[etils.enp.array_types.typing.Array | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]]]) etils.enp.array_types.typing.Array | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] [source]#
Stack a tree of Iterable[Array].
Supports jax,
tf
, np.Example:
etree.stack([ {'a': np.array([1])}, {'a': np.array([2])}, {'a': np.array([3])}, ]) == { 'a': np.array([[1], [2], [3]]) }
- Parameters:
trees – The list of tree to stack
- Returns:
Tree of arrays.
- spec_like(tree: etils.enp.array_types.typing.Array | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]], *, ignore_other: bool = True) etils.enp.array_spec.ArraySpec | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] [source]#
Inspect a tree of array, works with any array type.
Example:
model = MyModel() variables = model.init(jax.random.PRNGKey(0), x) # Inspect the `variables` tree structures print(etree.spec_like(variables))
- Parameters:
tree – The tree of array
ignore_other – If True, non-array are forwarded as-is.
- Returns:
The tree of
enp.ArraySpec
.