Source code for redex.combinator._parallel

"""The parallel combinator."""

from typing import List
from functools import reduce
from dataclasses import replace
from redex import util
from redex import function as fn
from redex.function import Fn, FnIter, Signature
from redex.stack import constrained_call, stackmethod, Stack
from redex.combinator._base import Combinator

# pylint: disable=too-few-public-methods


# pylint: disable=too-few-public-methods
[docs]class Parallel(Combinator): """The parallel combinator.""" children: List[Fn] """composite functions.""" children_signatures: List[Signature] """signatures of the composite functions.""" @stackmethod def __call__(self, stack: Stack) -> Stack: outputs = Stack() for i, child in enumerate(self.children): signature = self.children_signatures[i] n_lower, n_upper = signature.index_bounds outputs += constrained_call(child, stack[n_lower:n_upper], signature) return outputs + stack[self.signature.n_in :]
[docs]def parallel(*children: FnIter) -> Parallel: """Creates a parallel combinator. The combinator applies functions in parallel to its inputs. Each function consumes a span of inputs. The span sizes are determined by a number of required arguments of these functions. >>> import operator as op >>> from redex import combinator as cb >>> parallel = cb.parallel(op.add, op.add) >>> parallel(1, 2, 3, 4) == (1 + 2, 3 + 4) True Args: children: a sequence of functions. Returns: a combinator. """ flat_children = util.flatten(children) signature, children_signatures = _estimate_parallel_signatures(flat_children) return Parallel( signature=signature, children=flat_children, children_signatures=children_signatures, )
_Initializer = tuple[int, int, List[Signature]] def _estimate_parallel_signatures( children: List[Fn], ) -> tuple[Signature, List[Signature]]: def count(acc: _Initializer, child: Fn) -> _Initializer: in_total, out_total, signatures = acc signature = replace(fn.infer_signature(child), start_index=in_total) return ( in_total + signature.n_in, out_total + signature.n_out, signatures + [signature], ) initializer: _Initializer = (0, 0, []) in_total, out_total, children_signatures = reduce(count, children, initializer) return Signature(n_in=in_total, n_out=out_total), children_signatures