Skip to content
63 changes: 41 additions & 22 deletions pyop2/local_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Union

import coffee
from coffee.visitors import EstimateFlops
import loopy as lp
from loopy.tools import LoopyKeyBuilder
import numpy as np

from pyop2 import version
from pyop2.configuration import configuration
from pyop2.datatypes import ScalarType
from pyop2.exceptions import NameTypeError
from pyop2.types import Access
Expand Down Expand Up @@ -152,28 +152,9 @@ def arguments(self):
for acc, dtype in zip(self.accesses, self.dtypes))

@cached_property
@abc.abstractmethod
def num_flops(self):
"""Compute the numbers of FLOPs if not already known."""
if self.flop_count is not None:
return self.flop_count

if not configuration["compute_kernel_flops"]:
return 0

if isinstance(self.code, coffee.base.Node):
v = coffee.visitors.EstimateFlops()
return v.visit(self.code)
elif isinstance(self.code, lp.TranslationUnit):
op_map = lp.get_op_map(
self.code.copy(options=lp.Options(ignore_boostable_into=True),
silenced_warnings=['insn_count_subgroups_upper_bound',
'get_x_map_guessing_subgroup_size',
'summing_if_branches_ops']),
subgroup_size='guess')
return op_map.filter_by(name=['add', 'sub', 'mul', 'div'],
dtype=[ScalarType]).eval_and_sum({})
else:
return 0

def __eq__(self, other):
if not isinstance(other, LocalKernel):
Expand Down Expand Up @@ -214,6 +195,12 @@ def dtypes(self):
def dtypes(self, dtypes):
self._dtypes = dtypes

@cached_property
def num_flops(self):
"""Set the numbers of FLOPs to 0 if not already known,
because there is no way to measure or estimate the FLOPS for string kernels. """
return self.flop_count if self.flop_count is not None else 0


class CoffeeLocalKernel(LocalKernel):
""":class:`LocalKernel` class where `code` has type :class:`coffee.base.Node`."""
Expand All @@ -231,13 +218,19 @@ def dtypes(self):
def dtypes(self, dtypes):
self._dtypes = dtypes

@cached_property
def num_flops(self):
"""Compute the numbers of FLOPs if not already known
using COFFEE's FLOP estimation algorithm."""
return self.flop_count if self.flop_count is not None else EstimateFlops().visit(self.code)


class LoopyLocalKernel(LocalKernel):
""":class:`LocalKernel` class where `code` has type :class:`loopy.LoopKernel`
or :class:`loopy.TranslationUnit`.
"""

@validate_type(("code", (lp.LoopKernel, lp.TranslationUnit), TypeError))
@validate_type(("code", (lp.TranslationUnit), TypeError))
def __init__(self, code, *args, **kwargs):
super().__init__(code, *args, **kwargs)

Expand All @@ -250,3 +243,29 @@ def _loopy_arguments(self):
"""Return the loopy arguments associated with the kernel."""
return tuple(a for a in self.code.callables_table[self.name].subkernel.args
if isinstance(a, lp.ArrayArg))

@cached_property
def num_flops(self):
"""Compute the numbers of FLOPs if not already known
using Loo.py's FLOP counting algorithm."""
if self.flop_count is not None:
return self.flop_count
else:
assert isinstance(self.code, lp.TranslationUnit), "LocalLoopyKernels code should be a translation unit."
# in order to silence the warnings we need to access
# the callable kernels in the translation unit
prog = self.code.with_entrypoints(self.name)
knl = prog.default_entrypoint
warnings = list(knl.silenced_warnings)
warnings.extend(['insn_count_subgroups_upper_bound',
'get_x_map_guessing_subgroup_size',
'summing_if_branches_ops'])
knl = knl.copy(silenced_warnings=warnings)
# for extrusion utils the layer arg must be fixed
# because usually it would be a value which is passed in from the global kernel
# theoretically this changes the result but not the FLOP count
knl = lp.fix_parameters(knl, layer=1)
prog = prog.with_kernel(knl)
op_map = lp.get_op_map(prog, subgroup_size=1)
return op_map.filter_by(name=['add', 'sub', 'mul', 'div'],
dtype=[ScalarType]).eval_and_sum({})
3 changes: 2 additions & 1 deletion pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def _compute(self, part):
:arg part: The :class:`SetPartition` to compute over.
"""
with self._compute_event():
PETSc.Log.logFlops(part.size*self.num_flops)
if configuration["compute_kernel_flops"]:
PETSc.Log.logFlops(part.size*self.num_flops)
self.global_kernel(self.comm, part.offset, part.offset+part.size, *self.arglist)

@cached_property
Expand Down