diff --git a/pyop2/local_kernel.py b/pyop2/local_kernel.py index 4807463b8..a6c04b9c4 100644 --- a/pyop2/local_kernel.py +++ b/pyop2/local_kernel.py @@ -4,12 +4,12 @@ from typing import Union import coffee +from coffee.visitors import EstimateFlops import loopy as lp from loopy.tools import LoopyKeyBuilder import numpy as np from pyop2 import version -from pyop2.configuration import configuration from pyop2.datatypes import ScalarType from pyop2.exceptions import NameTypeError from pyop2.types import Access @@ -152,28 +152,9 @@ def arguments(self): for acc, dtype in zip(self.accesses, self.dtypes)) @cached_property + @abc.abstractmethod def num_flops(self): """Compute the numbers of FLOPs if not already known.""" - if self.flop_count is not None: - return self.flop_count - - if not configuration["compute_kernel_flops"]: - return 0 - - if isinstance(self.code, coffee.base.Node): - v = coffee.visitors.EstimateFlops() - return v.visit(self.code) - elif isinstance(self.code, lp.TranslationUnit): - op_map = lp.get_op_map( - self.code.copy(options=lp.Options(ignore_boostable_into=True), - silenced_warnings=['insn_count_subgroups_upper_bound', - 'get_x_map_guessing_subgroup_size', - 'summing_if_branches_ops']), - subgroup_size='guess') - return op_map.filter_by(name=['add', 'sub', 'mul', 'div'], - dtype=[ScalarType]).eval_and_sum({}) - else: - return 0 def __eq__(self, other): if not isinstance(other, LocalKernel): @@ -214,6 +195,12 @@ def dtypes(self): def dtypes(self, dtypes): self._dtypes = dtypes + @cached_property + def num_flops(self): + """Set the numbers of FLOPs to 0 if not already known, + because there is no way to measure or estimate the FLOPS for string kernels. """ + return self.flop_count if self.flop_count is not None else 0 + class CoffeeLocalKernel(LocalKernel): """:class:`LocalKernel` class where `code` has type :class:`coffee.base.Node`.""" @@ -231,13 +218,19 @@ def dtypes(self): def dtypes(self, dtypes): self._dtypes = dtypes + @cached_property + def num_flops(self): + """Compute the numbers of FLOPs if not already known + using COFFEE's FLOP estimation algorithm.""" + return self.flop_count if self.flop_count is not None else EstimateFlops().visit(self.code) + class LoopyLocalKernel(LocalKernel): """:class:`LocalKernel` class where `code` has type :class:`loopy.LoopKernel` or :class:`loopy.TranslationUnit`. """ - @validate_type(("code", (lp.LoopKernel, lp.TranslationUnit), TypeError)) + @validate_type(("code", (lp.TranslationUnit), TypeError)) def __init__(self, code, *args, **kwargs): super().__init__(code, *args, **kwargs) @@ -250,3 +243,29 @@ def _loopy_arguments(self): """Return the loopy arguments associated with the kernel.""" return tuple(a for a in self.code.callables_table[self.name].subkernel.args if isinstance(a, lp.ArrayArg)) + + @cached_property + def num_flops(self): + """Compute the numbers of FLOPs if not already known + using Loo.py's FLOP counting algorithm.""" + if self.flop_count is not None: + return self.flop_count + else: + assert isinstance(self.code, lp.TranslationUnit), "LocalLoopyKernels code should be a translation unit." + # in order to silence the warnings we need to access + # the callable kernels in the translation unit + prog = self.code.with_entrypoints(self.name) + knl = prog.default_entrypoint + warnings = list(knl.silenced_warnings) + warnings.extend(['insn_count_subgroups_upper_bound', + 'get_x_map_guessing_subgroup_size', + 'summing_if_branches_ops']) + knl = knl.copy(silenced_warnings=warnings) + # for extrusion utils the layer arg must be fixed + # because usually it would be a value which is passed in from the global kernel + # theoretically this changes the result but not the FLOP count + knl = lp.fix_parameters(knl, layer=1) + prog = prog.with_kernel(knl) + op_map = lp.get_op_map(prog, subgroup_size=1) + return op_map.filter_by(name=['add', 'sub', 'mul', 'div'], + dtype=[ScalarType]).eval_and_sum({}) diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 8384268cf..3990ad242 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -188,7 +188,8 @@ def _compute(self, part): :arg part: The :class:`SetPartition` to compute over. """ with self._compute_event(): - PETSc.Log.logFlops(part.size*self.num_flops) + if configuration["compute_kernel_flops"]: + PETSc.Log.logFlops(part.size*self.num_flops) self.global_kernel(self.comm, part.offset, part.offset+part.size, *self.arglist) @cached_property