flax_extra.util module¶
General utility functions.
- flax_extra.util.batch_per_device(inputs: jax._src.numpy.lax_numpy.ndarray, n_devices: Optional[int] = None) jax._src.numpy.lax_numpy.ndarray[source]¶
Splits the head axis of an array evenly across the number of devices.
The function changes input array shape as follows:
(head, *tail) -> (n_devices, reduced_head // n_devices, *tail)
- Parameters
inputs – an array.
n_devices – number of devices. Defaults to all available devices.
- Returns
an array with items of the leading axis mapped to the number of devices.
- flax_extra.util.originate(tree: Any) Any[source]¶
Originate replicated pytree object.
Assuming that each replica of the (per device) replicated pytree is representing a copy of the origin pytree, the function simply retrieves the first replica of the replicated pytree.
- Parameters
tree – replicated pytree.
- Returns
a regular pytree.