Source code for flax_extra.layer._fourier_positional_encoding

r"""Fourier positional encoding."""

from typing import Any, Callable, List, Sequence, Optional, Union
from functools import reduce
import operator
import jax.numpy as jnp
from jax.core import NamedShape
from jax.random import KeyArray
from flax import linen as nn

Array = jnp.ndarray
Shape = Optional[Union[Sequence[int], NamedShape]]
Dtype = Any
Positions = List[int]
InitFn = Callable[[KeyArray, Shape, Dtype], Array]


def _build_fourier_encodings(
    spatial_encodings: Array,
    seqshape: tuple[int, ...],
    n_bands: int,
    use_sine_only: bool = False,
) -> Array:
    r"""Generates Fourier positional encoding vectors with linear spacing.

    The function maps spatial coordinates to the surface of a higher dimensional
    hypersphere with a set of sinusoids.

    Args:
        spatial_encodings: an unbatched vectors representing multi-dimensional
            spatial positions, :math:`\in \sR^{\nSeqLen \times d}`.
        seqshape: a shape of the sequence.
        n_bands: a number of frequency bands to use.
        use_sine_only: whether to use a single phase (i.e. sine only) or two phase
            (i.e. sine and cosine) for each frequency band.

    Returns:
        an unbatched vectors of Fourier features,
            :math:`\sR^{\nSeqLen \times d}`.
    """
    seqlen, d_spatial_encoding = spatial_encodings.shape

    min_frequency = 1.0
    frequency_bands = jnp.stack(
        [
            jnp.linspace(min_frequency, seqlen / 2.0, num=n_bands, endpoint=True)
            for seqlen in seqshape
        ],
        axis=0,
    )

    scaled_frequency_bands = spatial_encodings[:, :, None] * frequency_bands[None, :, :]
    d_encoding = reduce(operator.mul, scaled_frequency_bands.shape[1:])
    frequencies = jnp.reshape(scaled_frequency_bands, (-1, d_encoding))
    assert (
        d_spatial_encoding * n_bands == d_encoding
    ), f"An invalid dimention of the Fourier positional encoding {d_encoding}."

    if use_sine_only:
        encodings = jnp.sin(jnp.pi * frequencies)
        assert encodings.shape == (
            seqlen,
            d_spatial_encoding * n_bands,
        ), f"An invalid shape of the sine-only Fourier positional encodings {encodings.shape}."
    else:
        encodings = jnp.concatenate(
            [
                jnp.sin(jnp.pi * frequencies),
                jnp.cos(jnp.pi * frequencies),
            ],
            axis=-1,
        )
        assert encodings.shape == (
            seqlen,
            2 * d_spatial_encoding * n_bands,
        ), f"An invalid shape of the sine-cosine Fourier positional encodings {encodings.shape}."

    return encodings  # type: ignore


def _build_spatial_positional_encodings(
    seqshape: tuple[Any, ...],
    bounds: tuple[float, float],
) -> Array:
    r"""Generates encoding vectors for multi-dimensional spatial positions.

    A positional encoding vector at each position may be seen as a
    multi-dimensional representation of the position across all spatial
    dimensions (or similarly, as a coordinate in the
    :math:`[\textrm{lower_bound} \times \textrm{upper_bound}]^{d}` space).

    Args:
        seqshape: a shape of the sequence.
        lower_bound: a minimal value for a position.
        upper_bound: a maximum value for a position.

    Returns:
        an unbatched vectors representing multi-dimensional spatial positions,
            :math:`\sR^{\nSeqLen \times d}`.
    """
    lower_bound, upper_bound = bounds
    points = [
        jnp.linspace(
            lower_bound,
            upper_bound,
            num=seqlen,
            endpoint=True,
            dtype=jnp.float32,
        )
        for seqlen in seqshape
    ]
    coordinate_grids = jnp.meshgrid(*points, indexing="ij")
    coordinates = jnp.stack(coordinate_grids, axis=-1)
    assert coordinates.shape == (*seqshape, len(seqshape)), (
        f"An invalid shape of the coordinate grid {coordinates.shape} "
        "for spatial positional encoding."
    )

    seqlen = reduce(operator.mul, seqshape)
    coordinates = jnp.reshape(coordinates, (seqlen, -1))
    return coordinates  # type: ignore


def _scale_spatial_positional_coordinates(
    coordinates: Array,
    seqshape: tuple[int, ...],
    bounds: tuple[float, float],
) -> Array:
    lower_bound, upper_bound = bounds
    distance = abs(lower_bound) + abs(upper_bound)
    return lower_bound + coordinates * distance / jnp.array(seqshape)  # type: ignore


def _to_spatial_positional_encodings(
    output_positions: Positions,
    seqshape: tuple[int, ...],
    bounds: tuple[float, float],
) -> Array:
    return _scale_spatial_positional_coordinates(
        jnp.stack(jnp.unravel_index(output_positions, shape=seqshape), axis=1),
        seqshape=seqshape,
        bounds=bounds,
    )


[docs]class FourierPositionEncoding(nn.Module): r"""Computes a vector of Fourier (or sinusoidal) features that may be seen as a multi-dimensional representation of the position across all its spatial dimensions (or similarly, as a single encoding of different time axes). .. math:: \begin{aligned} & \textrm{FourierPositionEncoding}( \\ & \quad m \in \sN \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times d} \end{aligned} \begin{aligned} & \textrm{FourierPositionEncoding}( \\ & \quad m \in \sN \\ & \quad t \in \sN^{\nSeqLen^{\prime}} \\ & ) \\ & \rightarrow \sR^{\nBatchSize \times T \times T^{\prime} \times d} \end{aligned} Output dimension may vary depending on parameters of the module: - For a single phase (i.e. sine only), `use_sine_only=True`. - With spatial coordinates get concatenated: :math:`d = n_{bands} \cdot len(seqshape) + len(seqshape)` - Otherwise: :math:`d = n_{bands} \cdot len(seqshape)` - For two phase (i.e. sine and cosine) for each frequency band. - With spatial coordinates get concatenated: :math:`d = n_{bands} \cdot len(seqshape) \cdot 2 + len(seqshape)` - Otherwise: :math:`d = n_{bands} \cdot len(seqshape) \cdot 2` Args: batch_size: a batch size of the sequence. output_positions: a subset of positions (i.e. time steps) within the sequence positional encoding will be calculated. Returns: positional encoding vectors. """ seqshape: tuple[int, ...] r"""a shape of the sequence.""" n_bands: int r"""a number of frequency bands to use.""" spatial_coordinate_bounds: tuple[float, float] = (-1.0, 1.0) r"""lower and upper bound for spatial coordinates.""" use_spatial_coordinates: bool = True r"""whether to concatenate the spatial coordinates to the Fourier features.""" use_sine_only: bool = False r"""whether to use a single phase (i.e. sine only) or two phase (i.e. sine and cosine) for each frequency band.""" @nn.compact def __call__( # type: ignore[override] # pylint: disable=arguments-differ self, batch_size: int, output_positions: Optional[Positions] = None, ) -> Array: if output_positions is None: spatial_encodings = _build_spatial_positional_encodings( seqshape=self.seqshape, bounds=self.spatial_coordinate_bounds, ) else: spatial_encodings = _to_spatial_positional_encodings( output_positions=output_positions, seqshape=self.seqshape, bounds=self.spatial_coordinate_bounds, ) fourier_encodings = _build_fourier_encodings( spatial_encodings, n_bands=self.n_bands, seqshape=self.seqshape, use_sine_only=self.use_sine_only, ) if self.use_spatial_coordinates: fourier_encodings = jnp.concatenate( [spatial_encodings, fourier_encodings], axis=-1, ) fourier_encodings = jnp.broadcast_to( fourier_encodings[None], (batch_size, *fourier_encodings.shape), ) return fourier_encodings