enp.compat
[[Source]]
Compat utils between TF/Torch/Numpy/Jax.
Currently, each numpy API has slightly different behavior. Those functions
ensure compatibility so that the code works seamlessly between all APIs.
In the future, those functions could be deleted and replaced by the
official numpy API.
Symbols
Function
|
|
enp.compat.astype |
x.astype(dtype) .
|
enp.compat.concat |
xnp.concatenate(x, axis=axis) .
|
enp.compat.det |
Like np.linalg.det but auto-support jnp, tnp, np. |
enp.compat.dtype_np_to_torch |
Returns the torch dtype for the given numpy dtype. |
enp.compat.dtype_torch_to_np |
Returns the numpy dtype for the given torch dtype. |
enp.compat.expand_dims |
xnp.expand_dims(x, axis=axis) .
|
enp.compat.inv |
Like np.linalg.inv but auto-support jnp, tnp, np. |
enp.compat.is_array_xnp |
isinstance(x, xnp.Array) .
|
enp.compat.norm |
Like np.linalg.norm but auto-support jnp, tnp, np. |
enp.compat.round |
x.round() for jnp, tnp, np, otrch.
|