Source code for flax_extra.operator._shift_right
r"""The shift-right operator."""
from typing import List
from dataclasses import dataclass
from jax import numpy as jnp
Array = jnp.ndarray
[docs]@dataclass
class ShiftRight:
r"""Inserts padding to shift the sequence.
>>> from jax import numpy as jnp
>>> from flax_extra import operator as op
>>> shift_right = op.ShiftRight(axis=0, n_positions=1)
>>> shift_right(jnp.array([1, 2, 3]))
DeviceArray([0, 1, 2], dtype=int32)
"""
n_positions: int = 1
r"""a number of positions to shift."""
pad_id: int = 0
r"""a padding identifier to insert."""
axis: int = 0
r"""the operation will be performed along this axis."""
def __call__(self, inputs: Array) -> Array:
def pad_width() -> List[tuple[int, int]]:
acc = [(0, 0)] * len(inputs.shape)
acc[self.axis] = (self.n_positions, 0)
return acc
padded = jnp.pad(
inputs,
pad_width=pad_width(),
mode="constant",
constant_values=self.pad_id,
)
return jnp.take( # type: ignore
padded,
jnp.arange(padded.shape[self.axis] - self.n_positions),
axis=self.axis,
)