flax_extra.operator package

Operator is just a function.

class flax_extra.operator.Rearrange(pattern: str, bindings: Mapping[str, int] = <factory>)[source]

Bases: object

Rearranges the shape of an array according to the pattern.

>>> from jax import numpy as jnp
>>> from flax_extra import operator as op
>>> rearrange = op.Rearrange(pattern="(a da) db -> a (da db)", bindings=dict(da=2))
>>> rearrange(jnp.ones((4,3))).shape
(2, 6)
bindings: Mapping[str, int]

bindings for dimensions specified in the pattern.

pattern: str

a rearrangement pattern.

class flax_extra.operator.ReshapeBatch(shape: tuple[int, ...])[source]

Bases: object

Changes the shape of an array preserving its batch dimension.

>>> from jax import numpy as jnp
>>> from flax_extra import operator as op
>>> reshape = op.ReshapeBatch(shape=(1, 6))
>>> reshape(jnp.ones((2, 2, 3))).shape
(2, 1, 6)
shape: tuple[int, ...]

an array shape excluding its batch dimension.

class flax_extra.operator.ShiftRight(n_positions: int = 1, pad_id: int = 0, axis: int = 0)[source]

Bases: object

Inserts padding to shift the sequence.

>>> from jax import numpy as jnp
>>> from flax_extra import operator as op
>>> shift_right = op.ShiftRight(axis=0, n_positions=1)
>>> shift_right(jnp.array([1, 2, 3]))
DeviceArray([0, 1, 2], dtype=int32)
axis: int = 0

the operation will be performed along this axis.

n_positions: int = 1

a number of positions to shift.

pad_id: int = 0

a padding identifier to insert.