Source code for redex.combinator._residual

"""The residual combinator."""

from redex import util
from redex.function import Fn, FnIter
from redex import function as fn
from redex.combinator._serial import serial, Serial
from redex.combinator._branch import branch
from redex.combinator._identity import identity
from redex.combinator._fold import add


[docs]def residual(*children: FnIter, shortcut: Fn = identity(n_in=1)) -> Serial: """Creates a residual combinator. The combinator computes the sum of two branches: main and shortcut. >>> from redex import combinator as cb >>> residual = cb.residual(cb.serial(cb.add(), cb.add())) >>> residual(1, 2, 3) == 1 + 2 + 3 + 1 True Args: children: a main sequence of functions. shortcut: a skip connection. Defaults to identity function. Returns: a combinator. """ flat_children = util.flatten(children) if len(flat_children) == 1: grouped_children = flat_children[0] else: grouped_children = serial(*flat_children) grouped_children_signature = fn.infer_signature(grouped_children) if grouped_children_signature.n_out != 1: raise ValueError( "The main branch of the residual must output exactly one value. " f"`{fn.infer_name(grouped_children)}` outputs " f"`{grouped_children_signature.n_out}` values." ) shortcut_signature = fn.infer_signature(shortcut) if shortcut_signature.n_out != 1: raise ValueError( "The shortcut branch of the residual must output exactly one value. " f"`{fn.infer_name(shortcut)}` outputs `{shortcut_signature.n_out}` values." ) return serial( branch(grouped_children, shortcut), add(n_in=2), )