flax_extra.operator package¶
Operator is just a function.
- class flax_extra.operator.Rearrange(pattern: str, bindings: Mapping[str, int] = <factory>)[source]¶
Bases:
objectRearranges the shape of an array according to the pattern.
>>> from jax import numpy as jnp >>> from flax_extra import operator as op >>> rearrange = op.Rearrange(pattern="(a da) db -> a (da db)", bindings=dict(da=2)) >>> rearrange(jnp.ones((4,3))).shape (2, 6)
- bindings: Mapping[str, int]¶
bindings for dimensions specified in the pattern.
- pattern: str¶
a rearrangement pattern.
- class flax_extra.operator.ReshapeBatch(shape: tuple[int, ...])[source]¶
Bases:
objectChanges the shape of an array preserving its batch dimension.
>>> from jax import numpy as jnp >>> from flax_extra import operator as op >>> reshape = op.ReshapeBatch(shape=(1, 6)) >>> reshape(jnp.ones((2, 2, 3))).shape (2, 1, 6)
- shape: tuple[int, ...]¶
an array shape excluding its batch dimension.
- class flax_extra.operator.ShiftRight(n_positions: int = 1, pad_id: int = 0, axis: int = 0)[source]¶
Bases:
objectInserts padding to shift the sequence.
>>> from jax import numpy as jnp >>> from flax_extra import operator as op >>> shift_right = op.ShiftRight(axis=0, n_positions=1) >>> shift_right(jnp.array([1, 2, 3])) DeviceArray([0, 1, 2], dtype=int32)
- axis: int = 0¶
the operation will be performed along this axis.
- n_positions: int = 1¶
a number of positions to shift.
- pad_id: int = 0¶
a padding identifier to insert.