enp.check_and_normalize_arrays#
- etils.enp.check_and_normalize_arrays(fn: None = None, *, strict: bool = True) Callable[[etils.enp.checking._Fn], etils.enp.checking._Fn] [source]#
- etils.enp.check_and_normalize_arrays(fn: etils.enp.checking._Fn = None, *, strict: bool = True) etils.enp.checking._Fn
Check and normalize arrays.
This function:
Validate that the dtype/shape input arrays match the typing annotations
Normalize np, jnp, tf types to be consistent
Add an optional xnp argument to convert input arrays to np/jnp/tnp.
See doc at: google/etils
Example:
@enp.check_and_normalize_arrays(strict=False) def add(x: FloatArray[...], y: FloatArray[...]) -> y: FloatArray[...]: return x + y # Inside the function, `np` normalized to `jnp` add(np.array(1.), jnp.array(2.)) == jnp.array(3.) # strict=False, so `list` accepted and normalized to `xnp` add(jnp.array(1.), [1., 2., 3.]) == jnp.array([2., 3., 4.])
- Parameters:
fn – The function to decorate. Arguments will be automatically infered.
strict – If False, fn will also accept list, int,… in which case those are automatically converted to xnp
- Returns:
The decorated function, with dynamic shape checking
- Return type:
fn