flax_extra.combinator package¶
Combinator functions compose other functions.
- class flax_extra.combinator.Combinator(signature: redex.function.Signature)[source]¶
Bases:
redex.function.FineCallableThe base class for combinators.
- signature: redex.function.Signature¶
a signature of the function.
- class flax_extra.combinator.Concatenate(signature: redex.function.Signature, axis: int)[source]¶
Bases:
redex.combinator._base.CombinatorThe concatenate combinator.
- axis: int¶
an axis on which to concatenate arrays.
- class flax_extra.combinator.Drop(signature: redex.function.Signature)[source]¶
Bases:
redex.combinator._base.CombinatorThe drop combinator.
- signature: redex.function.Signature¶
a signature of the function.
- class flax_extra.combinator.Dup(signature: redex.function.Signature)[source]¶
Bases:
redex.combinator._base.CombinatorThe duplicate combinator.
- signature: redex.function.Signature¶
a signature of the function.
- class flax_extra.combinator.Foldl(signature: redex.function.Signature, func: Callable[[Any, Any], Any])[source]¶
Bases:
redex.combinator._base.CombinatorThe left folding combinator.
- func: Callable[[Any, Any], Any]¶
a function of two arguments.
- class flax_extra.combinator.Identity(signature: redex.function.Signature)[source]¶
Bases:
redex.combinator._base.CombinatorThe identity combinator.
- signature: redex.function.Signature¶
a signature of the function.
- class flax_extra.combinator.Parallel(signature: redex.function.Signature, children: List[Callable[[...], Any]], children_signatures: List[redex.function.Signature])[source]¶
Bases:
redex.combinator._base.CombinatorThe parallel combinator.
- children: List[Callable[[...], Any]]¶
composite functions.
- children_signatures: List[redex.function.Signature]¶
signatures of the composite functions.
- class flax_extra.combinator.Select(signature: redex.function.Signature, indices: List[int])[source]¶
Bases:
redex.combinator._base.CombinatorThe select combinator.
- indices: List[int]¶
a sequence of 0-based indices relative to the top of the stack.
- class flax_extra.combinator.Serial(signature: redex.function.Signature, children: List[Callable[[...], Any]], children_signatures: List[redex.function.Signature])[source]¶
Bases:
redex.combinator._base.CombinatorThe serial combinator.
- children: List[Callable[[...], Any]]¶
composite functions.
- children_signatures: List[redex.function.Signature]¶
signatures of the composite functions.
- flax_extra.combinator.add(n_in: int = 2) redex.combinator._fold.Foldl[source]¶
Creates an addition combinator.
>>> from redex import combinator as cb >>> add = cb.add() >>> add(4, 2) == 6 True
- Parameters
n_in – a number of inputs.
- Returns
a combinator.
- flax_extra.combinator.branch(*children: Union[Callable[[...], Any], Iterable[Callable[[...], Any]]]) redex.combinator._serial.Serial[source]¶
Creates a branch combinator.
The combinator combines multiple branches of given functions and operate on copy of inputs. Each branch includes a sequence of functions applied serially.
>>> import operator as op >>> from redex import combinator as cb >>> branch = cb.branch(cb.serial(op.add, op.add), op.add) >>> branch(1, 2, 3) == (1 + 2 + 3, 1 + 2) True
- Parameters
children – a sequence of functions.
- Returns
a combinator.
- flax_extra.combinator.concatenate(axis: int = - 1, n_in: int = 2) Callable[[...], jax._src.numpy.lax_numpy.ndarray][source]¶
Creates a concatenate combinator.
Concatenates input arrays on desired axis.
>>> from jax import numpy as jnp >>> from flax_extra import combinator as cb >>> concatenate = cb.concatenate() >>> concatenate(jnp.array([1, 2]), jnp.array([3])) DeviceArray([1, 2, 3], dtype=int32)
- Parameters
n_in – a number of inputs.
- Returns
a combinator.
- flax_extra.combinator.div(n_in: int = 2) redex.combinator._fold.Foldl[source]¶
Creates a division combinator.
>>> from redex import combinator as cb >>> div = cb.div() >>> div(4, 2) == 2 True
- Parameters
n_in – a number of inputs.
- Returns
a combinator.
- flax_extra.combinator.drop(n_in: int = 1) redex.combinator._drop.Drop[source]¶
Creates a drop combinator.
Drops some of inputs.
>>> from redex import combinator as cb >>> drop = cb.drop() >>> drop(1, 2) == 2 True
- Parameters
n_in – a number of inputs.
- Returns
a combinator.
- flax_extra.combinator.dup(n_in: int = 1) redex.combinator._dup.Dup[source]¶
Creates a duplicate combinator.
The combinator makes a copy of inputs.
>>> from redex import combinator as cb >>> dup = cb.dup() >>> dup(1) == (1, 1) True
- Parameters
n_in – a number of inputs.
- Returns
a combinator.
- flax_extra.combinator.fold(func: Callable[[Any, Any], Any], n_in: int = 2) redex.combinator._fold.Foldl¶
Creates a left folding combinator.
- flax_extra.combinator.foldl(func: Callable[[Any, Any], Any], n_in: int = 2) redex.combinator._fold.Foldl[source]¶
Creates a left folding combinator.
- flax_extra.combinator.identity(n_in: int = 1) redex.combinator._identity.Identity[source]¶
Always returns the same values that were used as arguments.
>>> from redex import combinator as cb >>> identity = cb.identity() >>> identity(1) == 1 True
- Parameters
n_in – a number of inputs.
- Returns
a combinator.
- flax_extra.combinator.mul(n_in: int = 2) redex.combinator._fold.Foldl[source]¶
Creates a multiplication combinator.
>>> from redex import combinator as cb >>> mul = cb.mul() >>> mul(4, 2) == 8 True
- Parameters
n_in – a number of inputs.
- Returns
a combinator.
- flax_extra.combinator.parallel(*children: Union[Callable[[...], Any], Iterable[Callable[[...], Any]]]) redex.combinator._parallel.Parallel[source]¶
Creates a parallel combinator.
The combinator applies functions in parallel to its inputs. Each function consumes a span of inputs. The span sizes are determined by a number of required arguments of these functions.
>>> import operator as op >>> from redex import combinator as cb >>> parallel = cb.parallel(op.add, op.add) >>> parallel(1, 2, 3, 4) == (1 + 2, 3 + 4) True
- Parameters
children – a sequence of functions.
- Returns
a combinator.
- flax_extra.combinator.residual(*children: Union[Callable[[...], Any], Iterable[Callable[[...], Any]]], shortcut: Callable[[...], Any] = Identity(signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),)))) redex.combinator._serial.Serial[source]¶
Creates a residual combinator.
The combinator computes the sum of two branches: main and shortcut.
>>> from redex import combinator as cb >>> residual = cb.residual(cb.serial(cb.add(), cb.add())) >>> residual(1, 2, 3) == 1 + 2 + 3 + 1 True
- Parameters
children – a main sequence of functions.
shortcut – a skip connection. Defaults to identity function.
- Returns
a combinator.
- flax_extra.combinator.select(indices: List[int], n_in: Optional[int] = None) redex.combinator._select.Select[source]¶
Creates a select combinator.
The combinator allows to change order or copy inputs.
>>> from redex import combinator as cb >>> select = cb.select([0, 0, 1, 1]) >>> select(1, 2, 3, 4) == (1, 1, 2, 2, 3, 4) True
- Parameters
indices – a sequence of 0-based indices relative to the top of the stack.
n_in – a number of items to pop from the stack, and replace with those specified by indices. If not specified, the value will be calculated as max(indices) + 1.
- Returns
a combinator.
- flax_extra.combinator.serial(*children: Union[Callable[[...], Any], Iterable[Callable[[...], Any]]]) redex.combinator._serial.Serial[source]¶
Creates a serial combinator.
The combinator applies functions in series (function composition).
>>> import operator as op >>> from redex import combinator as cb >>> serial = cb.serial(op.add, op.add, op.add) >>> serial(1, 2, 3, 4) == 1 + 2 + 3 + 4 True
- Parameters
children – a sequence of functions.
- Returns
a combinator.
- flax_extra.combinator.sub(n_in: int = 2) redex.combinator._fold.Foldl[source]¶
Creates an subtraction combinator.
>>> from redex import combinator as cb >>> sub = cb.sub() >>> sub(4, 2) == 2 True
- Parameters
n_in – a number of inputs.
- Returns
a combinator.