Source code for flax_extra.operator._reshape

r"""The array reshaping operator."""

from dataclasses import dataclass
from jax import numpy as jnp
from flax import linen as nn

Array = jnp.ndarray


[docs]@dataclass class ReshapeBatch: r"""Changes the shape of an array preserving its batch dimension. >>> from jax import numpy as jnp >>> from flax_extra import operator as op >>> reshape = op.ReshapeBatch(shape=(1, 6)) >>> reshape(jnp.ones((2, 2, 3))).shape (2, 1, 6) """ shape: tuple[int, ...] r"""an array shape excluding its batch dimension.""" @nn.compact def __call__(self, inputs: Array) -> Array: batch_size = inputs.shape[0] return jnp.reshape(inputs, (batch_size, *self.shape)) # type: ignore