enp.check_and_normalize_arrays#
- etils.enp.check_and_normalize_arrays(
- fn: None = None,
- *,
- strict: bool = True,
- etils.enp.check_and_normalize_arrays(
- fn: etils.enp.checking._Fn = None,
- *,
- strict: bool = True,
Check and normalize arrays.
This function:
Validate that the dtype/shape input arrays match the typing annotations
Normalize np, jnp, tf types to be consistent
Add an optional xnp argument to convert input arrays to np/jnp/tnp.
See doc at: google/etils
Example:
@enp.check_and_normalize_arrays(strict=False) def add(x: FloatArray[...], y: FloatArray[...]) -> y: FloatArray[...]: return x + y # Inside the function, `np` normalized to `jnp` add(np.array(1.), jnp.array(2.)) == jnp.array(3.) # strict=False, so `list` accepted and normalized to `xnp` add(jnp.array(1.), [1., 2., 3.]) == jnp.array([2., 3., 4.])
Example with an explicit xnp argument, which automatically passed to the wrapped function:
@enp.check_and_normalize_arrays(strict=False) def clip(x: Array, *, xnp: enp.NpModule = ...) -> y: Array: return xnp.clip(x, 0.0, 1.0)
If present in the function signature, the xnp argument must be a keyword-only argument.
- Parameters:
fn – The function to decorate. Arguments will be automatically infered.
strict – If False, fn will also accept list, int,… in which case those are automatically converted to xnp
- Returns:
The decorated function, with dynamic shape checking
- Return type:
fn