Source code for flax_extra.random
r"""Random number generation."""
from typing import List, Mapping, Generator, Optional, Union
import jax
from jax import random, numpy as jnp
from jax.random import KeyArray
Array = jnp.ndarray
KeyGenerator = Generator[KeyArray, None, None]
[docs]def split_per_device(key: KeyArray, n_devices: Optional[int] = None) -> KeyArray:
r"""Splits a random number generator key according to the number of devices."""
if n_devices is None:
n_devices = jax.local_device_count()
return jax.random.split(key, num=n_devices)
[docs]def into_collection(key: KeyArray, labels: List[str]) -> Mapping[str, KeyArray]:
r"""Splits a random number generator key into a few.
New keys are associated with provided collection labels.
Args:
labels: a collection of labels.
key: an initial random number generator key.
Returns:
a dictionary with random number generator keys as values
and collection labels as keys.
"""
keys = {labels[i]: k for i, k in enumerate(random.split(key=key, num=len(labels)))}
return keys
[docs]def into_sequence(key: KeyArray) -> KeyGenerator:
r"""Creates a generator of random number generator keys.
Args:
key: an initial generator key.
Yields:
a new random number generator key.
"""
initial_key = key
while True:
initial_key, key = random.split(key=initial_key, num=2)
yield key
[docs]def sequence(seed: int, num: int = 1) -> Union[KeyGenerator, List[KeyGenerator]]:
r"""Creates generators of random number generator keys.
.. code-block:: python
from flax_extra import random
rnkeyg = random.sequence(seed=0)
next(rnkeyg)
Args:
seed: a seed integer to create an initial generator.
num: the number of sequences to produce.
Yields:
a new random number generator key.
"""
initial_key = random.PRNGKey(seed=seed)
if num == 1:
return into_sequence(key=initial_key)
keys = random.split(key=initial_key, num=num)
return [into_sequence(key) for key in keys]