Source code for flax_extra.operator._rearrange
r"""The array rearranging operator."""
from dataclasses import dataclass, field
from typing import Mapping
from jax import numpy as jnp
from flax import linen as nn
import einops
Array = jnp.ndarray
[docs]@dataclass
class Rearrange:
r"""Rearranges the shape of an array according to the pattern.
>>> from jax import numpy as jnp
>>> from flax_extra import operator as op
>>> rearrange = op.Rearrange(pattern="(a da) db -> a (da db)", bindings=dict(da=2))
>>> rearrange(jnp.ones((4,3))).shape
(2, 6)
"""
pattern: str
r"""a rearrangement pattern."""
bindings: Mapping[str, int] = field(default_factory=dict)
r"""bindings for dimensions specified in the pattern."""
@nn.compact
def __call__(self, inputs: Array) -> Array:
return einops.rearrange(inputs, pattern=self.pattern, **self.bindings) # type: ignore