flax_extra.random module¶
Random number generation.
- flax_extra.random.into_collection(key: jax._src.prng.PRNGKeyArray, labels: List[str]) Mapping[str, jax._src.prng.PRNGKeyArray][source]¶
Splits a random number generator key into a few. New keys are associated with provided collection labels.
- Parameters
labels – a collection of labels.
key – an initial random number generator key.
- Returns
a dictionary with random number generator keys as values and collection labels as keys.
- flax_extra.random.into_sequence(key: jax._src.prng.PRNGKeyArray) Generator[jax._src.prng.PRNGKeyArray, None, None][source]¶
Creates a generator of random number generator keys.
- Parameters
key – an initial generator key.
- Yields
a new random number generator key.
- flax_extra.random.sequence(seed: int, num: int = 1) Union[Generator[jax._src.prng.PRNGKeyArray, None, None], List[Generator[jax._src.prng.PRNGKeyArray, None, None]]][source]¶
Creates generators of random number generator keys.
from flax_extra import random rnkeyg = random.sequence(seed=0) next(rnkeyg)
- Parameters
seed – a seed integer to create an initial generator.
num – the number of sequences to produce.
- Yields
a new random number generator key.