flax_extra.util module

General utility functions.

flax_extra.util.batch_per_device(inputs: jax._src.numpy.lax_numpy.ndarray, n_devices: Optional[int] = None) jax._src.numpy.lax_numpy.ndarray[source]

Splits the head axis of an array evenly across the number of devices.

The function changes input array shape as follows:

(head, *tail) -> (n_devices, reduced_head // n_devices, *tail)
Parameters
  • inputs – an array.

  • n_devices – number of devices. Defaults to all available devices.

Returns

an array with items of the leading axis mapped to the number of devices.

flax_extra.util.originate(tree: Any) Any[source]

Originate replicated pytree object.

Assuming that each replica of the (per device) replicated pytree is representing a copy of the origin pytree, the function simply retrieves the first replica of the replicated pytree.

Parameters

tree – replicated pytree.

Returns

a regular pytree.