flax_extra.random module

Random number generation.

flax_extra.random.into_collection(key: jax._src.prng.PRNGKeyArray, labels: List[str]) Mapping[str, jax._src.prng.PRNGKeyArray][source]

Splits a random number generator key into a few. New keys are associated with provided collection labels.

Parameters
  • labels – a collection of labels.

  • key – an initial random number generator key.

Returns

a dictionary with random number generator keys as values and collection labels as keys.

flax_extra.random.into_sequence(key: jax._src.prng.PRNGKeyArray) Generator[jax._src.prng.PRNGKeyArray, None, None][source]

Creates a generator of random number generator keys.

Parameters

key – an initial generator key.

Yields

a new random number generator key.

flax_extra.random.sequence(seed: int, num: int = 1) Union[Generator[jax._src.prng.PRNGKeyArray, None, None], List[Generator[jax._src.prng.PRNGKeyArray, None, None]]][source]

Creates generators of random number generator keys.

from flax_extra import random
rnkeyg = random.sequence(seed=0)
next(rnkeyg)
Parameters
  • seed – a seed integer to create an initial generator.

  • num – the number of sequences to produce.

Yields

a new random number generator key.

flax_extra.random.split_per_device(key: jax._src.prng.PRNGKeyArray, n_devices: Optional[int] = None) jax._src.prng.PRNGKeyArray[source]

Splits a random number generator key according to the number of devices.