flax_extra.combinator package

Combinator functions compose other functions.

class flax_extra.combinator.Combinator(signature: redex.function.Signature)[source]

Bases: redex.function.FineCallable

The base class for combinators.

signature: redex.function.Signature

a signature of the function.

class flax_extra.combinator.Concatenate(signature: redex.function.Signature, axis: int)[source]

Bases: redex.combinator._base.Combinator

The concatenate combinator.

axis: int

an axis on which to concatenate arrays.

class flax_extra.combinator.Drop(signature: redex.function.Signature)[source]

Bases: redex.combinator._base.Combinator

The drop combinator.

signature: redex.function.Signature

a signature of the function.

class flax_extra.combinator.Dup(signature: redex.function.Signature)[source]

Bases: redex.combinator._base.Combinator

The duplicate combinator.

signature: redex.function.Signature

a signature of the function.

class flax_extra.combinator.Foldl(signature: redex.function.Signature, func: Callable[[Any, Any], Any])[source]

Bases: redex.combinator._base.Combinator

The left folding combinator.

func: Callable[[Any, Any], Any]

a function of two arguments.

class flax_extra.combinator.Identity(signature: redex.function.Signature)[source]

Bases: redex.combinator._base.Combinator

The identity combinator.

signature: redex.function.Signature

a signature of the function.

class flax_extra.combinator.Parallel(signature: redex.function.Signature, children: List[Callable[[...], Any]], children_signatures: List[redex.function.Signature])[source]

Bases: redex.combinator._base.Combinator

The parallel combinator.

children: List[Callable[[...], Any]]

composite functions.

children_signatures: List[redex.function.Signature]

signatures of the composite functions.

class flax_extra.combinator.Select(signature: redex.function.Signature, indices: List[int])[source]

Bases: redex.combinator._base.Combinator

The select combinator.

indices: List[int]

a sequence of 0-based indices relative to the top of the stack.

class flax_extra.combinator.Serial(signature: redex.function.Signature, children: List[Callable[[...], Any]], children_signatures: List[redex.function.Signature])[source]

Bases: redex.combinator._base.Combinator

The serial combinator.

children: List[Callable[[...], Any]]

composite functions.

children_signatures: List[redex.function.Signature]

signatures of the composite functions.

flax_extra.combinator.add(n_in: int = 2) redex.combinator._fold.Foldl[source]

Creates an addition combinator.

>>> from redex import combinator as cb
>>> add = cb.add()
>>> add(4, 2) == 6
True
Parameters

n_in – a number of inputs.

Returns

a combinator.

flax_extra.combinator.branch(*children: Union[Callable[[...], Any], Iterable[Callable[[...], Any]]]) redex.combinator._serial.Serial[source]

Creates a branch combinator.

The combinator combines multiple branches of given functions and operate on copy of inputs. Each branch includes a sequence of functions applied serially.

>>> import operator as op
>>> from redex import combinator as cb
>>> branch = cb.branch(cb.serial(op.add, op.add), op.add)
>>> branch(1, 2, 3) == (1 + 2 + 3, 1 + 2)
True
Parameters

children – a sequence of functions.

Returns

a combinator.

flax_extra.combinator.concatenate(axis: int = - 1, n_in: int = 2) Callable[[...], jax._src.numpy.lax_numpy.ndarray][source]

Creates a concatenate combinator.

Concatenates input arrays on desired axis.

>>> from jax import numpy as jnp
>>> from flax_extra import combinator as cb
>>> concatenate = cb.concatenate()
>>> concatenate(jnp.array([1, 2]), jnp.array([3]))
DeviceArray([1, 2, 3], dtype=int32)
Parameters

n_in – a number of inputs.

Returns

a combinator.

flax_extra.combinator.div(n_in: int = 2) redex.combinator._fold.Foldl[source]

Creates a division combinator.

>>> from redex import combinator as cb
>>> div = cb.div()
>>> div(4, 2) == 2
True
Parameters

n_in – a number of inputs.

Returns

a combinator.

flax_extra.combinator.drop(n_in: int = 1) redex.combinator._drop.Drop[source]

Creates a drop combinator.

Drops some of inputs.

>>> from redex import combinator as cb
>>> drop = cb.drop()
>>> drop(1, 2) == 2
True
Parameters

n_in – a number of inputs.

Returns

a combinator.

flax_extra.combinator.dup(n_in: int = 1) redex.combinator._dup.Dup[source]

Creates a duplicate combinator.

The combinator makes a copy of inputs.

>>> from redex import combinator as cb
>>> dup = cb.dup()
>>> dup(1) == (1, 1)
True
Parameters

n_in – a number of inputs.

Returns

a combinator.

flax_extra.combinator.fold(func: Callable[[Any, Any], Any], n_in: int = 2) redex.combinator._fold.Foldl

Creates a left folding combinator.

flax_extra.combinator.foldl(func: Callable[[Any, Any], Any], n_in: int = 2) redex.combinator._fold.Foldl[source]

Creates a left folding combinator.

flax_extra.combinator.identity(n_in: int = 1) redex.combinator._identity.Identity[source]

Always returns the same values that were used as arguments.

>>> from redex import combinator as cb
>>> identity = cb.identity()
>>> identity(1) == 1
True
Parameters

n_in – a number of inputs.

Returns

a combinator.

flax_extra.combinator.mul(n_in: int = 2) redex.combinator._fold.Foldl[source]

Creates a multiplication combinator.

>>> from redex import combinator as cb
>>> mul = cb.mul()
>>> mul(4, 2) == 8
True
Parameters

n_in – a number of inputs.

Returns

a combinator.

flax_extra.combinator.parallel(*children: Union[Callable[[...], Any], Iterable[Callable[[...], Any]]]) redex.combinator._parallel.Parallel[source]

Creates a parallel combinator.

The combinator applies functions in parallel to its inputs. Each function consumes a span of inputs. The span sizes are determined by a number of required arguments of these functions.

>>> import operator as op
>>> from redex import combinator as cb
>>> parallel = cb.parallel(op.add, op.add)
>>> parallel(1, 2, 3, 4) == (1 + 2, 3 + 4)
True
Parameters

children – a sequence of functions.

Returns

a combinator.

flax_extra.combinator.residual(*children: Union[Callable[[...], Any], Iterable[Callable[[...], Any]]], shortcut: Callable[[...], Any] = Identity(signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),)))) redex.combinator._serial.Serial[source]

Creates a residual combinator.

The combinator computes the sum of two branches: main and shortcut.

>>> from redex import combinator as cb
>>> residual = cb.residual(cb.serial(cb.add(), cb.add()))
>>> residual(1, 2, 3) == 1 + 2 + 3 + 1
True
Parameters
  • children – a main sequence of functions.

  • shortcut – a skip connection. Defaults to identity function.

Returns

a combinator.

flax_extra.combinator.select(indices: List[int], n_in: Optional[int] = None) redex.combinator._select.Select[source]

Creates a select combinator.

The combinator allows to change order or copy inputs.

>>> from redex import combinator as cb
>>> select = cb.select([0, 0, 1, 1])
>>> select(1, 2, 3, 4) == (1, 1, 2, 2, 3, 4)
True
Parameters
  • indices – a sequence of 0-based indices relative to the top of the stack.

  • n_in – a number of items to pop from the stack, and replace with those specified by indices. If not specified, the value will be calculated as max(indices) + 1.

Returns

a combinator.

flax_extra.combinator.serial(*children: Union[Callable[[...], Any], Iterable[Callable[[...], Any]]]) redex.combinator._serial.Serial[source]

Creates a serial combinator.

The combinator applies functions in series (function composition).

>>> import operator as op
>>> from redex import combinator as cb
>>> serial = cb.serial(op.add, op.add, op.add)
>>> serial(1, 2, 3, 4) == 1 + 2 + 3 + 4
True
Parameters

children – a sequence of functions.

Returns

a combinator.

flax_extra.combinator.sub(n_in: int = 2) redex.combinator._fold.Foldl[source]

Creates an subtraction combinator.

>>> from redex import combinator as cb
>>> sub = cb.sub()
>>> sub(4, 2) == 2
True
Parameters

n_in – a number of inputs.

Returns

a combinator.