diff --git a/gem/__init__.py b/gem/__init__.py deleted file mode 100644 index f1e77203..00000000 --- a/gem/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from gem.gem import * # noqa -from gem.optimise import select_expression # noqa diff --git a/gem/coffee.py b/gem/coffee.py deleted file mode 100644 index f766a489..00000000 --- a/gem/coffee.py +++ /dev/null @@ -1,192 +0,0 @@ -"""This module contains an implementation of the COFFEE optimisation -algorithm operating on a GEM representation. - -This file is NOT for code generation as a COFFEE AST. -""" - -from collections import OrderedDict -import itertools -import logging - -import numpy - -from gem.gem import IndexSum, one -from gem.optimise import make_sum, make_product -from gem.refactorise import Monomial -from gem.utils import groupby - - -__all__ = ['optimise_monomial_sum'] - - -def monomial_sum_to_expression(monomial_sum): - """Convert a monomial sum to a GEM expression. - - :arg monomial_sum: an iterable of :class:`Monomial`s - - :returns: GEM expression - """ - indexsums = [] # The result is summation of indexsums - # Group monomials according to their sum indices - groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) - # Create IndexSum's from each monomial group - for _, monomials in groups: - sum_indices = monomials[0].sum_indices - products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials] - indexsums.append(IndexSum(make_sum(products), sum_indices)) - return make_sum(indexsums) - - -def index_extent(factor, linear_indices): - """Compute the product of the extents of linear indices of a GEM expression - - :arg factor: GEM expression - :arg linear_indices: set of linear indices - - :returns: product of extents of linear indices - """ - return numpy.prod([i.extent for i in factor.free_indices if i in linear_indices]) - - -def find_optimal_atomics(monomials, linear_indices): - """Find optimal atomic common subexpressions, which produce least number of - terms in the resultant IndexSum when factorised. - - :arg monomials: A list of :class:`Monomial`s, all of which should have - the same sum indices - :arg linear_indices: tuple of linear indices - - :returns: list of atomic GEM expressions - """ - atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials)))) - - def cost(solution): - extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution)) - # Prefer shorter solutions, but larger extents - return (len(solution), -extent) - - optimal_solution = set(atomics) # pessimal but feasible solution - solution = set() - - max_it = 1 << 12 - it = iter(range(max_it)) - - def solve(idx): - while idx < len(monomials) and solution.intersection(monomials[idx].atomics): - idx += 1 - - if idx < len(monomials): - if len(solution) < len(optimal_solution): - for atomic in monomials[idx].atomics: - solution.add(atomic) - solve(idx + 1) - solution.remove(atomic) - else: - if cost(solution) < cost(optimal_solution): - optimal_solution.clear() - optimal_solution.update(solution) - next(it) - - try: - solve(0) - except StopIteration: - logger = logging.getLogger('tsfc') - logger.warning("Solution to ILP problem may not be optimal: search " - "interrupted after examining %d solutions.", max_it) - - return tuple(atomic for atomic in atomics if atomic in optimal_solution) - - -def factorise_atomics(monomials, optimal_atomics, linear_indices): - """Group and factorise monomials using a list of atomics as common - subexpressions. Create new monomials for each group and optimise them recursively. - - :arg monomials: an iterable of :class:`Monomial`s, all of which should have - the same sum indices - :arg optimal_atomics: list of tuples of atomics to be used as common subexpression - :arg linear_indices: tuple of linear indices - - :returns: an iterable of :class:`Monomials`s after factorisation - """ - if not optimal_atomics or len(monomials) <= 1: - return monomials - - # Group monomials with respect to each optimal atomic - def group_key(monomial): - for oa in optimal_atomics: - if oa in monomial.atomics: - return oa - assert False, "Expect at least one optimal atomic per monomial." - factor_group = groupby(monomials, key=group_key) - - # We should not drop monomials - assert sum(len(ms) for _, ms in factor_group) == len(monomials) - - sum_indices = next(iter(monomials)).sum_indices - new_monomials = [] - for oa, monomials in factor_group: - # Create new MonomialSum for the factorised out terms - sub_monomials = [] - for monomial in monomials: - atomics = list(monomial.atomics) - atomics.remove(oa) # remove common factor - sub_monomials.append(Monomial((), tuple(atomics), monomial.rest)) - # Continue to factorise the remaining expression - sub_monomials = optimise_monomials(sub_monomials, linear_indices) - if len(sub_monomials) == 1: - # Factorised part is a product, we add back the common atomics then - # add to new MonomialSum directly rather than forming a product node - # Retaining the monomial structure enables applying associativity - # when forming GEM nodes later. - sub_monomial, = sub_monomials - new_monomials.append( - Monomial(sum_indices, (oa,) + sub_monomial.atomics, sub_monomial.rest)) - else: - # Factorised part is a summation, we need to create a new GEM node - # and multiply with the common factor - node = monomial_sum_to_expression(sub_monomials) - # If the free indices of the new node intersect with linear indices, - # add to the new monomial as `atomic`, otherwise add as `rest`. - # Note: we might want to continue to factorise with the new atomics - # by running optimise_monoials twice. - if set(linear_indices) & set(node.free_indices): - new_monomials.append(Monomial(sum_indices, (oa, node), one)) - else: - new_monomials.append(Monomial(sum_indices, (oa, ), node)) - return new_monomials - - -def optimise_monomial_sum(monomial_sum, linear_indices): - """Choose optimal common atomic subexpressions and factorise a - :class:`MonomialSum` object to create a GEM expression. - - :arg monomial_sum: a :class:`MonomialSum` object - :arg linear_indices: tuple of linear indices - - :returns: factorised GEM expression - """ - groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) - new_monomials = [] - for _, monomials in groups: - new_monomials.extend(optimise_monomials(monomials, linear_indices)) - return monomial_sum_to_expression(new_monomials) - - -def optimise_monomials(monomials, linear_indices): - """Choose optimal common atomic subexpressions and factorise an iterable - of monomials. - - :arg monomials: a list of :class:`Monomial`s, all of which should have - the same sum indices - :arg linear_indices: tuple of linear indices - - :returns: an iterable of factorised :class:`Monomials`s - """ - assert len(set(frozenset(m.sum_indices) for m in monomials)) <= 1, \ - "All monomials required to have same sum indices for factorisation" - - result = [m for m in monomials if not m.atomics] # skipped monomials - active_monomials = [m for m in monomials if m.atomics] - optimal_atomics = find_optimal_atomics(active_monomials, linear_indices) - result += factorise_atomics(active_monomials, optimal_atomics, linear_indices) - return result diff --git a/gem/flop_count.py b/gem/flop_count.py deleted file mode 100644 index b9595e81..00000000 --- a/gem/flop_count.py +++ /dev/null @@ -1,197 +0,0 @@ -""" -This file contains all the necessary functions to accurately count the -total number of floating point operations for a given script. -""" - -import gem.gem as gem -import gem.impero as imp -from functools import singledispatch -import numpy -import math - - -@singledispatch -def statement(tree, temporaries): - raise NotImplementedError - - -@statement.register(imp.Block) -def statement_block(tree, temporaries): - flops = sum(statement(child, temporaries) for child in tree.children) - return flops - - -@statement.register(imp.For) -def statement_for(tree, temporaries): - extent = tree.index.extent - assert extent is not None - child, = tree.children - flops = statement(child, temporaries) - return flops * extent - - -@statement.register(imp.Initialise) -def statement_initialise(tree, temporaries): - return 0 - - -@statement.register(imp.Accumulate) -def statement_accumulate(tree, temporaries): - flops = expression_flops(tree.indexsum.children[0], temporaries) - return flops + 1 - - -@statement.register(imp.Return) -def statement_return(tree, temporaries): - flops = expression_flops(tree.expression, temporaries) - return flops + 1 - - -@statement.register(imp.ReturnAccumulate) -def statement_returnaccumulate(tree, temporaries): - flops = expression_flops(tree.indexsum.children[0], temporaries) - return flops + 1 - - -@statement.register(imp.Evaluate) -def statement_evaluate(tree, temporaries): - flops = expression_flops(tree.expression, temporaries, top=True) - return flops - - -@singledispatch -def flops(expr, temporaries): - raise NotImplementedError(f"Don't know how to count flops of {type(expr)}") - - -@flops.register(gem.Failure) -def flops_failure(expr, temporaries): - raise ValueError("Not expecting a Failure node") - - -@flops.register(gem.Variable) -@flops.register(gem.Identity) -@flops.register(gem.Delta) -@flops.register(gem.Zero) -@flops.register(gem.Literal) -@flops.register(gem.Index) -@flops.register(gem.VariableIndex) -def flops_zero(expr, temporaries): - # Initial set up of these Gem nodes are of 0 floating point operations. - return 0 - - -@flops.register(gem.LogicalNot) -@flops.register(gem.LogicalAnd) -@flops.register(gem.LogicalOr) -@flops.register(gem.ListTensor) -def flops_zeroplus(expr, temporaries): - # These nodes contribute 0 floating point operations, but their children may not. - return 0 + sum(expression_flops(child, temporaries) - for child in expr.children) - - -@flops.register(gem.Product) -def flops_product(expr, temporaries): - # Multiplication by -1 is not a flop. - a, b = expr.children - if isinstance(a, gem.Literal) and a.value == -1: - return expression_flops(b, temporaries) - elif isinstance(b, gem.Literal) and b.value == -1: - return expression_flops(a, temporaries) - else: - return 1 + sum(expression_flops(child, temporaries) - for child in expr.children) - - -@flops.register(gem.Sum) -@flops.register(gem.Division) -@flops.register(gem.Comparison) -@flops.register(gem.MathFunction) -@flops.register(gem.MinValue) -@flops.register(gem.MaxValue) -def flops_oneplus(expr, temporaries): - return 1 + sum(expression_flops(child, temporaries) - for child in expr.children) - - -@flops.register(gem.Power) -def flops_power(expr, temporaries): - base, exponent = expr.children - base_flops = expression_flops(base, temporaries) - if isinstance(exponent, gem.Literal): - exponent = exponent.value - if exponent > 0 and exponent == math.floor(exponent): - return base_flops + int(math.ceil(math.log2(exponent))) - else: - return base_flops + 5 # heuristic - else: - return base_flops + 5 # heuristic - - -@flops.register(gem.Conditional) -def flops_conditional(expr, temporaries): - condition, then, else_ = (expression_flops(child, temporaries) - for child in expr.children) - return condition + max(then, else_) - - -@flops.register(gem.Indexed) -@flops.register(gem.FlexiblyIndexed) -def flops_indexed(expr, temporaries): - aggregate = sum(expression_flops(child, temporaries) - for child in expr.children) - # Average flops per entry - return aggregate / numpy.prod(expr.children[0].shape, dtype=int) - - -@flops.register(gem.IndexSum) -def flops_indexsum(expr, temporaries): - raise ValueError("Not expecting IndexSum") - - -@flops.register(gem.Inverse) -def flops_inverse(expr, temporaries): - n, _ = expr.shape - # 2n^3 + child flop count - return 2*n**3 + sum(expression_flops(child, temporaries) - for child in expr.children) - - -@flops.register(gem.Solve) -def flops_solve(expr, temporaries): - n, m = expr.shape - # 2mn + inversion cost of A + children flop count - return 2*n*m + 2*n**3 + sum(expression_flops(child, temporaries) - for child in expr.children) - - -@flops.register(gem.ComponentTensor) -def flops_componenttensor(expr, temporaries): - raise ValueError("Not expecting ComponentTensor") - - -def expression_flops(expression, temporaries, top=False): - """An approximation to flops required for each expression. - - :arg expression: GEM expression. - :arg temporaries: Expressions that are assigned to temporaries - :arg top: are we at the root? - :returns: flop count for the expression - """ - if not top and expression in temporaries: - return 0 - else: - return flops(expression, temporaries) - - -def count_flops(impero_c): - """An approximation to flops required for a scheduled impero_c tree. - - :arg impero_c: a :class:`~.Impero_C` object. - :returns: approximate flop count for the tree. - """ - try: - return statement(impero_c.tree, set(impero_c.temporaries)) - except (ValueError, NotImplementedError): - return 0 diff --git a/gem/gem.py b/gem/gem.py deleted file mode 100644 index 95e8f2f5..00000000 --- a/gem/gem.py +++ /dev/null @@ -1,1046 +0,0 @@ -"""GEM is the intermediate language of TSFC for describing -tensor-valued mathematical expressions and tensor operations. -It is similar to Einstein's notation. - -Its design was heavily inspired by UFL, with some major differences: - - GEM has got nothing FEM-specific. - - In UFL free indices are just unrolled shape, thus UFL is very - restrictive about operations on expressions with different sets of - free indices. GEM is much more relaxed about free indices. - -Similarly to UFL, all GEM nodes have 'shape' and 'free_indices' -attributes / properties. Unlike UFL, however, index extents live on -the Index objects in GEM, not on all the nodes that have those free -indices. -""" - -from abc import ABCMeta -from itertools import chain -from operator import attrgetter -from numbers import Integral, Number - -import numpy -from numpy import asarray - -from gem.node import Node as NodeBase, traversal - - -__all__ = ['Node', 'Identity', 'Literal', 'Zero', 'Failure', - 'Variable', 'Sum', 'Product', 'Division', 'Power', - 'MathFunction', 'MinValue', 'MaxValue', 'Comparison', - 'LogicalNot', 'LogicalAnd', 'LogicalOr', 'Conditional', - 'Index', 'VariableIndex', 'Indexed', 'ComponentTensor', - 'IndexSum', 'ListTensor', 'Concatenate', 'Delta', - 'index_sum', 'partial_indexed', 'reshape', 'view', - 'indices', 'as_gem', 'FlexiblyIndexed', - 'Inverse', 'Solve', 'extract_type'] - - -class NodeMeta(type): - """Metaclass of GEM nodes. - - When a GEM node is constructed, this metaclass automatically - collects its free indices if 'free_indices' has not been set yet. - """ - - def __call__(self, *args, **kwargs): - # Create and initialise object - obj = super(NodeMeta, self).__call__(*args, **kwargs) - - # Set free_indices if not set already - if not hasattr(obj, 'free_indices'): - obj.free_indices = unique(chain(*[c.free_indices - for c in obj.children])) - - return obj - - -class Node(NodeBase, metaclass=NodeMeta): - """Abstract GEM node class.""" - - __slots__ = ('free_indices',) - - def is_equal(self, other): - """Common subexpression eliminating equality predicate. - - When two (sub)expressions are equal, the children of one - object are reassigned to the children of the other, so some - duplicated subexpressions are eliminated. - """ - result = NodeBase.is_equal(self, other) - if result: - self.children = other.children - return result - - def __getitem__(self, indices): - try: - indices = tuple(indices) - except TypeError: - indices = (indices, ) - return Indexed(self, indices) - - def __add__(self, other): - return componentwise(Sum, self, as_gem(other)) - - def __radd__(self, other): - return as_gem(other).__add__(self) - - def __sub__(self, other): - return componentwise( - Sum, self, - componentwise(Product, Literal(-1), as_gem(other))) - - def __rsub__(self, other): - return as_gem(other).__sub__(self) - - def __mul__(self, other): - return componentwise(Product, self, as_gem(other)) - - def __rmul__(self, other): - return as_gem(other).__mul__(self) - - def __matmul__(self, other): - other = as_gem(other) - if not self.shape and not other.shape: - return Product(self, other) - elif not (self.shape and other.shape): - raise ValueError("Both objects must have shape for matmul") - elif self.shape[-1] != other.shape[0]: - raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul") - *i, k = indices(len(self.shape)) - _, *j = indices(len(other.shape)) - expr = Product(Indexed(self, tuple(i) + (k, )), - Indexed(other, (k, ) + tuple(j))) - return ComponentTensor(IndexSum(expr, (k, )), tuple(i) + tuple(j)) - - def __rmatmul__(self, other): - return as_gem(other).__matmul__(self) - - @property - def T(self): - i = indices(len(self.shape)) - return ComponentTensor(Indexed(self, i), tuple(reversed(i))) - - def __truediv__(self, other): - other = as_gem(other) - if other.shape: - raise ValueError("Denominator must be scalar") - return componentwise(Division, self, other) - - def __rtruediv__(self, other): - return as_gem(other).__truediv__(self) - - -class Terminal(Node): - """Abstract class for terminal GEM nodes.""" - - __slots__ = () - - children = () - - is_equal = NodeBase.is_equal - - -class Scalar(Node): - """Abstract class for scalar-valued GEM nodes.""" - - __slots__ = () - - shape = () - - -class Failure(Terminal): - """Abstract class for failure GEM nodes.""" - - __slots__ = ('shape', 'exception') - __front__ = ('shape', 'exception') - - def __init__(self, shape, exception): - self.shape = shape - self.exception = exception - - -class Constant(Terminal): - """Abstract base class for constant types. - - Convention: - - array: numpy array of values - - value: float or complex value (scalars only) - """ - __slots__ = () - - -class Zero(Constant): - """Symbolic zero tensor""" - - __slots__ = ('shape',) - __front__ = ('shape',) - - def __init__(self, shape=()): - self.shape = shape - - @property - def value(self): - assert not self.shape - return 0.0 - - -class Identity(Constant): - """Identity matrix""" - - __slots__ = ('dim',) - __front__ = ('dim',) - - def __init__(self, dim): - self.dim = dim - - @property - def shape(self): - return (self.dim, self.dim) - - @property - def array(self): - return numpy.eye(self.dim) - - -class Literal(Constant): - """Tensor-valued constant""" - - __slots__ = ('array',) - __front__ = ('array',) - - def __new__(cls, array): - array = asarray(array) - return super(Literal, cls).__new__(cls) - - def __init__(self, array): - array = asarray(array) - try: - self.array = array.astype(float, casting="safe") - except TypeError: - self.array = array.astype(complex) - - def is_equal(self, other): - if type(self) is not type(other): - return False - if self.shape != other.shape: - return False - return tuple(self.array.flat) == tuple(other.array.flat) - - def get_hash(self): - return hash((type(self), self.shape, tuple(self.array.flat))) - - @property - def value(self): - assert self.shape == () - return self.array.dtype.type(self.array) - - @property - def shape(self): - return self.array.shape - - -class Variable(Terminal): - """Symbolic variable tensor""" - - __slots__ = ('name', 'shape') - __front__ = ('name', 'shape') - - def __init__(self, name, shape): - self.name = name - self.shape = shape - - -class Sum(Scalar): - __slots__ = ('children',) - - def __new__(cls, a, b): - assert not a.shape - assert not b.shape - - # Constant folding - if isinstance(a, Zero): - return b - elif isinstance(b, Zero): - return a - - if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value + b.value) - - self = super(Sum, cls).__new__(cls) - self.children = a, b - return self - - -class Product(Scalar): - __slots__ = ('children',) - - def __new__(cls, a, b): - assert not a.shape - assert not b.shape - - # Constant folding - if isinstance(a, Zero) or isinstance(b, Zero): - return Zero() - - if a == one: - return b - if b == one: - return a - - if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value * b.value) - - self = super(Product, cls).__new__(cls) - self.children = a, b - return self - - -class Division(Scalar): - __slots__ = ('children',) - - def __new__(cls, a, b): - assert not a.shape - assert not b.shape - - # Constant folding - if isinstance(b, Zero): - raise ValueError("division by zero") - if isinstance(a, Zero): - return Zero() - - if b == one: - return a - - if isinstance(a, Constant) and isinstance(b, Constant): - return Literal(a.value / b.value) - - self = super(Division, cls).__new__(cls) - self.children = a, b - return self - - -class Power(Scalar): - __slots__ = ('children',) - - def __new__(cls, base, exponent): - assert not base.shape - assert not exponent.shape - - # Constant folding - if isinstance(base, Zero): - if isinstance(exponent, Zero): - raise ValueError("cannot solve 0^0") - return Zero() - elif isinstance(exponent, Zero): - return one - - if isinstance(base, Constant) and isinstance(exponent, Constant): - return Literal(base.value ** exponent.value) - - self = super(Power, cls).__new__(cls) - self.children = base, exponent - return self - - -class MathFunction(Scalar): - __slots__ = ('name', 'children') - __front__ = ('name',) - - def __new__(cls, name, *args): - assert isinstance(name, str) - assert all(arg.shape == () for arg in args) - - if name in {'conj', 'real', 'imag'}: - arg, = args - if isinstance(arg, Zero): - return arg - - self = super(MathFunction, cls).__new__(cls) - self.name = name - self.children = args - return self - - -class MinValue(Scalar): - __slots__ = ('children',) - - def __init__(self, a, b): - assert not a.shape - assert not b.shape - - self.children = a, b - - -class MaxValue(Scalar): - __slots__ = ('children',) - - def __init__(self, a, b): - assert not a.shape - assert not b.shape - - self.children = a, b - - -class Comparison(Scalar): - __slots__ = ('operator', 'children') - __front__ = ('operator',) - - def __init__(self, op, a, b): - assert not a.shape - assert not b.shape - - if op not in [">", ">=", "==", "!=", "<", "<="]: - raise ValueError("invalid operator") - - self.operator = op - self.children = a, b - - -class LogicalNot(Scalar): - __slots__ = ('children',) - - def __init__(self, expression): - assert not expression.shape - - self.children = expression, - - -class LogicalAnd(Scalar): - __slots__ = ('children',) - - def __init__(self, a, b): - assert not a.shape - assert not b.shape - - self.children = a, b - - -class LogicalOr(Scalar): - __slots__ = ('children',) - - def __init__(self, a, b): - assert not a.shape - assert not b.shape - - self.children = a, b - - -class Conditional(Node): - __slots__ = ('children', 'shape') - - def __new__(cls, condition, then, else_): - assert not condition.shape - assert then.shape == else_.shape == () - - # If both branches are the same, just return one of them. In - # particular, this will help constant-fold zeros. - if then == else_: - return then - - self = super(Conditional, cls).__new__(cls) - self.children = condition, then, else_ - self.shape = then.shape - return self - - -class IndexBase(metaclass=ABCMeta): - """Abstract base class for indices.""" - pass - - -IndexBase.register(int) - - -class Index(IndexBase): - """Free index""" - - # Not true object count, just for naming purposes - _count = 0 - - __slots__ = ('name', 'extent', 'count') - - def __init__(self, name=None, extent=None): - self.name = name - Index._count += 1 - self.count = Index._count - self.extent = extent - - def set_extent(self, value): - # Set extent, check for consistency - if self.extent is None: - self.extent = value - elif self.extent != value: - raise ValueError("Inconsistent index extents!") - - def __str__(self): - if self.name is None: - return "i_%d" % self.count - return self.name - - def __repr__(self): - if self.name is None: - return "Index(%r)" % self.count - return "Index(%r)" % self.name - - def __lt__(self, other): - # Allow sorting of free indices in Python 3 - return id(self) < id(other) - - def __getstate__(self): - return self.name, self.extent, self.count - - def __setstate__(self, state): - self.name, self.extent, self.count = state - - -class VariableIndex(IndexBase): - """An index that is constant during a single execution of the - kernel, but whose value is not known at compile time.""" - - __slots__ = ('expression',) - - def __init__(self, expression): - assert isinstance(expression, Node) - assert not expression.free_indices - assert not expression.shape - self.expression = expression - - def __eq__(self, other): - if self is other: - return True - if type(self) is not type(other): - return False - return self.expression == other.expression - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash((VariableIndex, self.expression)) - - def __str__(self): - return str(self.expression) - - def __repr__(self): - return "VariableIndex(%r)" % (self.expression,) - - def __reduce__(self): - return VariableIndex, (self.expression,) - - -class Indexed(Scalar): - __slots__ = ('children', 'multiindex') - __back__ = ('multiindex',) - - def __new__(cls, aggregate, multiindex): - # Accept numpy or any integer, but cast to int. - multiindex = tuple(int(i) if isinstance(i, Integral) else i - for i in multiindex) - - # Set index extents from shape - assert len(aggregate.shape) == len(multiindex) - for index, extent in zip(multiindex, aggregate.shape): - assert isinstance(index, IndexBase) - if isinstance(index, Index): - index.set_extent(extent) - elif isinstance(index, int) and not (0 <= index < extent): - raise IndexError("Invalid literal index") - - # Empty multiindex - if not multiindex: - return aggregate - - # Zero folding - if isinstance(aggregate, Zero): - return Zero() - - # All indices fixed - if all(isinstance(i, int) for i in multiindex): - if isinstance(aggregate, Constant): - return Literal(aggregate.array[multiindex]) - elif isinstance(aggregate, ListTensor): - return aggregate.array[multiindex] - - self = super(Indexed, cls).__new__(cls) - self.children = (aggregate,) - self.multiindex = multiindex - - new_indices = tuple(i for i in multiindex if isinstance(i, Index)) - self.free_indices = unique(aggregate.free_indices + new_indices) - - return self - - def index_ordering(self): - """Running indices in the order of indexing in this node.""" - return tuple(i for i in self.multiindex if isinstance(i, Index)) - - -class FlexiblyIndexed(Scalar): - """Flexible indexing of :py:class:`Variable`s to implement views and - reshapes (splitting dimensions only).""" - - __slots__ = ('children', 'dim2idxs') - __back__ = ('dim2idxs',) - - def __init__(self, variable, dim2idxs): - """Construct a flexibly indexed node. - - :arg variable: a node that has a shape - :arg dim2idxs: describes the mapping of indices - - For example, if ``variable`` is rank two, and ``dim2idxs`` is - - ((1, ((i, 12), (j, 4), (k, 1))), (0, ())) - - then this corresponds to the indexing: - - variable[1 + i*12 + j*4 + k][0] - - """ - assert variable.shape - assert len(variable.shape) == len(dim2idxs) - - dim2idxs_ = [] - free_indices = [] - for dim, (offset, idxs) in zip(variable.shape, dim2idxs): - offset_ = offset - idxs_ = [] - last = 0 - for idx in idxs: - index, stride = idx - if isinstance(index, Index): - assert index.extent is not None - free_indices.append(index) - idxs_.append((index, stride)) - last += (index.extent - 1) * stride - elif isinstance(index, int): - offset_ += index * stride - else: - raise ValueError("Unexpected index type for flexible indexing") - - if dim is not None and offset_ + last >= dim: - raise ValueError("Offset {0} and indices {1} exceed dimension {2}".format(offset, idxs, dim)) - - dim2idxs_.append((offset_, tuple(idxs_))) - - self.children = (variable,) - self.dim2idxs = tuple(dim2idxs_) - self.free_indices = unique(free_indices) - - def index_ordering(self): - """Running indices in the order of indexing in this node.""" - return tuple(index - for _, idxs in self.dim2idxs - for index, _ in idxs - if isinstance(index, Index)) - - -class ComponentTensor(Node): - __slots__ = ('children', 'multiindex', 'shape') - __back__ = ('multiindex',) - - def __new__(cls, expression, multiindex): - assert not expression.shape - - # Empty multiindex - if not multiindex: - return expression - - # Collect shape - shape = tuple(index.extent for index in multiindex) - assert all(s >= 0 for s in shape) - - # Zero folding - if isinstance(expression, Zero): - return Zero(shape) - - self = super(ComponentTensor, cls).__new__(cls) - self.children = (expression,) - self.multiindex = multiindex - self.shape = shape - - # Collect free indices - assert set(multiindex) <= set(expression.free_indices) - self.free_indices = unique(set(expression.free_indices) - set(multiindex)) - - return self - - -class IndexSum(Scalar): - __slots__ = ('children', 'multiindex') - __back__ = ('multiindex',) - - def __new__(cls, summand, multiindex): - # Sum zeros - assert not summand.shape - if isinstance(summand, Zero): - return summand - - # Unroll singleton sums - unroll = tuple(index for index in multiindex if index.extent <= 1) - if unroll: - assert numpy.prod([index.extent for index in unroll]) == 1 - summand = Indexed(ComponentTensor(summand, unroll), - (0,) * len(unroll)) - multiindex = tuple(index for index in multiindex - if index not in unroll) - - # No indices case - multiindex = tuple(multiindex) - if not multiindex: - return summand - - self = super(IndexSum, cls).__new__(cls) - self.children = (summand,) - self.multiindex = multiindex - - # Collect shape and free indices - assert set(multiindex) <= set(summand.free_indices) - self.free_indices = unique(set(summand.free_indices) - set(multiindex)) - - return self - - -class ListTensor(Node): - __slots__ = ('array',) - - def __new__(cls, array): - array = asarray(array) - assert numpy.prod(array.shape) - - # Handle children with shape - child_shape = array.flat[0].shape - assert all(elem.shape == child_shape for elem in array.flat) - - if child_shape: - # Destroy structure - direct_array = numpy.empty(array.shape + child_shape, dtype=object) - for alpha in numpy.ndindex(array.shape): - for beta in numpy.ndindex(child_shape): - direct_array[alpha + beta] = Indexed(array[alpha], beta) - array = direct_array - - # Constant folding - if all(isinstance(elem, Constant) for elem in array.flat): - return Literal(numpy.vectorize(attrgetter('value'))(array)) - - self = super(ListTensor, cls).__new__(cls) - self.array = array - return self - - @property - def children(self): - return tuple(self.array.flat) - - @property - def shape(self): - return self.array.shape - - def __reduce__(self): - return type(self), (self.array,) - - def reconstruct(self, *args): - return ListTensor(asarray(args).reshape(self.array.shape)) - - def __repr__(self): - return "ListTensor(%r)" % self.array.tolist() - - def is_equal(self, other): - """Common subexpression eliminating equality predicate.""" - if type(self) is not type(other): - return False - if (self.array == other.array).all(): - self.array = other.array - return True - return False - - def get_hash(self): - return hash((type(self), self.shape, self.children)) - - -class Concatenate(Node): - """Flattens and concatenates GEM expressions by shape. - - Similar to what UFL MixedElement does to value shape. For - example, if children have shapes (2, 2), (), and (3,) then the - concatenated expression has shape (8,). - """ - __slots__ = ('children',) - - def __new__(cls, *children): - if all(isinstance(child, Zero) for child in children): - size = int(sum(numpy.prod(child.shape, dtype=int) for child in children)) - return Zero((size,)) - - self = super(Concatenate, cls).__new__(cls) - self.children = children - return self - - @property - def shape(self): - return (int(sum(numpy.prod(child.shape, dtype=int) for child in self.children)),) - - -class Delta(Scalar, Terminal): - __slots__ = ('i', 'j') - __front__ = ('i', 'j') - - def __new__(cls, i, j): - assert isinstance(i, IndexBase) - assert isinstance(j, IndexBase) - - # \delta_{i,i} = 1 - if i == j: - return one - - # Fixed indices - if isinstance(i, int) and isinstance(j, int): - return Literal(int(i == j)) - - self = super(Delta, cls).__new__(cls) - self.i = i - self.j = j - # Set up free indices - free_indices = tuple(index for index in (i, j) if isinstance(index, Index)) - self.free_indices = tuple(unique(free_indices)) - return self - - -class Inverse(Node): - """The inverse of a square matrix.""" - __slots__ = ('children', 'shape') - - def __new__(cls, tensor): - assert len(tensor.shape) == 2 - assert tensor.shape[0] == tensor.shape[1] - - # Invert 1x1 matrix - if tensor.shape == (1, 1): - multiindex = (Index(), Index()) - return ComponentTensor(Division(one, Indexed(tensor, multiindex)), multiindex) - - self = super(Inverse, cls).__new__(cls) - self.children = (tensor,) - self.shape = tensor.shape - - return self - - -class Solve(Node): - """Solution of a square matrix equation with (potentially) multiple right hand sides. - - Represents the X obtained by solving AX = B. - """ - __slots__ = ('children', 'shape') - - def __init__(self, A, B): - # Shape requirements - assert B.shape - assert len(A.shape) == 2 - assert A.shape[0] == A.shape[1] - assert A.shape[0] == B.shape[0] - - self.children = (A, B) - self.shape = A.shape[1:] + B.shape[1:] - - -def unique(indices): - """Sorts free indices and eliminates duplicates. - - :arg indices: iterable of indices - :returns: sorted tuple of unique free indices - """ - return tuple(sorted(set(indices), key=id)) - - -def index_sum(expression, indices): - """Eliminates indices from the free indices of an expression by - summing over them. Skips any index that is not a free index of - the expression.""" - multiindex = tuple(index for index in indices - if index in expression.free_indices) - return IndexSum(expression, multiindex) - - -def partial_indexed(tensor, indices): - """Generalised indexing into a tensor by eating shape off the front. - The number of indices may be less than or equal to the rank of the tensor, - so the result may have a non-empty shape. - - :arg tensor: tensor-valued GEM expression - :arg indices: indices, at most as many as the rank of the tensor - :returns: a potentially tensor-valued expression - """ - if len(indices) == 0: - return tensor - elif len(indices) < len(tensor.shape): - rank = len(tensor.shape) - len(indices) - shape_indices = tuple(Index() for i in range(rank)) - return ComponentTensor( - Indexed(tensor, indices + shape_indices), - shape_indices) - elif len(indices) == len(tensor.shape): - return Indexed(tensor, indices) - else: - raise ValueError("More indices than rank!") - - -def strides_of(shape): - """Calculate cumulative strides from per-dimension capacities. - - For example: - - [2, 3, 4] ==> [12, 4, 1] - - """ - temp = numpy.flipud(numpy.cumprod(numpy.flipud(list(shape)[1:]))) - return list(temp) + [1] - - -def decompose_variable_view(expression): - """Extract information from a shaped node. - Decompose ComponentTensor + FlexiblyIndexed.""" - if (isinstance(expression, (Variable, Inverse, Solve))): - variable = expression - indexes = tuple(Index(extent=extent) for extent in expression.shape) - dim2idxs = tuple((0, ((index, 1),)) for index in indexes) - elif (isinstance(expression, ComponentTensor) and - not isinstance(expression.children[0], FlexiblyIndexed)): - variable = expression - indexes = expression.multiindex - dim2idxs = tuple((0, ((index, 1),)) for index in indexes) - elif isinstance(expression, ComponentTensor) and isinstance(expression.children[0], FlexiblyIndexed): - variable = expression.children[0].children[0] - indexes = expression.multiindex - dim2idxs = expression.children[0].dim2idxs - else: - raise ValueError("Cannot handle {} objects.".format(type(expression).__name__)) - - return variable, dim2idxs, indexes - - -def reshape(expression, *shapes): - """Reshape a variable (splitting indices only). - - :arg expression: view of a :py:class:`Variable` - :arg shapes: one shape tuple for each dimension of the variable. - """ - variable, dim2idxs, indexes = decompose_variable_view(expression) - assert len(indexes) == len(shapes) - shape_of = dict(zip(indexes, shapes)) - - dim2idxs_ = [] - indices = [[] for _ in range(len(indexes))] - for offset, idxs in dim2idxs: - idxs_ = [] - for idx in idxs: - index, stride = idx - assert isinstance(index, Index) - dim = index.extent - shape = shape_of[index] - if dim is not None and numpy.prod(shape) != dim: - raise ValueError("Shape {} does not match extent {}.".format(shape, dim)) - strides = strides_of(shape) - for extent, stride_ in zip(shape, strides): - index_ = Index(extent=extent) - idxs_.append((index_, stride_ * stride)) - indices[indexes.index(index)].append(index_) - dim2idxs_.append((offset, tuple(idxs_))) - - expr = FlexiblyIndexed(variable, tuple(dim2idxs_)) - return ComponentTensor(expr, tuple(chain.from_iterable(indices))) - - -def view(expression, *slices): - """View a part of a shaped object. - - :arg expression: a node that has a shape - :arg slices: one slice object for each dimension of the expression. - """ - variable, dim2idxs, indexes = decompose_variable_view(expression) - assert len(indexes) == len(slices) - slice_of = dict(zip(indexes, slices)) - - dim2idxs_ = [] - indices = [None] * len(slices) - for offset, idxs in dim2idxs: - offset_ = offset - idxs_ = [] - for idx in idxs: - index, stride = idx - assert isinstance(index, Index) - dim = index.extent - s = slice_of[index] - start = s.start or 0 - stop = s.stop or dim - if stop is None: - raise ValueError("Unknown extent!") - if dim is not None and stop > dim: - raise ValueError("Slice exceeds dimension extent!") - step = s.step or 1 - offset_ += start * stride - extent = 1 + (stop - start - 1) // step - index_ = Index(extent=extent) - indices[indexes.index(index)] = index_ - idxs_.append((index_, step * stride)) - dim2idxs_.append((offset_, tuple(idxs_))) - - expr = FlexiblyIndexed(variable, tuple(dim2idxs_)) - return ComponentTensor(expr, tuple(indices)) - - -# Static one object for quicker constant folding -one = Literal(1) - - -# Syntax sugar -def indices(n): - """Make some :class:`Index` objects. - - :arg n: The number of indices to make. - :returns: A tuple of `n` :class:`Index` objects. - """ - return tuple(Index() for _ in range(n)) - - -def componentwise(op, *exprs): - """Apply gem op to exprs component-wise and wrap up in a ComponentTensor. - - :arg op: function that returns a gem Node. - :arg exprs: expressions to apply op to. - :raises ValueError: if the expressions have mismatching shapes. - :returns: New gem Node constructed from op. - - Each expression must either have the same shape, or else be - scalar. Shaped expressions are indexed, the op is applied to the - scalar expressions and the result is wrapped up in a ComponentTensor. - - """ - shapes = set(e.shape for e in exprs) - if len(shapes - {()}) > 1: - raise ValueError("expressions must have matching shape (or else be scalar)") - shape = max(shapes) - i = indices(len(shape)) - exprs = tuple(Indexed(e, i) if e.shape else e for e in exprs) - return ComponentTensor(op(*exprs), i) - - -def as_gem(expr): - """Attempt to convert an expression into GEM. - - :arg expr: The expression. - :returns: A GEM representation of the expression. - :raises ValueError: if conversion was not possible. - """ - if isinstance(expr, Node): - return expr - elif isinstance(expr, Number): - return Literal(expr) - else: - raise ValueError("Do not know how to convert %r to GEM" % expr) - - -def extract_type(expressions, klass): - """Collects objects of type klass in expressions.""" - return tuple(node for node in traversal(expressions) if isinstance(node, klass)) diff --git a/gem/impero.py b/gem/impero.py deleted file mode 100644 index c909e1bf..00000000 --- a/gem/impero.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Impero is a helper AST for generating C code (or equivalent, -e.g. COFFEE) from GEM. An Impero expression is a proper tree, not -directed acyclic graph (DAG). Impero is a helper AST, not a -standalone language; it is incomplete without GEM as its terminals -refer to nodes from GEM expressions. - -Trivia: - - Impero helps translating GEM into an imperative language. - - Byzantine units in Age of Empires II sometimes say 'Impero?' - (Command?) after clicking on them. -""" - -from abc import ABCMeta, abstractmethod - -from gem.node import Node as NodeBase - - -class Node(NodeBase): - """Base class of all Impero nodes""" - - __slots__ = () - - -class Terminal(Node, metaclass=ABCMeta): - """Abstract class for terminal Impero nodes""" - - __slots__ = () - - children = () - - @abstractmethod - def loop_shape(self, free_indices): - """Gives the loop shape, an ordering of indices for an Impero - terminal. - - :arg free_indices: a callable mapping of GEM expressions to - ordered free indices. - """ - pass - - -class Evaluate(Terminal): - """Assign the value of a GEM expression to a temporary.""" - - __slots__ = ('expression',) - __front__ = ('expression',) - - def __init__(self, expression): - self.expression = expression - - def loop_shape(self, free_indices): - return free_indices(self.expression) - - -class Initialise(Terminal): - """Initialise an :class:`gem.IndexSum`.""" - - __slots__ = ('indexsum',) - __front__ = ('indexsum',) - - def __init__(self, indexsum): - self.indexsum = indexsum - - def loop_shape(self, free_indices): - return free_indices(self.indexsum) - - -class Accumulate(Terminal): - """Accumulate terms into an :class:`gem.IndexSum`.""" - - __slots__ = ('indexsum',) - __front__ = ('indexsum',) - - def __init__(self, indexsum): - self.indexsum = indexsum - - def loop_shape(self, free_indices): - return free_indices(self.indexsum.children[0]) - - -class Noop(Terminal): - """No-op terminal. Does not generate code, but wraps a GEM - expression to have a loop shape, thus affects loop fusion.""" - - __slots__ = ('expression',) - __front__ = ('expression',) - - def __init__(self, expression): - self.expression = expression - - def loop_shape(self, free_indices): - return free_indices(self.expression) - - -class Return(Terminal): - """Save value of GEM expression into an lvalue. Used to "return" - values from a kernel.""" - - __slots__ = ('variable', 'expression') - __front__ = ('variable', 'expression') - - def __init__(self, variable, expression): - assert set(variable.free_indices) >= set(expression.free_indices) - - self.variable = variable - self.expression = expression - - def loop_shape(self, free_indices): - return free_indices(self.variable) - - -class ReturnAccumulate(Terminal): - """Accumulate an :class:`gem.IndexSum` directly into a return - variable.""" - - __slots__ = ('variable', 'indexsum') - __front__ = ('variable', 'indexsum') - - def __init__(self, variable, indexsum): - assert set(variable.free_indices) == set(indexsum.free_indices) - - self.variable = variable - self.indexsum = indexsum - - def loop_shape(self, free_indices): - return free_indices(self.indexsum.children[0]) - - -class Block(Node): - """An ordered set of Impero expressions. Corresponds to a curly - braces block in C.""" - - __slots__ = ('children',) - - def __init__(self, statements): - self.children = tuple(statements) - - -class For(Node): - """For loop with an index which stores its extent, and a loop body - expression which is usually a :class:`Block`.""" - - __slots__ = ('index', 'children') - __front__ = ('index',) - - def __new__(cls, index, statement): - # In case of an empty loop, create a Noop instead. - # Related: https://github.com/coneoproject/COFFEE/issues/98 - assert isinstance(statement, Block) - if not statement.children: - # This "works" because the loop_shape of this node is not - # asked any more. - return Noop(None) - else: - return super(For, cls).__new__(cls) - - def __init__(self, index, statement): - self.index = index - self.children = (statement,) diff --git a/gem/impero_utils.py b/gem/impero_utils.py deleted file mode 100644 index 31f9565b..00000000 --- a/gem/impero_utils.py +++ /dev/null @@ -1,322 +0,0 @@ -"""Utilities for building an Impero AST from an ordered list of -terminal Impero operations, and for building any additional data -required for straightforward C code generation. - -What this module does is independent of the generated code target. -""" - -import collections -from functools import singledispatch -from itertools import chain, groupby - -from gem.node import traversal, collect_refcount -from gem import gem, impero as imp, optimise, scheduling - - -# ImperoC is named tuple for C code generation. -# -# Attributes: -# tree - Impero AST describing the loop structure and operations -# temporaries - List of GEM expressions which have assigned temporaries -# declare - Where to declare temporaries to get correct C code -# indices - Indices for declarations and referencing values -ImperoC = collections.namedtuple('ImperoC', ['tree', 'temporaries', 'declare', 'indices']) - - -class NoopError(Exception): - """No operations in the kernel.""" - pass - - -def preprocess_gem(expressions, replace_delta=True, remove_componenttensors=True): - """Lower GEM nodes that cannot be translated to C directly.""" - if remove_componenttensors: - expressions = optimise.remove_componenttensors(expressions) - if replace_delta: - expressions = optimise.replace_delta(expressions) - return expressions - - -def compile_gem(assignments, prefix_ordering, remove_zeros=False, - emit_return_accumulate=True): - """Compiles GEM to Impero. - - :arg assignments: list of (return variable, expression DAG root) pairs - :arg prefix_ordering: outermost loop indices - :arg remove_zeros: remove zero assignment to return variables - :arg emit_return_accumulate: emit ReturnAccumulate nodes (see - :func:`~.scheduling.emit_operations`)? If False, - split into Accumulate/Return pairs. Set to False if the - output tensor of kernels is not guaranteed to be zero on entry. - """ - # Remove zeros - if remove_zeros: - def nonzero(assignment): - variable, expression = assignment - return not isinstance(expression, gem.Zero) - assignments = list(filter(nonzero, assignments)) - - # Just the expressions - expressions = [expression for variable, expression in assignments] - - # Collect indices in a deterministic order - indices = list(collections.OrderedDict.fromkeys(chain.from_iterable( - node.index_ordering() - for node in traversal(expressions) - if isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)) - ))) - - # Build ordered index map - index_ordering = make_prefix_ordering(indices, prefix_ordering) - apply_ordering = make_index_orderer(index_ordering) - - get_indices = lambda expr: apply_ordering(expr.free_indices) - - # Build operation ordering - ops = scheduling.emit_operations(assignments, get_indices, emit_return_accumulate) - - # Empty kernel - if len(ops) == 0: - raise NoopError() - - # Drop unnecessary temporaries - ops = inline_temporaries(expressions, ops) - - # Build Impero AST - tree = make_loop_tree(ops, get_indices) - - # Collect temporaries - temporaries = collect_temporaries(tree) - - # Determine declarations - declare, indices = place_declarations(tree, temporaries, get_indices) - - # Prepare ImperoC (Impero AST + other data for code generation) - return ImperoC(tree, temporaries, declare, indices) - - -def make_prefix_ordering(indices, prefix_ordering): - """Creates an ordering of ``indices`` which starts with those - indices in ``prefix_ordering``.""" - # Need to return deterministically ordered indices - return tuple(prefix_ordering) + tuple(k for k in indices if k not in prefix_ordering) - - -def make_index_orderer(index_ordering): - """Returns a function which given a set of indices returns those - indices in the order as they appear in ``index_ordering``.""" - idx2pos = {idx: pos for pos, idx in enumerate(index_ordering)} - - def apply_ordering(indices): - return tuple(sorted(indices, key=lambda i: idx2pos[i])) - return apply_ordering - - -def inline_temporaries(expressions, ops): - """Inline temporaries which could be inlined without blowing up - the code. - - :arg expressions: a multi-root GEM expression DAG, used for - reference counting - :arg ops: ordered list of Impero terminals - :returns: a filtered ``ops``, without the unnecessary - :class:`impero.Evaluate`s - """ - refcount = collect_refcount(expressions) - - candidates = set() # candidates for inlining - for op in ops: - if isinstance(op, imp.Evaluate): - expr = op.expression - if expr.shape == () and refcount[expr] == 1: - candidates.add(expr) - - # Prevent inlining that pulls expressions into inner loops - for node in traversal(expressions): - for child in node.children: - if child in candidates and set(child.free_indices) < set(node.free_indices): - candidates.remove(child) - - # Filter out candidates - return [op for op in ops if not (isinstance(op, imp.Evaluate) and op.expression in candidates)] - - -def collect_temporaries(tree): - """Collects GEM expressions to assign to temporaries from a list - of Impero terminals.""" - result = [] - for node in traversal((tree,)): - # IndexSum temporaries should be added either at Initialise or - # at Accumulate. The difference is only in ordering - # (numbering). We chose Accumulate here. - if isinstance(node, imp.Accumulate): - result.append(node.indexsum) - elif isinstance(node, imp.Evaluate): - result.append(node.expression) - return result - - -def make_loop_tree(ops, get_indices, level=0): - """Creates an Impero AST with loops from a list of operations and - their respective free indices. - - :arg ops: a list of Impero terminal nodes - :arg get_indices: callable mapping from GEM nodes to an ordering - of free indices - :arg level: depth of loop nesting - :returns: Impero AST with loops, without declarations - """ - keyfunc = lambda op: op.loop_shape(get_indices)[level:level+1] - statements = [] - for first_index, op_group in groupby(ops, keyfunc): - if first_index: - inner_block = make_loop_tree(op_group, get_indices, level+1) - statements.append(imp.For(first_index[0], inner_block)) - else: - statements.extend(op_group) - # Remove no-op terminals from the tree - statements = [s for s in statements if not isinstance(s, imp.Noop)] - return imp.Block(statements) - - -def place_declarations(tree, temporaries, get_indices): - """Determines where and how to declare temporaries for an Impero AST. - - :arg tree: Impero AST to determine the declarations for - :arg temporaries: list of GEM expressions which are assigned to - temporaries - :arg get_indices: callable mapping from GEM nodes to an ordering - of free indices - """ - numbering = {t: n for n, t in enumerate(temporaries)} - assert len(numbering) == len(temporaries) - - # Collect the total number of temporary references - total_refcount = collections.Counter() - for node in traversal((tree,)): - if isinstance(node, imp.Terminal): - total_refcount.update(temp_refcount(numbering, node)) - assert set(total_refcount) == set(temporaries) - - # Result - declare = {} - indices = {} - - @singledispatch - def recurse(expr, loop_indices): - """Visit an Impero AST to collect declarations. - - :arg expr: Impero tree node - :arg loop_indices: loop indices (in order) from the outer - loops surrounding ``expr`` - :returns: :class:`collections.Counter` with the reference - counts for each temporary in the subtree whose root - is ``expr`` - """ - return AssertionError("unsupported expression type %s" % type(expr)) - - @recurse.register(imp.Terminal) - def recurse_terminal(expr, loop_indices): - return temp_refcount(numbering, expr) - - @recurse.register(imp.For) - def recurse_for(expr, loop_indices): - return recurse(expr.children[0], loop_indices + (expr.index,)) - - @recurse.register(imp.Block) - def recurse_block(expr, loop_indices): - # Temporaries declared at the beginning of the block are - # collected here - declare[expr] = [] - - # Collect reference counts for the block - refcount = collections.Counter() - for statement in expr.children: - refcount.update(recurse(statement, loop_indices)) - - # Visit :class:`collections.Counter` in deterministic order - for e in sorted(refcount.keys(), key=lambda t: numbering[t]): - if refcount[e] == total_refcount[e]: - # If all references are within this block, then this - # block is the right place to declare the temporary. - assert loop_indices == get_indices(e)[:len(loop_indices)] - indices[e] = get_indices(e)[len(loop_indices):] - if indices[e]: - # Scalar-valued temporaries are not declared until - # their value is assigned. This does not really - # matter, but produces a more compact and nicer to - # read C code. - declare[expr].append(e) - # Remove expression from the ``refcount`` so it will - # not be declared again. - del refcount[e] - return refcount - - # Populate result - remainder = recurse(tree, ()) - assert not remainder - - # Set in ``declare`` for Impero terminals whether they should - # declare the temporary that they are writing to. - for node in traversal((tree,)): - if isinstance(node, imp.Terminal): - declare[node] = False - if isinstance(node, imp.Evaluate): - e = node.expression - elif isinstance(node, imp.Initialise): - e = node.indexsum - else: - continue - - if len(indices[e]) == 0: - declare[node] = True - - return declare, indices - - -def temp_refcount(temporaries, op): - """Collects the number of times temporaries are referenced when - generating code for an Impero terminal. - - :arg temporaries: set of temporaries - :arg op: Impero terminal - :returns: :class:`collections.Counter` object mapping some of - elements from ``temporaries`` to the number of times - they will referenced from ``op`` - """ - counter = collections.Counter() - - def recurse(o): - """Traverses expression until reaching temporaries, counting - temporary references.""" - if o in temporaries: - counter[o] += 1 - else: - for c in o.children: - recurse(c) - - def recurse_top(o): - """Traverses expression until reaching temporaries, counting - temporary references. Always descends into children at least - once, even when the root is a temporary.""" - if o in temporaries: - counter[o] += 1 - for c in o.children: - recurse(c) - - if isinstance(op, imp.Initialise): - counter[op.indexsum] += 1 - elif isinstance(op, imp.Accumulate): - recurse_top(op.indexsum) - elif isinstance(op, imp.Evaluate): - recurse_top(op.expression) - elif isinstance(op, imp.Return): - recurse(op.expression) - elif isinstance(op, imp.ReturnAccumulate): - recurse(op.indexsum.children[0]) - elif isinstance(op, imp.Noop): - pass - else: - raise AssertionError("unhandled operation: %s" % type(op)) - - return counter diff --git a/gem/interpreter.py b/gem/interpreter.py deleted file mode 100644 index 13eeb44a..00000000 --- a/gem/interpreter.py +++ /dev/null @@ -1,362 +0,0 @@ -""" -An interpreter for GEM trees. -""" -import numpy -import operator -from collections import OrderedDict -from functools import singledispatch -import itertools - -from gem import gem, node -from gem.optimise import replace_delta - -__all__ = ("evaluate", ) - - -class Result(object): - """An array object that tracks which axes of the array correspond to - gem free indices (and what those free indices are). - - :arg arr: The array. - :arg fids: The free indices. - - The first ``len(fids)`` axes of the provided array correspond to - the free indices, the remaining axes are the shape of each entry. - """ - def __init__(self, arr, fids=None): - self.arr = arr - self.fids = fids if fids is not None else () - - def broadcast(self, fids): - """Given some free indices, return a broadcasted array which - contains extra dimensions that correspond to indices in fids - that are not in ``self.fids``. - - Note that inserted dimensions will have length one. - - :arg fids: The free indices for broadcasting. - """ - # Select free indices - axes = tuple(self.fids.index(fi) for fi in fids if fi in self.fids) - assert len(axes) == len(self.fids) - # Add shape - axes += tuple(range(len(self.fids), self.arr.ndim)) - # Move axes, insert extra axes - arr = numpy.transpose(self.arr, axes) - for i, fi in enumerate(fids): - if fi not in self.fids: - arr = numpy.expand_dims(arr, axis=i) - return arr - - def filter(self, idx, fids): - """Given an index tuple and some free indices, return a - "filtered" index tuple which removes entries that correspond - to indices in fids that are not in ``self.fids``. - - :arg idx: The index tuple to filter. - :arg fids: The free indices for the index tuple. - """ - return tuple(idx[fids.index(i)] for i in self.fids) + idx[len(fids):] - - def __getitem__(self, idx): - return self.arr[tuple(idx)] - - def __setitem__(self, idx, val): - self.arr[idx] = val - - @property - def tshape(self): - """The total shape of the result array.""" - return self.arr.shape - - @property - def fshape(self): - """The shape of the free index part of the result array.""" - return self.tshape[:len(self.fids)] - - @property - def shape(self): - """The shape of the shape part of the result array.""" - return self.tshape[len(self.fids):] - - def __repr__(self): - return "Result(%r, %r)" % (self.arr, self.fids) - - def __str__(self): - return repr(self) - - @classmethod - def empty(cls, *children, **kwargs): - """Build an empty Result object. - - :arg children: The children used to determine the shape and - free indices. - :kwarg dtype: The data type of the result array. - """ - dtype = kwargs.get("dtype", float) - assert all(children[0].shape == c.shape for c in children) - fids = [] - for f in itertools.chain(*(c.fids for c in children)): - if f not in fids: - fids.append(f) - shape = tuple(i.extent for i in fids) + children[0].shape - return cls(numpy.empty(shape, dtype=dtype), tuple(fids)) - - -@singledispatch -def _evaluate(expression, self): - """Evaluate an expression using a provided callback handler. - - :arg expression: The expression to evaluation. - :arg self: The callback handler (should provide bindings). - """ - raise ValueError("Unhandled node type %s" % type(expression)) - - -@_evaluate.register(gem.Zero) -def _evaluate_zero(e, self): - """Zeros produce an array of zeros.""" - return Result(numpy.zeros(e.shape, dtype=float)) - - -@_evaluate.register(gem.Failure) -def _evaluate_failure(e, self): - """Failure nodes produce NaNs.""" - return Result(numpy.full(e.shape, numpy.nan, dtype=float)) - - -@_evaluate.register(gem.Constant) -def _evaluate_constant(e, self): - """Constants return their array.""" - return Result(e.array) - - -@_evaluate.register(gem.Delta) -def _evaluate_delta(e, self): - """Lower delta and evaluate.""" - e, = replace_delta((e,)) - return self(e) - - -@_evaluate.register(gem.Variable) -def _evaluate_variable(e, self): - """Look up variables in the provided bindings.""" - try: - val = self.bindings[e] - except KeyError: - raise ValueError("Binding for %s not found" % e) - if val.shape != e.shape: - raise ValueError("Binding for %s has wrong shape. %s, not %s." % - (e, val.shape, e.shape)) - return Result(val) - - -@_evaluate.register(gem.Power) -@_evaluate.register(gem.Division) -@_evaluate.register(gem.Product) -@_evaluate.register(gem.Sum) -def _evaluate_operator(e, self): - op = {gem.Product: operator.mul, - gem.Division: operator.truediv, - gem.Sum: operator.add, - gem.Power: operator.pow}[type(e)] - - a, b = [self(o) for o in e.children] - result = Result.empty(a, b) - fids = result.fids - result.arr = op(a.broadcast(fids), b.broadcast(fids)) - return result - - -@_evaluate.register(gem.MathFunction) -def _evaluate_mathfunction(e, self): - ops = [self(o) for o in e.children] - result = Result.empty(*ops) - names = { - "abs": abs, - "log": numpy.log, - "real": operator.attrgetter("real"), - "imag": operator.attrgetter("imag"), - "conj": operator.methodcaller("conjugate"), - } - op = names[e.name] - for idx in numpy.ndindex(result.tshape): - result[idx] = op(*(o[o.filter(idx, result.fids)] for o in ops)) - return result - - -@_evaluate.register(gem.MaxValue) -@_evaluate.register(gem.MinValue) -def _evaluate_minmaxvalue(e, self): - ops = [self(o) for o in e.children] - result = Result.empty(*ops) - op = {gem.MinValue: min, - gem.MaxValue: max}[type(e)] - for idx in numpy.ndindex(result.tshape): - result[idx] = op(*(o[o.filter(idx, result.fids)] for o in ops)) - return result - - -@_evaluate.register(gem.Comparison) -def _evaluate_comparison(e, self): - ops = [self(o) for o in e.children] - op = {">": operator.gt, - ">=": operator.ge, - "==": operator.eq, - "!=": operator.ne, - "<": operator.lt, - "<=": operator.le}[e.operator] - result = Result.empty(*ops, dtype=bool) - for idx in numpy.ndindex(result.tshape): - result[idx] = op(*(o[o.filter(idx, result.fids)] for o in ops)) - return result - - -@_evaluate.register(gem.LogicalNot) -def _evaluate_logicalnot(e, self): - val = self(e.children[0]) - assert val.arr.dtype == numpy.dtype("bool") - result = Result.empty(val, bool) - for idx in numpy.ndindex(result.tshape): - result[idx] = not val[val.filter(idx, result.fids)] - return result - - -@_evaluate.register(gem.LogicalAnd) -def _evaluate_logicaland(e, self): - a, b = [self(o) for o in e.children] - assert a.arr.dtype == numpy.dtype("bool") - assert b.arr.dtype == numpy.dtype("bool") - result = Result.empty(a, b, bool) - for idx in numpy.ndindex(result.tshape): - result[idx] = a[a.filter(idx, result.fids)] and \ - b[b.filter(idx, result.fids)] - return result - - -@_evaluate.register(gem.LogicalOr) -def _evaluate_logicalor(e, self): - a, b = [self(o) for o in e.children] - assert a.arr.dtype == numpy.dtype("bool") - assert b.arr.dtype == numpy.dtype("bool") - result = Result.empty(a, b, dtype=bool) - for idx in numpy.ndindex(result.tshape): - result[idx] = a[a.filter(idx, result.fids)] or \ - b[b.filter(idx, result.fids)] - return result - - -@_evaluate.register(gem.Conditional) -def _evaluate_conditional(e, self): - cond, then, else_ = [self(o) for o in e.children] - assert cond.arr.dtype == numpy.dtype("bool") - result = Result.empty(cond, then, else_) - for idx in numpy.ndindex(result.tshape): - if cond[cond.filter(idx, result.fids)]: - result[idx] = then[then.filter(idx, result.fids)] - else: - result[idx] = else_[else_.filter(idx, result.fids)] - return result - - -@_evaluate.register(gem.Indexed) -def _evaluate_indexed(e, self): - """Indexing maps shape to free indices""" - val = self(e.children[0]) - fids = tuple(i for i in e.multiindex if isinstance(i, gem.Index)) - - idx = [] - # First pick up all the existing free indices - for _ in val.fids: - idx.append(slice(None)) - # Now grab the shape axes - for i in e.multiindex: - if isinstance(i, gem.Index): - # Free index, want entire extent - idx.append(slice(None)) - elif isinstance(i, gem.VariableIndex): - # Variable index, evaluate inner expression - result, = self(i.expression) - assert not result.tshape - idx.append(result[()]) - else: - # Fixed index, just pick that value - idx.append(i) - assert len(idx) == len(val.tshape) - return Result(val[idx], val.fids + fids) - - -@_evaluate.register(gem.ComponentTensor) -def _evaluate_componenttensor(e, self): - """Component tensors map free indices to shape.""" - val = self(e.children[0]) - axes = [] - fids = [] - # First grab the free indices that aren't bound - for a, f in enumerate(val.fids): - if f not in e.multiindex: - axes.append(a) - fids.append(f) - # Now the bound free indices - for i in e.multiindex: - axes.append(val.fids.index(i)) - # Now the existing shape - axes.extend(range(len(val.fshape), len(val.tshape))) - return Result(numpy.transpose(val.arr, axes=axes), - tuple(fids)) - - -@_evaluate.register(gem.IndexSum) -def _evaluate_indexsum(e, self): - """Index sums reduce over the given axis.""" - val = self(e.children[0]) - idx = tuple(map(val.fids.index, e.multiindex)) - rfids = tuple(fi for fi in val.fids if fi not in e.multiindex) - return Result(val.arr.sum(axis=idx), rfids) - - -@_evaluate.register(gem.ListTensor) -def _evaluate_listtensor(e, self): - """List tensors just turn into arrays.""" - ops = [self(o) for o in e.children] - tmp = Result.empty(*ops) - arrs = [numpy.broadcast_to(o.broadcast(tmp.fids), tmp.fshape) for o in ops] - arrs = numpy.moveaxis(numpy.asarray(arrs), 0, -1).reshape(tmp.fshape + e.shape) - return Result(arrs, tmp.fids) - - -@_evaluate.register(gem.Concatenate) -def _evaluate_concatenate(e, self): - """Concatenate nodes flatten and concatenate shapes.""" - ops = [self(o) for o in e.children] - fids = tuple(OrderedDict.fromkeys(itertools.chain(*(o.fids for o in ops)))) - fshape = tuple(i.extent for i in fids) - arrs = [] - for o in ops: - # Create temporary with correct shape - arr = numpy.empty(fshape + o.shape) - # Broadcast for extra free indices - arr[:] = o.broadcast(fids) - # Flatten shape - arr = arr.reshape(arr.shape[:arr.ndim-len(o.shape)] + (-1,)) - arrs.append(arr) - arrs = numpy.concatenate(arrs, axis=-1) - return Result(arrs, fids) - - -def evaluate(expressions, bindings=None): - """Evaluate some GEM expressions given variable bindings. - - :arg expressions: A single GEM expression, or iterable of - expressions to evaluate. - :kwarg bindings: An optional dict mapping GEM :class:`gem.Variable` - nodes to data. - :returns: a list of the evaluated expressions. - """ - try: - exprs = tuple(expressions) - except TypeError: - exprs = (expressions, ) - mapper = node.Memoizer(_evaluate) - mapper.bindings = bindings if bindings is not None else {} - return list(map(mapper, exprs)) diff --git a/gem/node.py b/gem/node.py deleted file mode 100644 index 31d99b9e..00000000 --- a/gem/node.py +++ /dev/null @@ -1,248 +0,0 @@ -"""Generic abstract node class and utility functions for creating -expression DAG languages.""" - -import collections - - -class Node(object): - """Abstract node class. - - Nodes are not meant to be modified. - - A node can reference other nodes; they are called children. A node - might contain data, or reference other objects which are not - themselves nodes; they are not called children. - - Both the children (if any) and non-child data (if any) are - required to create a node, or determine the equality of two - nodes. For reconstruction, however, only the new children are - necessary. - """ - - __slots__ = ('hash_value',) - - # Non-child data as the first arguments of the constructor. - # To be (potentially) overridden by derived node classes. - __front__ = () - - # Non-child data as the last arguments of the constructor. - # To be (potentially) overridden by derived node classes. - __back__ = () - - def _cons_args(self, children): - """Constructs an argument list for the constructor with - non-child data from 'self' and children from 'children'. - - Internally used utility function. - """ - front_args = [getattr(self, name) for name in self.__front__] - back_args = [getattr(self, name) for name in self.__back__] - - return tuple(front_args) + tuple(children) + tuple(back_args) - - def __reduce__(self): - # Gold version: - return type(self), self._cons_args(self.children) - - def reconstruct(self, *args): - """Reconstructs the node with new children from - 'args'. Non-child data are copied from 'self'. - - Returns a new object. - """ - return type(self)(*self._cons_args(args)) - - def __repr__(self): - cons_args = self._cons_args(self.children) - return "%s(%s)" % (type(self).__name__, ", ".join(map(repr, cons_args))) - - def __eq__(self, other): - """Provides equality testing with quick positive and negative - paths based on :func:`id` and :meth:`__hash__`. - """ - if self is other: - return True - elif hash(self) != hash(other): - return False - else: - return self.is_equal(other) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - """Provides caching for hash values.""" - try: - return self.hash_value - except AttributeError: - self.hash_value = self.get_hash() - return self.hash_value - - def is_equal(self, other): - """Equality predicate. - - This is the method to potentially override in derived classes, - not :meth:`__eq__` or :meth:`__ne__`. - """ - if type(self) is not type(other): - return False - self_consargs = self._cons_args(self.children) - other_consargs = other._cons_args(other.children) - return self_consargs == other_consargs - - def get_hash(self): - """Hash function. - - This is the method to potentially override in derived classes, - not :meth:`__hash__`. - """ - return hash((type(self),) + self._cons_args(self.children)) - - -def pre_traversal(expression_dags): - """Pre-order traversal of the nodes of expression DAGs.""" - seen = set() - lifo = [] - # Some roots might be same, but they must be visited only once. - # Keep the original ordering of roots, for deterministic code - # generation. - for root in expression_dags: - if root not in seen: - seen.add(root) - lifo.append(root) - - while lifo: - node = lifo.pop() - yield node - for child in reversed(node.children): - if child not in seen: - seen.add(child) - lifo.append(child) - - -def post_traversal(expression_dags): - """Post-order traversal of the nodes of expression DAGs.""" - seen = set() - lifo = [] - # Some roots might be same, but they must be visited only once. - # Keep the original ordering of roots, for deterministic code - # generation. - for root in expression_dags: - if root not in seen: - seen.add(root) - lifo.append((root, list(root.children))) - - while lifo: - node, deps = lifo[-1] - for i, dep in enumerate(deps): - if dep is not None and dep not in seen: - lifo.append((dep, list(dep.children))) - deps[i] = None - break - else: - yield node - seen.add(node) - lifo.pop() - - -# Default to the more efficient pre-order traversal -traversal = pre_traversal - - -def collect_refcount(expression_dags): - """Collects reference counts for a multi-root expression DAG.""" - result = collections.Counter(expression_dags) - for node in traversal(expression_dags): - result.update(node.children) - return result - - -def noop_recursive(function): - """No-op wrapper for functions with overridable recursive calls. - - :arg function: a function with parameters (value, rec), where - ``rec`` is expected to be a function used for - recursive calls. - :returns: a function with working recursion and nothing fancy - """ - def recursive(node): - return function(node, recursive) - return recursive - - -def noop_recursive_arg(function): - """No-op wrapper for functions with overridable recursive calls - and an argument. - - :arg function: a function with parameters (value, rec, arg), where - ``rec`` is expected to be a function used for - recursive calls. - :returns: a function with working recursion and nothing fancy - """ - def recursive(node, arg): - return function(node, recursive, arg) - return recursive - - -class Memoizer(object): - """Caching wrapper for functions with overridable recursive calls. - The lifetime of the cache is the lifetime of the object instance. - - :arg function: a function with parameters (value, rec), where - ``rec`` is expected to be a function used for - recursive calls. - :returns: a function with working recursion and caching - """ - def __init__(self, function): - self.cache = {} - self.function = function - - def __call__(self, node): - try: - return self.cache[node] - except KeyError: - result = self.function(node, self) - self.cache[node] = result - return result - - -class MemoizerArg(object): - """Caching wrapper for functions with overridable recursive calls - and an argument. The lifetime of the cache is the lifetime of the - object instance. - - :arg function: a function with parameters (value, rec, arg), where - ``rec`` is expected to be a function used for - recursive calls. - :returns: a function with working recursion and caching - """ - def __init__(self, function): - self.cache = {} - self.function = function - - def __call__(self, node, arg): - cache_key = (node, arg) - try: - return self.cache[cache_key] - except KeyError: - result = self.function(node, self, arg) - self.cache[cache_key] = result - return result - - -def reuse_if_untouched(node, self): - """Reuse if untouched recipe""" - new_children = list(map(self, node.children)) - if all(nc == c for nc, c in zip(new_children, node.children)): - return node - else: - return node.reconstruct(*new_children) - - -def reuse_if_untouched_arg(node, self, arg): - """Reuse if touched recipe propagating an extra argument""" - new_children = [self(child, arg) for child in node.children] - if all(nc == c for nc, c in zip(new_children, node.children)): - return node - else: - return node.reconstruct(*new_children) diff --git a/gem/optimise.py b/gem/optimise.py deleted file mode 100644 index 3194e6ef..00000000 --- a/gem/optimise.py +++ /dev/null @@ -1,683 +0,0 @@ -"""A set of routines implementing various transformations on GEM -expressions.""" - -from collections import OrderedDict, defaultdict -from functools import singledispatch, partial, reduce -from itertools import combinations, permutations, zip_longest - -import numpy - -from gem.utils import groupby -from gem.node import (Memoizer, MemoizerArg, reuse_if_untouched, - reuse_if_untouched_arg, traversal) -from gem.gem import (Node, Failure, Identity, Literal, Zero, - Product, Sum, Comparison, Conditional, Division, - Index, VariableIndex, Indexed, FlexiblyIndexed, - IndexSum, ComponentTensor, ListTensor, Delta, - partial_indexed, one) - - -@singledispatch -def literal_rounding(node, self): - """Perform FFC rounding of FIAT tabulation matrices on the literals of - a GEM expression. - - :arg node: root of the expression - :arg self: function for recursive calls - """ - raise AssertionError("cannot handle type %s" % type(node)) - - -literal_rounding.register(Node)(reuse_if_untouched) - - -@literal_rounding.register(Literal) -def literal_rounding_literal(node, self): - table = node.array - epsilon = self.epsilon - # Mimic the rounding applied at COFFEE formatting, which in turn - # mimics FFC formatting. - one_decimal = numpy.asarray(numpy.round(table, 1)) - one_decimal[numpy.logical_not(one_decimal)] = 0 # no minus zeros - return Literal(numpy.where(abs(table - one_decimal) < epsilon, one_decimal, table)) - - -def ffc_rounding(expression, epsilon): - """Perform FFC rounding of FIAT tabulation matrices on the literals of - a GEM expression. - - :arg expression: GEM expression - :arg epsilon: tolerance limit for rounding - """ - mapper = Memoizer(literal_rounding) - mapper.epsilon = epsilon - return mapper(expression) - - -@singledispatch -def _replace_division(node, self): - """Replace division with multiplication - - :param node: root of expression - :param self: function for recursive calls - """ - raise AssertionError("cannot handle type %s" % type(node)) - - -_replace_division.register(Node)(reuse_if_untouched) - - -@_replace_division.register(Division) -def _replace_division_division(node, self): - a, b = node.children - return Product(self(a), Division(one, self(b))) - - -def replace_division(expressions): - """Replace divisions with multiplications in expressions""" - mapper = Memoizer(_replace_division) - return list(map(mapper, expressions)) - - -@singledispatch -def replace_indices(node, self, subst): - """Replace free indices in a GEM expression. - - :arg node: root of the expression - :arg self: function for recursive calls - :arg subst: tuple of pairs; each pair is a substitution - rule with a free index to replace and an index to - replace with. - """ - raise AssertionError("cannot handle type %s" % type(node)) - - -replace_indices.register(Node)(reuse_if_untouched_arg) - - -@replace_indices.register(Delta) -def replace_indices_delta(node, self, subst): - substitute = dict(subst) - i = substitute.get(node.i, node.i) - j = substitute.get(node.j, node.j) - if i == node.i and j == node.j: - return node - else: - return Delta(i, j) - - -@replace_indices.register(Indexed) -def replace_indices_indexed(node, self, subst): - child, = node.children - substitute = dict(subst) - multiindex = tuple(substitute.get(i, i) for i in node.multiindex) - if isinstance(child, ComponentTensor): - # Indexing into ComponentTensor - # Inline ComponentTensor and augment the substitution rules - substitute.update(zip(child.multiindex, multiindex)) - return self(child.children[0], tuple(sorted(substitute.items()))) - else: - # Replace indices - new_child = self(child, subst) - if new_child == child and multiindex == node.multiindex: - return node - else: - return Indexed(new_child, multiindex) - - -@replace_indices.register(FlexiblyIndexed) -def replace_indices_flexiblyindexed(node, self, subst): - child, = node.children - assert not child.free_indices - - substitute = dict(subst) - dim2idxs = tuple( - (offset, tuple((substitute.get(i, i), s) for i, s in idxs)) - for offset, idxs in node.dim2idxs - ) - - if dim2idxs == node.dim2idxs: - return node - else: - return FlexiblyIndexed(child, dim2idxs) - - -def filtered_replace_indices(node, self, subst): - """Wrapper for :func:`replace_indices`. At each call removes - substitution rules that do not apply.""" - filtered_subst = tuple((k, v) for k, v in subst if k in node.free_indices) - return replace_indices(node, self, filtered_subst) - - -def remove_componenttensors(expressions): - """Removes all ComponentTensors in multi-root expression DAG.""" - mapper = MemoizerArg(filtered_replace_indices) - return [mapper(expression, ()) for expression in expressions] - - -@singledispatch -def _constant_fold_zero(node, self): - raise AssertionError("cannot handle type %s" % type(node)) - - -_constant_fold_zero.register(Node)(reuse_if_untouched) - - -@_constant_fold_zero.register(Literal) -def _constant_fold_zero_literal(node, self): - if (node.array == 0).all(): - # All zeros, make symbolic zero - return Zero(node.shape) - else: - return node - - -@_constant_fold_zero.register(ListTensor) -def _constant_fold_zero_listtensor(node, self): - new_children = list(map(self, node.children)) - if all(isinstance(nc, Zero) for nc in new_children): - return Zero(node.shape) - elif all(nc == c for nc, c in zip(new_children, node.children)): - return node - else: - return node.reconstruct(*new_children) - - -def constant_fold_zero(exprs): - """Produce symbolic zeros from Literals - - :arg exprs: An iterable of gem expressions. - :returns: A list of gem expressions where any Literal containing - only zeros is replaced by symbolic Zero of the appropriate - shape. - - We need a separate path for ListTensor so that its `reconstruct` - method will not be called when the new children are `Zero()`s; - otherwise Literal `0`s would be reintroduced. - """ - mapper = Memoizer(_constant_fold_zero) - return [mapper(e) for e in exprs] - - -def _select_expression(expressions, index): - """Helper function to select an expression from a list of - expressions with an index. This function expect sanitised input, - one should normally call :py:func:`select_expression` instead. - - :arg expressions: a list of expressions - :arg index: an index (free, fixed or variable) - :returns: an expression - """ - expr = expressions[0] - if all(e == expr for e in expressions): - return expr - - types = set(map(type, expressions)) - if types <= {Indexed, Zero}: - multiindex, = set(e.multiindex for e in expressions if isinstance(e, Indexed)) - # Shape only determined by free indices - shape = tuple(i.extent for i in multiindex if isinstance(i, Index)) - - def child(expression): - if isinstance(expression, Indexed): - return expression.children[0] - elif isinstance(expression, Zero): - return Zero(shape) - return Indexed(_select_expression(list(map(child, expressions)), index), multiindex) - - if types <= {Literal, Zero, Failure}: - return partial_indexed(ListTensor(expressions), (index,)) - - if types <= {ComponentTensor, Zero}: - shape, = set(e.shape for e in expressions) - multiindex = tuple(Index(extent=d) for d in shape) - children = remove_componenttensors([Indexed(e, multiindex) for e in expressions]) - return ComponentTensor(_select_expression(children, index), multiindex) - - if len(types) == 1: - cls, = types - if cls.__front__ or cls.__back__: - raise NotImplementedError("How to factorise {} expressions?".format(cls.__name__)) - assert all(len(e.children) == len(expr.children) for e in expressions) - assert len(expr.children) > 0 - - return expr.reconstruct(*[_select_expression(nth_children, index) - for nth_children in zip(*[e.children - for e in expressions])]) - - raise NotImplementedError("No rule for factorising expressions of this kind.") - - -def select_expression(expressions, index): - """Select an expression from a list of expressions with an index. - Semantically equivalent to - - partial_indexed(ListTensor(expressions), (index,)) - - but has a much more optimised implementation. - - :arg expressions: a list of expressions of the same shape - :arg index: an index (free, fixed or variable) - :returns: an expression of the same shape as the given expressions - """ - # Check arguments - shape = expressions[0].shape - assert all(e.shape == shape for e in expressions) - - # Sanitise input expressions - alpha = tuple(Index() for s in shape) - exprs = remove_componenttensors([Indexed(e, alpha) for e in expressions]) - - # Factor the expressions recursively and convert result - selected = _select_expression(exprs, index) - return ComponentTensor(selected, alpha) - - -def delta_elimination(sum_indices, factors): - """IndexSum-Delta cancellation. - - :arg sum_indices: free indices for contractions - :arg factors: product factors - :returns: optimised (sum_indices, factors) - """ - sum_indices = list(sum_indices) # copy for modification - - def substitute(expression, from_, to_): - if from_ not in expression.free_indices: - return expression - elif isinstance(expression, Delta): - mapper = MemoizerArg(filtered_replace_indices) - return mapper(expression, ((from_, to_),)) - else: - return Indexed(ComponentTensor(expression, (from_,)), (to_,)) - - delta_queue = [(f, index) - for f in factors if isinstance(f, Delta) - for index in (f.i, f.j) if index in sum_indices] - while delta_queue: - delta, from_ = delta_queue[0] - to_, = list({delta.i, delta.j} - {from_}) - - sum_indices.remove(from_) - - factors = [substitute(f, from_, to_) for f in factors] - - delta_queue = [(f, index) - for f in factors if isinstance(f, Delta) - for index in (f.i, f.j) if index in sum_indices] - - return sum_indices, factors - - -def associate(operator, operands): - """Apply associativity rules to construct an operation-minimal expression tree. - - For best performance give factors that have different set of free indices. - - :arg operator: associative binary operator - :arg operands: list of operands - - :returns: (reduced expression, # of floating-point operations) - """ - if len(operands) > 32: - # O(N^3) algorithm - raise NotImplementedError("Not expected such a complicated expression!") - - def count(pair): - """Operation count to reduce a pair of GEM expressions""" - a, b = pair - extents = [i.extent for i in set().union(a.free_indices, b.free_indices)] - return numpy.prod(extents, dtype=int) - - flops = 0 - while len(operands) > 1: - # Greedy algorithm: choose a pair of operands that are the - # cheapest to reduce. - a, b = min(combinations(operands, 2), key=count) - flops += count((a, b)) - # Remove chosen factors, append their product - operands.remove(a) - operands.remove(b) - operands.append(operator(a, b)) - result, = operands - return result, flops - - -def sum_factorise(sum_indices, factors): - """Optimise a tensor product through sum factorisation. - - :arg sum_indices: free indices for contractions - :arg factors: product factors - :returns: optimised GEM expression - """ - if len(factors) == 0 and len(sum_indices) == 0: - # Empty product - return one - - if len(sum_indices) > 6: - raise NotImplementedError("Too many indices for sum factorisation!") - - # Form groups by free indices - groups = groupby(factors, key=lambda f: f.free_indices) - groups = [reduce(Product, terms) for _, terms in groups] - - # Sum factorisation - expression = None - best_flops = numpy.inf - - # Consider all orderings of contraction indices - for ordering in permutations(sum_indices): - terms = groups[:] - flops = 0 - # Apply contraction index by index - for sum_index in ordering: - # Select terms that need to be part of the contraction - contract = [t for t in terms if sum_index in t.free_indices] - deferred = [t for t in terms if sum_index not in t.free_indices] - - # Optimise associativity - product, flops_ = associate(Product, contract) - term = IndexSum(product, (sum_index,)) - flops += flops_ + numpy.prod([i.extent for i in product.free_indices], dtype=int) - - # Replace the contracted terms with the result of the - # contraction. - terms = deferred + [term] - - # If some contraction indices were independent, then we may - # still have several terms at this point. - expr, flops_ = associate(Product, terms) - flops += flops_ - - if flops < best_flops: - expression = expr - best_flops = flops - - return expression - - -def make_sum(summands): - """Constructs an operation-minimal sum of GEM expressions.""" - groups = groupby(summands, key=lambda f: f.free_indices) - summands = [reduce(Sum, terms) for _, terms in groups] - result, flops = associate(Sum, summands) - return result - - -def make_product(factors, sum_indices=()): - """Constructs an operation-minimal (tensor) product of GEM expressions.""" - return sum_factorise(sum_indices, factors) - - -def make_rename_map(): - """Creates an rename map for reusing the same index renames.""" - return defaultdict(Index) - - -def make_renamer(rename_map): - r"""Creates a function for renaming indices when expanding products of - IndexSums, i.e. applying to following rule: - - (\sum_i a_i)*(\sum_i b_i) ===> \sum_{i,i'} a_i*b_{i'} - - :arg rename_map: An rename map for renaming indices the same way - as functions returned by other calls of this - function. - :returns: A function that takes an iterable of indices to rename, - and returns (renamed indices, applier), where applier is - a function that remap the free indices of GEM - expressions from the old to the new indices. - """ - def _renamer(rename_map, current_set, incoming): - renamed = [] - renames = [] - for i in incoming: - j = i - while j in current_set: - j = rename_map[j] - current_set.add(j) - renamed.append(j) - if i != j: - renames.append((i, j)) - - if renames: - def applier(expr): - pairs = [(i, j) for i, j in renames if i in expr.free_indices] - if pairs: - current, renamed = zip(*pairs) - return Indexed(ComponentTensor(expr, current), renamed) - else: - return expr - else: - applier = lambda expr: expr - - return tuple(renamed), applier - return partial(_renamer, rename_map, set()) - - -def traverse_product(expression, stop_at=None, rename_map=None): - """Traverses a product tree and collects factors, also descending into - tensor contractions (IndexSum). The nominators of divisions are - also broken up, but not the denominators. - - :arg expression: a GEM expression - :arg stop_at: Optional predicate on GEM expressions. If specified - and returns true for some subexpression, that - subexpression is not broken into further factors - even if it is a product-like expression. - :arg rename_map: an rename map for consistent index renaming - :returns: (sum_indices, terms) - - sum_indices: list of indices to sum over - - terms: list of product terms - """ - if rename_map is None: - rename_map = make_rename_map() - renamer = make_renamer(rename_map) - - sum_indices = [] - terms = [] - - stack = [expression] - while stack: - expr = stack.pop() - if stop_at is not None and stop_at(expr): - terms.append(expr) - elif isinstance(expr, IndexSum): - indices, applier = renamer(expr.multiindex) - sum_indices.extend(indices) - stack.extend(remove_componenttensors(map(applier, expr.children))) - elif isinstance(expr, Product): - stack.extend(reversed(expr.children)) - elif isinstance(expr, Division): - # Break up products in the dividend, but not in divisor. - dividend, divisor = expr.children - if dividend == one: - terms.append(expr) - else: - stack.append(Division(one, divisor)) - stack.append(dividend) - else: - terms.append(expr) - - return sum_indices, terms - - -def traverse_sum(expression, stop_at=None): - """Traverses a summation tree and collects summands. - - :arg expression: a GEM expression - :arg stop_at: Optional predicate on GEM expressions. If specified - and returns true for some subexpression, that - subexpression is not broken into further summands - even if it is an addition. - :returns: list of summand expressions - """ - stack = [expression] - result = [] - while stack: - expr = stack.pop() - if stop_at is not None and stop_at(expr): - result.append(expr) - elif isinstance(expr, Sum): - stack.extend(reversed(expr.children)) - else: - result.append(expr) - return result - - -def contraction(expression, ignore=None): - """Optimise the contractions of the tensor product at the root of - the expression, including: - - - IndexSum-Delta cancellation - - Sum factorisation - - :arg ignore: Optional set of indices to ignore when applying sum - factorisation (otherwise all summation indices will be - considered). Use this if your expression has many contraction - indices. - - This routine was designed with finite element coefficient - evaluation in mind. - """ - # Eliminate annoying ComponentTensors - expression, = remove_componenttensors([expression]) - - # Flatten product tree, eliminate deltas, sum factorise - def rebuild(expression): - sum_indices, factors = delta_elimination(*traverse_product(expression)) - factors = remove_componenttensors(factors) - if ignore is not None: - # TODO: This is a really blunt instrument and one might - # plausibly want the ignored indices to be contracted on - # the inside rather than the outside. - extra = tuple(i for i in sum_indices if i in ignore) - to_factor = tuple(i for i in sum_indices if i not in ignore) - return IndexSum(sum_factorise(to_factor, factors), extra) - else: - return sum_factorise(sum_indices, factors) - - # Sometimes the value shape is composed as a ListTensor, which - # could get in the way of decomposing factors. In particular, - # this is the case for H(div) and H(curl) conforming tensor - # product elements. So if ListTensors are used, they are pulled - # out to be outermost, so we can straightforwardly factorise each - # of its entries. - lt_fis = OrderedDict() # ListTensor free indices - for node in traversal((expression,)): - if isinstance(node, Indexed): - child, = node.children - if isinstance(child, ListTensor): - lt_fis.update(zip_longest(node.multiindex, ())) - lt_fis = tuple(index for index in lt_fis if index in expression.free_indices) - - if lt_fis: - # Rebuild each split component - tensor = ComponentTensor(expression, lt_fis) - entries = [Indexed(tensor, zeta) for zeta in numpy.ndindex(tensor.shape)] - entries = remove_componenttensors(entries) - return Indexed(ListTensor( - numpy.array(list(map(rebuild, entries))).reshape(tensor.shape) - ), lt_fis) - else: - # Rebuild whole expression at once - return rebuild(expression) - - -@singledispatch -def _replace_delta(node, self): - raise AssertionError("cannot handle type %s" % type(node)) - - -_replace_delta.register(Node)(reuse_if_untouched) - - -@_replace_delta.register(Delta) -def _replace_delta_delta(node, self): - i, j = node.i, node.j - - if isinstance(i, Index) or isinstance(j, Index): - if isinstance(i, Index) and isinstance(j, Index): - assert i.extent == j.extent - if isinstance(i, Index): - assert i.extent is not None - size = i.extent - if isinstance(j, Index): - assert j.extent is not None - size = j.extent - return Indexed(Identity(size), (i, j)) - else: - def expression(index): - if isinstance(index, int): - return Literal(index) - elif isinstance(index, VariableIndex): - return index.expression - else: - raise ValueError("Cannot convert running index to expression.") - e_i = expression(i) - e_j = expression(j) - return Conditional(Comparison("==", e_i, e_j), one, Zero()) - - -def replace_delta(expressions): - """Lowers all Deltas in a multi-root expression DAG.""" - mapper = Memoizer(_replace_delta) - return list(map(mapper, expressions)) - - -@singledispatch -def _unroll_indexsum(node, self): - """Unrolls IndexSums below a certain extent. - - :arg node: root of the expression - :arg self: function for recursive calls - """ - raise AssertionError("cannot handle type %s" % type(node)) - - -_unroll_indexsum.register(Node)(reuse_if_untouched) - - -@_unroll_indexsum.register(IndexSum) # noqa -def _(node, self): - unroll = tuple(filter(self.predicate, node.multiindex)) - if unroll: - # Unrolling - summand = self(node.children[0]) - shape = tuple(index.extent for index in unroll) - unrolled = reduce(Sum, - (Indexed(ComponentTensor(summand, unroll), alpha) - for alpha in numpy.ndindex(shape)), - Zero()) - return IndexSum(unrolled, tuple(index for index in node.multiindex - if index not in unroll)) - else: - return reuse_if_untouched(node, self) - - -def unroll_indexsum(expressions, predicate): - """Unrolls IndexSums below a specified extent. - - :arg expressions: list of expression DAGs - :arg predicate: a predicate function on :py:class:`Index` objects - that tells whether to unroll a particular index - :returns: list of expression DAGs with some unrolled IndexSums - """ - mapper = Memoizer(_unroll_indexsum) - mapper.predicate = predicate - return list(map(mapper, expressions)) - - -def aggressive_unroll(expression): - """Aggressively unrolls all loop structures.""" - # Unroll expression shape - if expression.shape: - tensor = numpy.empty(expression.shape, dtype=object) - for alpha in numpy.ndindex(expression.shape): - tensor[alpha] = Indexed(expression, alpha) - expression, = remove_componenttensors((ListTensor(tensor),)) - - # Unroll summation - expression, = unroll_indexsum((expression,), predicate=lambda index: True) - expression, = remove_componenttensors((expression,)) - return expression diff --git a/gem/pprint.py b/gem/pprint.py deleted file mode 100644 index 9c245123..00000000 --- a/gem/pprint.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Pretty-printing GEM expressions.""" -from collections import defaultdict -import itertools - -from functools import singledispatch - -from gem import gem -from gem.node import collect_refcount, post_traversal - - -class Context(object): - def __init__(self): - expr_counter = itertools.count(1) - self.expr_name = defaultdict(lambda: "${}".format(next(expr_counter))) - index_counter = itertools.count(1) - self.index_name = defaultdict(lambda: "i_{}".format(next(index_counter))) - self.index_names = set() - - def force_expression(self, expr): - assert isinstance(expr, gem.Node) - return self.expr_name[expr] - - def expression(self, expr): - assert isinstance(expr, gem.Node) - return self.expr_name.get(expr) - - def index(self, index): - assert isinstance(index, gem.Index) - if index.name is None: - name = self.index_name[index] - elif index.name not in self.index_names: - name = index.name - self.index_name[index] = name - else: - name_ = index.name - for i in itertools.count(1): - name = "{}~{}".format(name_, i) - if name not in self.index_names: - break - self.index_names.add(name) - return name - - -global_context = Context() - - -def pprint(expression_dags, context=global_context): - refcount = collect_refcount(expression_dags) - - def force(node): - if isinstance(node, gem.Variable): - return False - if node.shape: - return True - if isinstance(node, (gem.Constant, gem.Indexed, gem.FlexiblyIndexed)): - return False - return refcount[node] > 1 - - for node in post_traversal(expression_dags): - if force(node): - context.force_expression(node) - - name = context.expression(node) - if name is not None: - print(make_decl(node, name, context), '=', to_str(node, context, top=True)) - - for i, root in enumerate(expression_dags): - name = "#%d" % (i + 1) - print(make_decl(root, name, context), '=', to_str(root, context)) - - -def make_decl(node, name, ctx): - result = name - if node.shape: - result += '[' + ','.join(map(repr, node.shape)) + ']' - if node.free_indices: - result += '{' + ','.join(map(ctx.index, node.free_indices)) + '}' - return result - - -def to_str(expr, ctx, prec=None, top=False): - if not top and ctx.expression(expr): - result = ctx.expression(expr) - if expr.free_indices: - result += '{' + ','.join(map(ctx.index, expr.free_indices)) + '}' - return result - else: - return _to_str(expr, ctx, prec=prec) - - -@singledispatch -def _to_str(node, ctx, prec): - raise AssertionError("GEM node expected") - - -@_to_str.register(gem.Node) -def _to_str_node(node, ctx, prec): - front_args = [repr(getattr(node, name)) for name in node.__front__] - back_args = [repr(getattr(node, name)) for name in node.__back__] - children = [to_str(child, ctx) for child in node.children] - return "%s(%s)" % (type(node).__name__, ", ".join(front_args + children + back_args)) - - -@_to_str.register(gem.Zero) -def _to_str_zero(node, ctx, prec): - assert not node.shape - return "%g" % node.value - - -@_to_str.register(gem.Literal) -def _to_str_literal(node, ctx, prec): - if node.shape: - return repr(node.array.tolist()) - else: - return "%g" % node.value - - -@_to_str.register(gem.Variable) -def _to_str_variable(node, ctx, prec): - return node.name - - -@_to_str.register(gem.ListTensor) -def _to_str_listtensor(node, ctx, prec): - def recurse_rank(array): - if len(array.shape) > 1: - return '[' + ', '.join(map(recurse_rank, array)) + ']' - else: - return '[' + ', '.join(to_str(item, ctx) for item in array) + ']' - - return recurse_rank(node.array) - - -@_to_str.register(gem.Indexed) -def _to_str_indexed(node, ctx, prec): - child, = node.children - result = to_str(child, ctx) - dimensions = [] - for index in node.multiindex: - if isinstance(index, gem.Index): - dimensions.append(ctx.index(index)) - elif isinstance(index, int): - dimensions.append(str(index)) - else: - dimensions.append(to_str(index.expression, ctx)) - result += '[' + ','.join(dimensions) + ']' - return result - - -@_to_str.register(gem.FlexiblyIndexed) -def _to_str_flexiblyindexed(node, ctx, prec): - child, = node.children - result = to_str(child, ctx) - dimensions = [] - for offset, idxs in node.dim2idxs: - parts = [] - if offset: - parts.append(str(offset)) - for index, stride in idxs: - index_name = ctx.index(index) - assert stride - if stride == 1: - parts.append(index_name) - else: - parts.append(index_name + "*" + str(stride)) - if parts: - dimensions.append(' + '.join(parts)) - else: - dimensions.append('0') - if dimensions: - result += '[' + ','.join(dimensions) + ']' - return result - - -@_to_str.register(gem.IndexSum) -def _to_str_indexsum(node, ctx, prec): - result = 'Sum_{' + ','.join(map(ctx.index, node.multiindex)) + '} ' + to_str(node.children[0], ctx, prec=2) - if prec is not None and prec > 2: - result = '({})'.format(result) - return result - - -@_to_str.register(gem.ComponentTensor) -def _to_str_componenttensor(node, ctx, prec): - return to_str(node.children[0], ctx) + '|' + ','.join(ctx.index(i) for i in node.multiindex) - - -@_to_str.register(gem.Sum) -def _to_str_sum(node, ctx, prec): - children = [to_str(child, ctx, prec=1) for child in node.children] - result = " + ".join(children) - if prec is not None and prec > 1: - result = "({})".format(result) - return result - - -@_to_str.register(gem.Product) -def _to_str_product(node, ctx, prec): - children = [to_str(child, ctx, prec=3) for child in node.children] - result = "*".join(children) - if prec is not None and prec > 3: - result = "({})".format(result) - return result - - -@_to_str.register(gem.MathFunction) -def _to_str_mathfunction(node, ctx, prec): - child, = node.children - return node.name + "(" + to_str(child, ctx) + ")" diff --git a/gem/refactorise.py b/gem/refactorise.py deleted file mode 100644 index 2ca6e4cc..00000000 --- a/gem/refactorise.py +++ /dev/null @@ -1,304 +0,0 @@ -"""Data structures and algorithms for generic expansion and -refactorisation.""" - -from collections import Counter, OrderedDict, defaultdict, namedtuple -from functools import singledispatch -from itertools import product -from sys import intern - -from gem.node import Memoizer, traversal -from gem.gem import (Node, Conditional, Zero, Product, Sum, Indexed, - ListTensor, one, MathFunction) -from gem.optimise import (remove_componenttensors, sum_factorise, - traverse_product, traverse_sum, unroll_indexsum, - make_rename_map, make_renamer) - - -# Refactorisation labels - -ATOMIC = intern('atomic') -"""Label: the expression need not be broken up into smaller parts""" - -COMPOUND = intern('compound') -"""Label: the expression must be broken up into smaller parts""" - -OTHER = intern('other') -"""Label: the expression is irrelevant with regards to refactorisation""" - - -Monomial = namedtuple('Monomial', ['sum_indices', 'atomics', 'rest']) -"""Monomial type, representation of a tensor product with some -distinguished factors (called atomics). - -- sum_indices: indices to sum over -- atomics: tuple of expressions classified as ATOMIC -- rest: a single expression classified as OTHER - -A :py:class:`Monomial` is a structured description of the expression: - -.. code-block:: python - - IndexSum(reduce(Product, atomics, rest), sum_indices) - -""" - - -class MonomialSum(object): - """Represents a sum of :py:class:`Monomial`s. - - The set of :py:class:`Monomial` summands are represented as a - mapping from a pair of unordered ``sum_indices`` and unordered - ``atomics`` to a ``rest`` GEM expression. This representation - makes it easier to merge similar monomials. - """ - def __init__(self): - # (unordered sum_indices, unordered atomics) -> rest - self.monomials = defaultdict(Zero) - - # We shall retain ordering for deterministic code generation: - # - # (unordered sum_indices, unordered atomics) -> - # (ordered sum_indices, ordered atomics) - self.ordering = OrderedDict() - - def __len__(self): - return len(self.ordering) - - def add(self, sum_indices, atomics, rest): - """Updates the :py:class:`MonomialSum` adding a new monomial.""" - sum_indices = tuple(sum_indices) - sum_indices_set = frozenset(sum_indices) - # Sum indices cannot have duplicates - assert len(sum_indices) == len(sum_indices_set) - - atomics = tuple(atomics) - atomics_set = frozenset(Counter(atomics).items()) - - assert isinstance(rest, Node) - - key = (sum_indices_set, atomics_set) - self.monomials[key] = Sum(self.monomials[key], rest) - self.ordering.setdefault(key, (sum_indices, atomics)) - - def __iter__(self): - """Iteration yields :py:class:`Monomial` objects""" - for key, (sum_indices, atomics) in self.ordering.items(): - rest = self.monomials[key] - yield Monomial(sum_indices, atomics, rest) - - @staticmethod - def sum(*args): - """Sum of multiple :py:class:`MonomialSum`s""" - result = MonomialSum() - for arg in args: - assert isinstance(arg, MonomialSum) - # Optimised implementation: no need to decompose and - # reconstruct key. - for key, rest in arg.monomials.items(): - result.monomials[key] = Sum(result.monomials[key], rest) - for key, value in arg.ordering.items(): - result.ordering.setdefault(key, value) - return result - - @staticmethod - def product(*args, **kwargs): - """Product of multiple :py:class:`MonomialSum`s""" - rename_map = kwargs.pop('rename_map', None) - if rename_map is None: - rename_map = make_rename_map() - if kwargs: - raise ValueError("Unrecognised keyword argument: " + kwargs.pop()) - - result = MonomialSum() - for monomials in product(*args): - renamer = make_renamer(rename_map) - sum_indices = [] - atomics = [] - rest = one - for s, a, r in monomials: - s_, applier = renamer(s) - sum_indices.extend(s_) - atomics.extend(map(applier, a)) - rest = Product(applier(r), rest) - result.add(sum_indices, atomics, rest) - return result - - -class FactorisationError(Exception): - """Raised when factorisation fails to achieve some desired form.""" - pass - - -@singledispatch -def _collect_monomials(expression, self): - """Refactorises an expression into a sum-of-products form, using - distributivity rules (i.e. a*(b + c) -> a*b + a*c). Expansion - proceeds until all "compound" expressions are broken up. - - :arg expression: a GEM expression to refactorise - :arg self: function for recursive calls - - :returns: :py:class:`MonomialSum` - - :raises FactorisationError: Failed to break up some "compound" - expressions with expansion. - """ - # Phase 1: Collect and categorise product terms - def stop_at(expr): - # Break up compounds only - return self.classifier(expr) != COMPOUND - common_indices, terms = traverse_product(expression, stop_at=stop_at) - common_indices = tuple(common_indices) - - common_atomics = [] - common_others = [] - compounds = [] - for term in terms: - label = self.classifier(term) - if label == ATOMIC: - common_atomics.append(term) - elif label == COMPOUND: - compounds.append(term) - elif label == OTHER: - common_others.append(term) - else: - raise ValueError("Classifier returned illegal value.") - common_atomics = tuple(common_atomics) - - # Phase 2: Attempt to break up compound terms into summands - sums = [] - for expr in compounds: - summands = traverse_sum(expr, stop_at=stop_at) - if len(summands) <= 1 and not isinstance(expr, (Conditional, MathFunction)): - # Compound term is not an addition, avoid infinite - # recursion and fail gracefully raising an exception. - raise FactorisationError(expr) - # Recurse into each summand, concatenate their results - sums.append(MonomialSum.sum(*map(self, summands))) - - # Phase 3: Expansion - # - # Each element of ``sums`` is a MonomialSum. Expansion produces a - # series (representing a sum) of products of monomials. - result = MonomialSum() - for s, a, r in MonomialSum.product(*sums, rename_map=self.rename_map): - renamer = make_renamer(self.rename_map) - renamer(common_indices) # update current_set - s_, applier = renamer(s) - - all_indices = common_indices + s_ - atomics = common_atomics + tuple(map(applier, a)) - - # All free indices that appear in atomic terms - atomic_indices = set().union(*[atomic.free_indices - for atomic in atomics]) - - # Sum indices that appear in atomic terms - # (will go to the result :py:class:`Monomial`) - sum_indices = tuple(index for index in all_indices - if index in atomic_indices) - - # Sum indices that do not appear in atomic terms - # (can factorise them over atomic terms immediately) - rest_indices = tuple(index for index in all_indices - if index not in atomic_indices) - - # Not really sum factorisation, but rather just an optimised - # way of building a product. - rest = sum_factorise(rest_indices, common_others + [applier(r)]) - - result.add(sum_indices, atomics, rest) - return result - - -@_collect_monomials.register(MathFunction) -def _collect_monomials_mathfunction(expression, self): - name = expression.name - if name in {"conj", "real", "imag"}: - # These are allowed to be applied to arguments, and hence must - # be dealt with specially. Just push the function onto each - # entry in the monomialsum of the child. - # NOTE: This presently assumes that the "atomics" part of a - # MonomialSum are real. This is true for the coffee, tensor, - # spectral modes: the atomics are indexed tabulation matrices - # (which are guaranteed real). - # If the classifier puts (potentially) complex expressions in - # atomics, then this code needs fixed. - child_ms, = map(self, expression.children) - result = MonomialSum() - for k, v in child_ms.monomials.items(): - result.monomials[k] = MathFunction(name, v) - result.ordering = child_ms.ordering.copy() - return result - else: - return _collect_monomials.dispatch(MathFunction.mro()[1])(expression, self) - - -@_collect_monomials.register(Conditional) -def _collect_monomials_conditional(expression, self): - """Refactorises a conditional expression into a sum-of-products form, - pulling only "atomics" out of conditional expressions. - - :arg expression: a GEM expression to refactorise - :arg self: function for recursive calls - - :returns: :py:class:`MonomialSum` - """ - condition, then, else_ = expression.children - # Recursively refactorise both branches to `MonomialSum`s - then_ms = self(then) - else_ms = self(else_) - - result = MonomialSum() - # For each set of atomics, create a new Conditional node. Atomics - # are considered safe to be pulled out of conditionals, but other - # expressions remain inside conditional branches. - zero = Zero() - for k in then_ms.monomials.keys() | else_ms.monomials.keys(): - _then = then_ms.monomials.get(k, zero) - _else = else_ms.monomials.get(k, zero) - result.monomials[k] = Conditional(condition, _then, _else) - - # Construct a deterministic ordering - result.ordering = then_ms.ordering.copy() - for k, v in else_ms.ordering.items(): - result.ordering.setdefault(k, v) - return result - - -def collect_monomials(expressions, classifier): - """Refactorises expressions into a sum-of-products form, using - distributivity rules (i.e. a*(b + c) -> a*b + a*c). Expansion - proceeds until all "compound" expressions are broken up. - - :arg expressions: GEM expressions to refactorise - :arg classifier: a function that can classify any GEM expression - as ``ATOMIC``, ``COMPOUND``, or ``OTHER``. This - classification drives the factorisation. - - :returns: list of :py:class:`MonomialSum`s - - :raises FactorisationError: Failed to break up some "compound" - expressions with expansion. - """ - # Get ComponentTensors out of the way - expressions = remove_componenttensors(expressions) - - # Get ListTensors out of the way - must_unroll = [] # indices to unroll - for node in traversal(expressions): - if isinstance(node, Indexed): - child, = node.children - if isinstance(child, ListTensor) and classifier(node) == COMPOUND: - must_unroll.extend(node.multiindex) - if must_unroll: - must_unroll = set(must_unroll) - expressions = unroll_indexsum(expressions, - predicate=lambda i: i in must_unroll) - expressions = remove_componenttensors(expressions) - - # Finally, refactorise expressions - mapper = Memoizer(_collect_monomials) - mapper.classifier = classifier - mapper.rename_map = make_rename_map() - return list(map(mapper, expressions)) diff --git a/gem/scheduling.py b/gem/scheduling.py deleted file mode 100644 index 831ee048..00000000 --- a/gem/scheduling.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Schedules operations to evaluate a multi-root expression DAG, -forming an ordered list of Impero terminals.""" - -import collections -import functools - -from gem import gem, impero -from gem.node import collect_refcount - - -class OrderedDefaultDict(collections.OrderedDict): - """A dictionary that provides a default value and ordered iteration. - - :arg factory: The callable used to create the default value. - - See :class:`collections.OrderedDict` for description of the - remaining arguments. - """ - def __init__(self, factory, *args, **kwargs): - self.factory = factory - super(OrderedDefaultDict, self).__init__(*args, **kwargs) - - def __missing__(self, key): - val = self[key] = self.factory() - return val - - -class ReferenceStager(object): - """Provides staging for nodes in reference counted expression - DAGs. A callback function is called once the reference count is - exhausted.""" - - def __init__(self, reference_count, callback): - """Initialises a ReferenceStager. - - :arg reference_count: initial reference counts for all - expected nodes - :arg callback: function to call on each node when - reference count is exhausted - """ - self.waiting = reference_count.copy() - self.callback = callback - - def decref(self, o): - """Decreases the reference count of a node, and possibly - triggering a callback (when the reference count drops to - zero).""" - assert 1 <= self.waiting[o] - - self.waiting[o] -= 1 - if self.waiting[o] == 0: - self.callback(o) - - def empty(self): - """All reference counts exhausted?""" - return not any(self.waiting.values()) - - -class Queue(object): - """Special queue for operation scheduling. GEM / Impero nodes are - inserted when they are ready to be scheduled, i.e. any operation - which depends on the operation to be inserted must have been - scheduled already. This class implements a heuristic for ordering - operations within the constraints in a way which aims to achieve - maximum loop fusion to minimise the size of temporaries which need - to be introduced. - """ - def __init__(self, callback): - """Initialises a Queue. - - :arg callback: function called on each element "popped" from the queue - """ - # Must have deterministic iteration over the queue - self.queue = OrderedDefaultDict(list) - self.callback = callback - - def insert(self, indices, elem): - """Insert element into queue. - - :arg indices: loop indices used by the scheduling heuristic - :arg elem: element to be scheduled - """ - self.queue[indices].append(elem) - - def process(self): - """Pops elements from the queue and calls the callback - function on them until the queue is empty. The callback - function can insert further elements into the queue. - """ - indices = () - while self.queue: - # Find innermost non-empty outer loop - while indices not in (i[:len(indices)] for i in self.queue.keys()): - indices = indices[:-1] - - # Pick a loop - for i in self.queue.keys(): - if i[:len(indices)] == indices: - indices = i - break - - while self.queue[indices]: - self.callback(self.queue[indices].pop()) - del self.queue[indices] - - -def handle(ops, push, decref, node): - """Helper function for scheduling""" - if isinstance(node, gem.Variable): - # Declared in the kernel header - pass - elif isinstance(node, gem.Constant): - # Constant literals inlined, unless tensor-valued - if node.shape: - ops.append(impero.Evaluate(node)) - elif isinstance(node, gem.Zero): # should rarely happen - assert not node.shape - elif isinstance(node, (gem.Indexed, gem.FlexiblyIndexed)): - # Indexing always inlined - decref(node.children[0]) - elif isinstance(node, gem.IndexSum): - ops.append(impero.Noop(node)) - push(impero.Accumulate(node)) - elif isinstance(node, gem.Node): - ops.append(impero.Evaluate(node)) - for child in node.children: - decref(child) - elif isinstance(node, impero.Initialise): - ops.append(node) - elif isinstance(node, impero.Accumulate): - ops.append(node) - push(impero.Initialise(node.indexsum)) - decref(node.indexsum.children[0]) - elif isinstance(node, impero.Return): - ops.append(node) - decref(node.expression) - elif isinstance(node, impero.ReturnAccumulate): - ops.append(node) - decref(node.indexsum.children[0]) - else: - raise AssertionError("no handler for node type %s" % type(node)) - - -def emit_operations(assignments, get_indices, emit_return_accumulate=True): - """Makes an ordering of operations to evaluate a multi-root - expression DAG. - - :arg assignments: Iterable of (variable, expression) pairs. - The value of expression is written into variable - upon execution. - :arg get_indices: mapping from GEM nodes to an ordering of free - indices - :arg emit_return_accumulate: emit ReturnAccumulate nodes? Set to - False if the output variables are not guaranteed - zero on entry to the kernel. - :returns: list of Impero terminals correctly ordered to evaluate - the assignments - """ - # Prepare reference counts - refcount = collect_refcount([e for v, e in assignments]) - - # Stage return operations - staging = [] - for variable, expression in assignments: - if emit_return_accumulate and \ - refcount[expression] == 1 and isinstance(expression, gem.IndexSum) \ - and set(variable.free_indices) == set(expression.free_indices): - staging.append(impero.ReturnAccumulate(variable, expression)) - refcount[expression] -= 1 - else: - staging.append(impero.Return(variable, expression)) - - # Prepare data structures - def push_node(node): - queue.insert(get_indices(node), node) - - def push_op(op): - queue.insert(op.loop_shape(get_indices), op) - - ops = [] - - stager = ReferenceStager(refcount, push_node) - queue = Queue(functools.partial(handle, ops, push_op, stager.decref)) - - # Enqueue return operations - for op in staging: - push_op(op) - - # Schedule operations - queue.process() - - # Assert that nothing left unprocessed - assert stager.empty() - - # Return - ops.reverse() - return ops diff --git a/gem/unconcatenate.py b/gem/unconcatenate.py deleted file mode 100644 index ce6e30b3..00000000 --- a/gem/unconcatenate.py +++ /dev/null @@ -1,270 +0,0 @@ -"""Utility functions for decomposing Concatenate nodes. - -The exported functions are flatten and unconcatenate. -- flatten: destroys the structure preserved within Concatenate nodes, - essentially reducing FInAT provided tabulations to what - FIAT could have provided, so old code can continue to work. -- unconcatenate: split up (variable, expression) pairs along - Concatenate nodes, thus recovering the structure - within them, yet eliminating the Concatenate nodes. - -Let us see an example on unconcatenate. Let us consider the form - - div(v) * dx - -where v is an RTCF7 test function. This means that the assembled -local vector has 8 * 7 + 7 * 8 = 112 entries. So the compilation of -the form starts with a single assignment pair [(v, e)]. v is now the -indexed return variable, something equivalent to - - Indexed(Variable('A', (112,)), (j,)) - -where j is the basis function index of the argument. e is just a GEM -quadrature expression with j as its only free index. This will -contain the tabulation of the RTCF7 element, which will cause -something like - - C_j := Indexed(Concatenate(A, B), (j,)) - -to appear as a subexpression in e. unconcatenate splits e along C_j -into e_1 and e_2 such that - - e_1 := e /. C_j -> A_{ja1,ja2}, and - e_2 := e /. C_j -> B_{jb1,jb2}. - -The split indices ja1, ja2, jb1, and jb2 have extents 8, 7, 7, and 8 -respectively (see the RTCF7 element construction above). So the -result of unconcatenate will be the list of pairs - - [(v_1, e_2), (v_2, e_2)] - -where v_1 is the first 56 entries of v, reshaped as an 8 x 7 matrix, -indexed with (ja1, ja2), and similarly, v_2 is the second 56 entries -of v, reshaped as a 7 x 8 matrix, indexed with (jb1, jb2). - -The unconcatenated form allows for sum factorisation of tensor product -elements as usual. This pair splitting is also applicable to -coefficient evaluation: take the local basis function coefficients as -the variable, the FInAT tabulation of the element as the expression, -and apply "matrix-vector multifunction" for each pair after -unconcatenation, and then add up the results. -""" - -from functools import singledispatch -from itertools import chain - -import numpy - -from gem.node import Memoizer, reuse_if_untouched -from gem.gem import (ComponentTensor, Concatenate, FlexiblyIndexed, - Index, Indexed, Literal, Node, partial_indexed, - reshape, view) -from gem.optimise import remove_componenttensors -from gem.interpreter import evaluate - - -__all__ = ['flatten', 'unconcatenate'] - - -def find_group(expressions): - """Finds a full set of indexed Concatenate nodes with the same - free index, if any such node exists. - - Pre-condition: ComponentTensor nodes surrounding Concatenate nodes - must be removed. - - :arg expressions: a multi-root GEM expression DAG - :returns: a list of GEM nodes, or None - """ - free_indices = set().union(chain(*[e.free_indices for e in expressions])) - - # Result variables - index = None - nodes = [] - - # Sui generis pre-order traversal so that we can avoid going - # unnecessarily deep in the DAG. - seen = set() - lifo = [] - for root in expressions: - if root not in seen: - seen.add(root) - lifo.append(root) - - while lifo: - node = lifo.pop() - if not free_indices.intersection(node.free_indices): - continue - - if isinstance(node, Indexed): - child, = node.children - if isinstance(child, Concatenate): - i, = node.multiindex - assert i in free_indices - if (index or i) == i: - index = i - nodes.append(node) - # Skip adding children - continue - - for child in reversed(node.children): - if child not in seen: - seen.add(child) - lifo.append(child) - - return index and nodes - - -def split_variable(variable_ref, index, multiindices): - """Splits a flexibly indexed variable along a concatenation index. - - :param variable_ref: flexibly indexed variable to split - :param index: :py:class:`Concatenate` index to split along - :param multiindices: one multiindex for each split variable - - :returns: generator of split indexed variables - """ - assert isinstance(variable_ref, FlexiblyIndexed) - other_indices = list(variable_ref.index_ordering()) - other_indices.remove(index) - other_indices = tuple(other_indices) - data = ComponentTensor(variable_ref, (index,) + other_indices) - slices = [slice(None)] * len(other_indices) - shapes = [(other_index.extent,) for other_index in other_indices] - - offset = 0 - for multiindex in multiindices: - shape = tuple(index.extent for index in multiindex) - size = numpy.prod(shape, dtype=int) - slice_ = slice(offset, offset + size) - offset += size - - sub_ref = Indexed(reshape(view(data, slice_, *slices), - shape, *shapes), - multiindex + other_indices) - sub_ref, = remove_componenttensors((sub_ref,)) - yield sub_ref - - -def _replace_node(node, self): - """Replace subexpressions using a given mapping. - - :param node: root of expression - :param self: function for recursive calls - """ - assert isinstance(node, Node) - if self.cut(node): - return node - try: - return self.mapping[node] - except KeyError: - return reuse_if_untouched(node, self) - - -def replace_node(expression, mapping, cut=None): - """Replace subexpressions using a given mapping. - - :param expression: a GEM expression - :param mapping: a :py:class:`dict` containing the substitutions - :param cut: cutting predicate; if returns true, it is assumed that - no replacements would take place in the subexpression. - """ - mapper = Memoizer(_replace_node) - mapper.mapping = mapping - mapper.cut = cut or (lambda node: False) - return mapper(expression) - - -def _unconcatenate(cache, pairs): - # Tail-call recursive core of unconcatenate. - # Assumes that input has already been sanitised. - concat_group = find_group([e for v, e in pairs]) - if concat_group is None: - return pairs - - # Get the index split - concat_ref = next(iter(concat_group)) - assert isinstance(concat_ref, Indexed) - concat_expr, = concat_ref.children - index, = concat_ref.multiindex - assert isinstance(concat_expr, Concatenate) - try: - multiindices = cache[index] - except KeyError: - multiindices = tuple(tuple(Index(extent=d) for d in child.shape) - for child in concat_expr.children) - cache[index] = multiindices - - def cut(node): - """No need to rebuild expression of independent of the - relevant concatenation index.""" - return index not in node.free_indices - - # Build Concatenate node replacement mappings - mappings = [{} for i in range(len(multiindices))] - for concat_ref in concat_group: - concat_expr, = concat_ref.children - for i in range(len(multiindices)): - sub_ref = Indexed(concat_expr.children[i], multiindices[i]) - sub_ref, = remove_componenttensors((sub_ref,)) - mappings[i][concat_ref] = sub_ref - - # Finally, split assignment pairs - split_pairs = [] - for var, expr in pairs: - if index not in var.free_indices: - split_pairs.append((var, expr)) - else: - for v, m in zip(split_variable(var, index, multiindices), mappings): - split_pairs.append((v, replace_node(expr, m, cut))) - - # Run again, there may be other Concatenate groups - return _unconcatenate(cache, split_pairs) - - -def unconcatenate(pairs, cache=None): - """Splits a list of (indexed variable, expression) pairs along - :py:class:`Concatenate` nodes embedded in the expressions. - - :param pairs: list of (indexed variable, expression) pairs - :param cache: index splitting cache :py:class:`dict` (optional) - - :returns: list of (indexed variable, expression) pairs - """ - # Set up cache - if cache is None: - cache = {} - - # Eliminate index renaming due to ComponentTensor nodes - exprs = remove_componenttensors([e for v, e in pairs]) - pairs = [(v, e) for (v, _), e in zip(pairs, exprs)] - - return _unconcatenate(cache, pairs) - - -@singledispatch -def _flatten(node, self): - """Replace Concatenate nodes with Literal nodes. - - :arg node: root of the expression - :arg self: function for recursive calls - """ - raise AssertionError("cannot handle type %s" % type(node)) - - -_flatten.register(Node)(reuse_if_untouched) - - -@_flatten.register(Concatenate) -def _flatten_concatenate(node, self): - result, = evaluate([node]) - return partial_indexed(Literal(result.arr), result.fids) - - -def flatten(expressions): - """Flatten Concatenate nodes, and destroy the structure they express. - - :arg expressions: a multi-root expression DAG - """ - mapper = Memoizer(_flatten) - return list(map(mapper, expressions)) diff --git a/gem/utils.py b/gem/utils.py deleted file mode 100644 index 12e0e0f6..00000000 --- a/gem/utils.py +++ /dev/null @@ -1,108 +0,0 @@ -import collections - - -# This is copied from PyOP2, and it is here to be available for both -# FInAT and TSFC without depending on PyOP2. -class cached_property(object): - """A read-only @property that is only evaluated once. The value is cached - on the object itself rather than the function or class; this should prevent - memory leakage.""" - def __init__(self, fget, doc=None): - self.fget = fget - self.__doc__ = doc or fget.__doc__ - self.__name__ = fget.__name__ - self.__module__ = fget.__module__ - - def __get__(self, obj, cls): - if obj is None: - return self - obj.__dict__[self.__name__] = result = self.fget(obj) - return result - - -def groupby(iterable, key=None): - """Groups objects by their keys. - - :arg iterable: an iterable - :arg key: key function - - :returns: list of (group key, list of group members) pairs - """ - if key is None: - key = lambda x: x - groups = collections.OrderedDict() - for elem in iterable: - groups.setdefault(key(elem), []).append(elem) - return groups.items() - - -def make_proxy_class(name, cls): - """Constructs a proxy class for a given class. - - :arg name: name of the new proxy class - :arg cls: the wrapee class to create a proxy for - """ - def __init__(self, wrapee): - self._wrapee = wrapee - - def make_proxy_property(name): - def getter(self): - return getattr(self._wrapee, name) - return property(getter) - - dct = {'__init__': __init__} - for attr in dir(cls): - if not attr.startswith('_'): - dct[attr] = make_proxy_property(attr) - return type(name, (), dct) - - -# Implementation of dynamically scoped variables in Python. -class UnsetVariableError(LookupError): - pass - - -_unset = object() - - -class DynamicallyScoped(object): - """A dynamically scoped variable.""" - - def __init__(self, default_value=_unset): - if default_value is _unset: - self._head = None - else: - self._head = (default_value, None) - - def let(self, value): - return _LetBlock(self, value) - - @property - def value(self): - if self._head is None: - raise UnsetVariableError("Dynamically scoped variable not set.") - result, tail = self._head - return result - - -class _LetBlock(object): - """Context manager representing a dynamic scope.""" - - def __init__(self, variable, value): - self.variable = variable - self.value = value - self.state = None - - def __enter__(self): - assert self.state is None - value = self.value - tail = self.variable._head - scope = (value, tail) - self.variable._head = scope - self.state = scope - - def __exit__(self, exc_type, exc_value, traceback): - variable = self.variable - assert self.state is variable._head - value, variable._head = variable._head - self.state = None diff --git a/setup.py b/setup.py index e300a605..ee9b0a23 100644 --- a/setup.py +++ b/setup.py @@ -4,4 +4,4 @@ setup(name="tsfc", version=version, - packages=["gem", "tsfc", "tsfc.kernel_interface"]) + packages=["tsfc", "tsfc.kernel_interface"])