enp.check_and_normalize_arrays

enp.check_and_normalize_arrays#

etils.enp.check_and_normalize_arrays(fn: None = None, *, strict: bool = True) Callable[[etils.enp.checking._Fn], etils.enp.checking._Fn][source]#
etils.enp.check_and_normalize_arrays(fn: etils.enp.checking._Fn = None, *, strict: bool = True) etils.enp.checking._Fn

Check and normalize arrays.

This function:

  • Validate that the dtype/shape input arrays match the typing annotations

  • Normalize np, jnp, tf types to be consistent

  • Add an optional xnp argument to convert input arrays to np/jnp/tnp.

See doc at: google/etils

Example:

@enp.check_and_normalize_arrays(strict=False)
def add(x: FloatArray[...], y: FloatArray[...]) -> y: FloatArray[...]:
  return x + y

# Inside the function, `np` normalized to `jnp`
add(np.array(1.), jnp.array(2.)) == jnp.array(3.)

# strict=False, so `list` accepted and normalized to `xnp`
add(jnp.array(1.), [1., 2., 3.]) == jnp.array([2., 3., 4.])
Parameters:
  • fn – The function to decorate. Arguments will be automatically infered.

  • strict – If False, fn will also accept list, int,… in which case those are automatically converted to xnp

Returns:

The decorated function, with dynamic shape checking

Return type:

fn