Tree utils#
etree
can be used with 5 different backends, depending on usage:
from etils import etree # Default Pure Python backend (no extra deps)
from etils.etree import jax as etree # Backend: jax.tree_utils
from etils.etree import optree as etree # Backend: optree
from etils.etree import nest as etree # Backend: tf.nest
from etils.etree import tree as etree # Backend: tree (DeepMind)
from etils.etree import py as etree # Backend: Pure Python (no extra deps)
parallel_map#
Similar to tree.map_structure
, but each leaf is executed in parallel.
img_paths = {'train': ['img0.png', ...], 'test': ['img1.png', ...]}
imgs = etree.parallel_map(imageio.imread, img_paths) # Load images in parallel
Kwargs:
progress_bar
: IfTrue
, display a progress barnum_threads
: Number of parallel threads (default to number of CPUs * 5)
unzip#
Unpack a tree of iterable. This is the reverse operation of tree.map_structure(zip, *trees)
Example:
etree.unzip({'a': np.array([1, 2, 3])}) == [{'a': 1}, {'a': 2}, {'a': 3}]
spec_like#
Extract the shape/dtype of a tree of array. Used for inspecting:
model = MyModel()
variables = model.init(jax.random.PRNGKey(0), x)
# Inspect the `variables` tree structures
print(etree.spec_like(variables))
Output the shape/dtype of all elements:
FrozenDict({
params: {
Conv_0: {
kernel: f32[1 1 3 32],
bias: f32[32],
},
Conv_1: {
kernel: f32[1 1 32 32],
bias: f32[32],
},
[...]
This works with np, torch, jax, tf arrays.