enp.check_and_normalize_arrays

enp.check_and_normalize_arrays#

etils.enp.check_and_normalize_arrays(
fn: None = None,
*,
strict: bool = True,
) Callable[[etils.enp.checking._Fn], etils.enp.checking._Fn][source]
etils.enp.check_and_normalize_arrays(
fn: etils.enp.checking._Fn = None,
*,
strict: bool = True,
) etils.enp.checking._Fn

Check and normalize arrays.

This function:

  • Validate that the dtype/shape input arrays match the typing annotations

  • Normalize np, jnp, tf types to be consistent

  • Add an optional xnp argument to convert input arrays to np/jnp/tnp.

See doc at: google/etils

Example:

@enp.check_and_normalize_arrays(strict=False)
def add(x: FloatArray[...], y: FloatArray[...]) -> y: FloatArray[...]:
  return x + y

# Inside the function, `np` normalized to `jnp`
add(np.array(1.), jnp.array(2.)) == jnp.array(3.)

# strict=False, so `list` accepted and normalized to `xnp`
add(jnp.array(1.), [1., 2., 3.]) == jnp.array([2., 3., 4.])

Example with an explicit xnp argument, which automatically passed to the wrapped function:

@enp.check_and_normalize_arrays(strict=False)
def clip(x: Array, *, xnp: enp.NpModule = ...) -> y: Array:
  return xnp.clip(x, 0.0, 1.0)

If present in the function signature, the xnp argument must be a keyword-only argument.

Parameters:
  • fn – The function to decorate. Arguments will be automatically infered.

  • strict – If False, fn will also accept list, int,… in which case those are automatically converted to xnp

Returns:

The decorated function, with dynamic shape checking

Return type:

fn