Source code for flax_extra.combinator._concatenate
r"""The concatenate combinator."""
from typing import Callable, cast
from redex.function import Signature
from redex.stack import stackmethod, verify_stack_size, Stack
from redex.combinator._base import Combinator
from jax import numpy as jnp
Array = jnp.ndarray
# pylint: disable=too-few-public-methods
[docs]class Concatenate(Combinator):
r"""The concatenate combinator."""
axis: int
r"""an axis on which to concatenate arrays."""
@stackmethod
def __call__(self, stack: Stack) -> Stack:
verify_stack_size(self, stack, self.signature)
n_in = self.signature.n_in
concatenated = jnp.concatenate(stack[:n_in], self.axis)
return (concatenated, *stack[self.signature.n_in :])
[docs]def concatenate(axis: int = -1, n_in: int = 2) -> Callable[..., Array]:
r"""Creates a concatenate combinator.
Concatenates input arrays on desired axis.
>>> from jax import numpy as jnp
>>> from flax_extra import combinator as cb
>>> concatenate = cb.concatenate()
>>> concatenate(jnp.array([1, 2]), jnp.array([3]))
DeviceArray([1, 2, 3], dtype=int32)
Args:
n_in: a number of inputs.
Returns:
a combinator.
"""
return cast(
Callable[..., Array],
Concatenate(axis=axis, signature=Signature(n_in=n_in, n_out=1)),
)