etree.backend.Jax

etree.backend.Jax#

class etils.etree.backend.Jax[source]

Bases: etils.etree.backend.Backend

jax.tree_util backend.

import_module()[source]

Import and return the module.

map(map_fn, *trees, is_leaf=None)[source]

Like tf.nest.map_structure.

flatten(tree, *, is_leaf=None)[source]

Like tf.nest.flatten.

unflatten(structure, flat_sequence)[source]
assert_same_structure(
tree0: Any | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]],
tree1: Any | list[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]] | tuple[etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree], ...] | dict[str, etils.etree.typing._T | list[Tree] | tuple[Tree, ...] | dict[str, Tree]],
)[source]