Numpy utils#

Code that works with np.array, jnp.ndarray, tf.Tensor#

Use enp.get_np_module(t) to write function which works with both tf, jax and numpy:

def my_function(array: Array):
  xnp = enp.get_np_module(array)
  return xnp.sum(array) + 1


my_function(tf.constant([1, 2]))  # Returns tf.Tensor
my_function(np.array([1, 2]))  # Returns np.ndarray
my_function(jnp.array([1, 2]))  # Returns jnp.ndarray

Runtime dtype/shape checking#

See array_types.

Interpolation util#

enp.interp linearly scale an array. API is: np.interp(array, from_=(min, max), to=(min, max))

  • Each dimension in the axis can be scaled by a different factor (broadcasting).

  • Values outside the boundaries are extrapolated.

  • Support np, jnp, tnp

Examples:

  • Normalize np.uint8 image to np.float32:

    img = enp.interp(img, (0, 255), (-1, 1))
    
  • Converting normalized 3d coordinates to world coordinates:

    coords = enp.interp(coords, from_=(-1, 1), to=(0, (h, w, d)))
    
    • coords[:, 0] is interpolated from (-1, 1) to (0, h)

    • coords[:, 1] is interpolated from (-1, 1) to (0, w)

    • coords[:, 2] is interpolated from (-1, 1) to (0, d)