Source code for flax_extra.util
r"""General utility functions."""
from typing import Any, Optional
import jax
from jax import numpy as jnp
Array = jnp.ndarray
ArrayTree = Any
[docs]def originate(tree: ArrayTree) -> ArrayTree:
r"""Originate replicated pytree object.
Assuming that each replica of the (per device) replicated pytree is
representing a copy of the origin pytree, the function simply retrieves
the first replica of the replicated pytree.
Args:
tree: replicated pytree.
Returns:
a regular pytree.
"""
return jax.tree_map(lambda x: x[0], tree)
[docs]def batch_per_device(inputs: Array, n_devices: Optional[int] = None) -> Array:
r"""Splits the head axis of an array evenly across the number of devices.
The function changes input array shape as follows:
::
(head, *tail) -> (n_devices, reduced_head // n_devices, *tail)
Args:
inputs: an array.
n_devices: number of devices. Defaults to all available devices.
Returns:
an array with items of the leading axis mapped to the number of devices.
"""
if n_devices is None:
n_devices = jax.local_device_count()
head, *tail = inputs.shape
return jnp.reshape(inputs, (n_devices, head // n_devices, *tail)) # type: ignore