enp.compat

enp.compat#

[[Source]]

Compat utils between TF/Torch/Numpy/Jax.

Currently, each numpy API has slightly different behavior. Those functions ensure compatibility so that the code works seamlessly between all APIs.

In the future, those functions could be deleted and replaced by the official numpy API.

Symbols#

Function#

enp.compat.astype

x.astype(dtype).

enp.compat.concat

xnp.concatenate(x, axis=axis).

enp.compat.det

Like np.linalg.det but auto-support jnp, tnp, np.

enp.compat.dtype_np_to_torch

Returns the torch dtype for the given numpy dtype.

enp.compat.dtype_torch_to_np

Returns the numpy dtype for the given torch dtype.

enp.compat.expand_dims

xnp.expand_dims(x, axis=axis).

enp.compat.inv

Like np.linalg.inv but auto-support jnp, tnp, np.

enp.compat.is_array_xnp

isinstance(x, xnp.Array).

enp.compat.norm

Like np.linalg.norm but auto-support jnp, tnp, np.

enp.compat.round

x.round() for jnp, tnp, np, otrch.

Attribute#

enp.compat.lazy