diff --git a/doc/tutorial.rst b/doc/tutorial.rst index 8e65e4591..2a041b700 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -1566,7 +1566,7 @@ information provided. Now we will count the operations: >>> op_map = lp.get_op_map(knl, subgroup_size=32) >>> print(op_map) - Op(np:dtype('float32'), add, subgroup, "stats_knl"): ... + Op(np:dtype('float32'), add, subgroup, "stats_knl", None): ... Each line of output will look roughly like:: @@ -1628,7 +1628,7 @@ together into keys containing only the specified fields: >>> op_map_dtype = op_map.group_by('dtype') >>> print(op_map_dtype) - Op(np:dtype('float32'), None, None): ... + Op(np:dtype('float32'), None, None, None): ... >>> f32op_count = op_map_dtype[lp.Op(dtype=np.float32) ... ].eval_with_dict(param_dict) @@ -1654,7 +1654,7 @@ we'll continue using the kernel from the previous example: >>> mem_map = lp.get_mem_access_map(knl, subgroup_size=32) >>> print(mem_map) - MemAccess(global, np:dtype('float32'), {}, {}, load, a, None, subgroup, 'stats_knl'): ... + MemAccess(global, np:dtype('float32'), {}, {}, load, a, None, subgroup, 'stats_knl', None): ... Each line of output will look roughly like:: @@ -1725,13 +1725,13 @@ using :func:`loopy.ToCountMap.to_bytes` and :func:`loopy.ToCountMap.group_by`: >>> bytes_map = mem_map.to_bytes() >>> print(bytes_map) - MemAccess(global, np:dtype('float32'), {}, {}, load, a, None, subgroup, 'stats_knl'): ... + MemAccess(global, np:dtype('float32'), {}, {}, load, a, None, subgroup, 'stats_knl', None): ... >>> global_ld_st_bytes = bytes_map.filter_by(mtype=['global'] ... ).group_by('direction') >>> print(global_ld_st_bytes) - MemAccess(None, None, None, None, load, None, None, None, None): ... - MemAccess(None, None, None, None, store, None, None, None, None): ... + MemAccess(None, None, None, None, load, None, None, None, None, None): ... + MemAccess(None, None, None, None, store, None, None, None, None, None): ... >>> loaded = global_ld_st_bytes[lp.MemAccess(direction='load') ... ].eval_with_dict(param_dict) @@ -1768,12 +1768,12 @@ this time. ... outer_tag="l.1", inner_tag="l.0") >>> mem_map = lp.get_mem_access_map(knl_consec, subgroup_size=32) >>> print(mem_map) - MemAccess(global, np:dtype('float32'), {0: 1, 1: 128}, {}, load, a, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float32'), {0: 1, 1: 128}, {}, load, b, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float32'), {0: 1, 1: 128}, {}, store, c, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 1, 1: 128}, {}, load, g, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 1, 1: 128}, {}, load, h, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 1, 1: 128}, {}, store, e, None, workitem, 'stats_knl'): ... + MemAccess(global, np:dtype('float32'), {0: 1, 1: 128}, {}, load, a, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float32'), {0: 1, 1: 128}, {}, load, b, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float32'), {0: 1, 1: 128}, {}, store, c, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float64'), {0: 1, 1: 128}, {}, load, g, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float64'), {0: 1, 1: 128}, {}, load, h, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float64'), {0: 1, 1: 128}, {}, store, e, None, workitem, 'stats_knl', None): ... With this parallelization, consecutive work-items will access consecutive array @@ -1813,12 +1813,12 @@ we'll switch the inner and outer tags in our parallelization of the kernel: ... outer_tag="l.0", inner_tag="l.1") >>> mem_map = lp.get_mem_access_map(knl_nonconsec, subgroup_size=32) >>> print(mem_map) - MemAccess(global, np:dtype('float32'), {0: 128, 1: 1}, {}, load, a, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float32'), {0: 128, 1: 1}, {}, load, b, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float32'), {0: 128, 1: 1}, {}, store, c, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 128, 1: 1}, {}, load, g, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 128, 1: 1}, {}, load, h, None, workitem, 'stats_knl'): ... - MemAccess(global, np:dtype('float64'), {0: 128, 1: 1}, {}, store, e, None, workitem, 'stats_knl'): ... + MemAccess(global, np:dtype('float32'), {0: 128, 1: 1}, {}, load, a, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float32'), {0: 128, 1: 1}, {}, load, b, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float32'), {0: 128, 1: 1}, {}, store, c, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float64'), {0: 128, 1: 1}, {}, load, g, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float64'), {0: 128, 1: 1}, {}, load, h, None, workitem, 'stats_knl', None): ... + MemAccess(global, np:dtype('float64'), {0: 128, 1: 1}, {}, store, e, None, workitem, 'stats_knl', None): ... With this parallelization, consecutive work-items will access *nonconsecutive* @@ -1871,7 +1871,7 @@ kernel from the previous example: >>> sync_map = lp.get_synchronization_map(knl) >>> print(sync_map) - Sync(kernel_launch, stats_knl): [l, m, n] -> { 1 } + Sync(kernel_launch, stats_knl, None): [l, m, n] -> { 1 } We can evaluate this polynomial using :meth:`islpy.PwQPolynomial.eval_with_dict`: @@ -1931,8 +1931,8 @@ count the barriers using :func:`loopy.get_synchronization_map`: >>> sync_map = lp.get_synchronization_map(knl) >>> print(sync_map) - Sync(barrier_local, loopy_kernel): { 1000 } - Sync(kernel_launch, loopy_kernel): { 1 } + Sync(barrier_local, loopy_kernel, None): { 1000 } + Sync(kernel_launch, loopy_kernel, None): { 1 } Based on the kernel code printed above, we would expect each work-item to diff --git a/loopy/statistics.py b/loopy/statistics.py index baabb99aa..bddbee51e 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __copyright__ = """ Copyright (C) 2015 James Stevens Copyright (C) 2018 Kaushik Kulkarni @@ -25,20 +27,34 @@ THE SOFTWARE. """ +from typing import (FrozenSet, Generic, Iterable, Optional, TypeVar, Union, Tuple, + Dict, Callable, Mapping, Any) +from enum import Enum, auto as enum_auto +from dataclasses import dataclass, replace from functools import partial, cached_property -from islpy import dim_type +from immutables import Map + +from islpy import PwQPolynomial, dim_type import islpy as isl +import pymbolic.primitives as p from pymbolic.mapper import CombineMapper +from pytools.tag import Tag +from pytools import memoize_method import loopy as lp +from loopy.kernel.array import ArrayBase +from loopy.kernel.instruction import InstructionBase +from loopy.kernel import LoopKernel from loopy.kernel.data import ( - MultiAssignmentBase, TemporaryVariable, AddressSpace) + InameImplementationTag, MultiAssignmentBase, AddressSpace) from loopy.diagnostic import warn_with_kernel, LoopyError -from loopy.symbolic import CoefficientCollector -from pytools import ImmutableRecord, memoize_method -from loopy.kernel.function_interface import CallableKernel +from loopy.symbolic import (CoefficientCollector, Reduction, SubArrayRef, + TaggedExpression) +from loopy.kernel.function_interface import CallableKernel, InKernelCallable from loopy.translation_unit import TranslationUnit +from loopy.typing import Expression +from loopy.types import LoopyType __doc__ = """ @@ -48,8 +64,12 @@ .. autoclass:: ToCountMap .. autoclass:: ToCountPolynomialMap .. autoclass:: CountGranularity +.. autoclass:: OpType .. autoclass:: Op +.. autoclass:: AccessDirection .. autoclass:: MemAccess +.. autoclass:: SynchronizationKind +.. autoclass:: Sync .. autofunction:: get_op_map .. autofunction:: get_mem_access_map @@ -79,28 +99,29 @@ # - Test for the subkernel functionality need to be written -def get_kernel_parameter_space(kernel): +def get_kernel_parameter_space(kernel: LoopKernel) -> isl.Space: return isl.Space.create_from_names(kernel.isl_context, set=[], params=sorted(list(kernel.outer_params()))).params() -def get_kernel_zero_pwqpolynomial(kernel): +def get_kernel_zero_pwqpolynomial(kernel: LoopKernel) -> PwQPolynomial: space = get_kernel_parameter_space(kernel) space = space.insert_dims(dim_type.out, 0, 1) - return isl.PwQPolynomial.zero(space) + return PwQPolynomial.zero(space) # {{{ GuardedPwQPolynomial -def _get_param_tuple(obj): +def _get_param_tuple(obj) -> Tuple[str, ...]: return tuple( obj.get_dim_name(dim_type.param, i) for i in range(obj.dim(dim_type.param))) class GuardedPwQPolynomial: - def __init__(self, pwqpolynomial, valid_domain): - assert isinstance(pwqpolynomial, isl.PwQPolynomial) + def __init__(self, + pwqpolynomial: PwQPolynomial, valid_domain: isl.Set) -> None: + assert isinstance(pwqpolynomial, PwQPolynomial) self.pwqpolynomial = pwqpolynomial self.valid_domain = valid_domain @@ -166,7 +187,11 @@ def __repr__(self): # {{{ ToCountMap -class ToCountMap: +Countable = Union["Op", "MemAccess", "Sync"] +CountT = TypeVar("CountT") + + +class ToCountMap(Generic[CountT]): """A map from work descriptors like :class:`Op` and :class:`MemAccess` to any arithmetic type. @@ -190,7 +215,9 @@ class ToCountMap: """ - def __init__(self, count_map=None): + count_map: Dict[Countable, CountT] + + def __init__(self, count_map: Optional[Dict[Countable, CountT]] = None) -> None: if count_map is None: count_map = {} @@ -199,13 +226,13 @@ def __init__(self, count_map=None): def _zero(self): return 0 - def __add__(self, other): + def __add__(self, other: ToCountMap[CountT]) -> ToCountMap[CountT]: result = self.count_map.copy() for k, v in other.count_map.items(): result[k] = self.count_map.get(k, 0) + v return self.copy(count_map=result) - def __radd__(self, other): + def __radd__(self, other: Union[int, ToCountMap[CountT]]) -> ToCountMap[CountT]: if other != 0: raise ValueError("ToCountMap: Attempted to add ToCountMap " "to {} {}. ToCountMap may only be added to " @@ -214,7 +241,7 @@ def __radd__(self, other): return self - def __mul__(self, other): + def __mul__(self, other: GuardedPwQPolynomial) -> ToCountMap[CountT]: if isinstance(other, GuardedPwQPolynomial): return self.copy({ index: other*value @@ -226,22 +253,23 @@ def __mul__(self, other): __rmul__ = __mul__ - def __getitem__(self, index): + def __getitem__(self, index: Countable) -> CountT: return self.count_map[index] - def __repr__(self): + def __repr__(self) -> str: return repr(self.count_map) - def __str__(self): + def __str__(self) -> str: return "\n".join( f"{k}: {v}" for k, v in sorted(self.count_map.items(), key=lambda k: str(k))) - def __len__(self): + def __len__(self) -> int: return len(self.count_map) - def get(self, key, default=None): + def get(self, + key: Countable, default: Optional[CountT] = None) -> Optional[CountT]: return self.count_map.get(key, default) def items(self): @@ -253,18 +281,20 @@ def keys(self): def values(self): return self.count_map.values() - def copy(self, count_map=None): + def copy( + self, count_map: Optional[Dict[Countable, CountT]] = None + ) -> ToCountMap[CountT]: if count_map is None: count_map = self.count_map return type(self)(count_map=count_map) - def with_set_attributes(self, **kwargs): + def with_set_attributes(self, **kwargs) -> ToCountMap: return self.copy(count_map={ - key.copy(**kwargs): val + replace(key, **kwargs): val for key, val in self.count_map.items()}) - def filter_by(self, **kwargs): + def filter_by(self, **kwargs) -> ToCountMap: """Remove items without specified key fields. :arg kwargs: Keyword arguments matching fields in the keys of the @@ -309,7 +339,8 @@ class _Sentinel: return self.copy(count_map=new_count_map) - def filter_by_func(self, func): + def filter_by_func( + self, func: Callable[[Countable], bool]) -> ToCountMap[CountT]: """Keep items that pass a test. :arg func: A function that takes a map key a parameter and returns a @@ -342,7 +373,7 @@ def filter_func(key): return self.copy(count_map=new_count_map) - def group_by(self, *args): + def group_by(self, *args) -> ToCountMap[CountT]: """Group map items together, distinguishing by only the key fields passed in args. @@ -388,7 +419,7 @@ def group_by(self, *args): """ - new_count_map = {} + new_count_map: Dict[Countable, CountT] = {} # make sure all item keys have same type if self.count_map: @@ -409,7 +440,7 @@ def group_by(self, *args): return self.copy(count_map=new_count_map) - def to_bytes(self): + def to_bytes(self) -> ToCountMap[CountT]: """Convert counts to bytes using data type in map key. :return: A :class:`ToCountMap` mapping each original key to an @@ -443,11 +474,11 @@ def to_bytes(self): new_count_map = {} for key, val in self.count_map.items(): - new_count_map[key] = int(key.dtype.itemsize) * val + new_count_map[key] = int(key.dtype.itemsize) * val # type: ignore[union-attr] # noqa: E501 return self.copy(new_count_map) - def sum(self): + def sum(self) -> CountT: """:return: A sum of the values of the dictionary.""" total = self._zero() @@ -462,14 +493,18 @@ def sum(self): # {{{ ToCountPolynomialMap -class ToCountPolynomialMap(ToCountMap): +class ToCountPolynomialMap(ToCountMap[GuardedPwQPolynomial]): """Maps any type of key to a :class:`islpy.PwQPolynomial` or a :class:`~loopy.statistics.GuardedPwQPolynomial`. .. automethod:: eval_and_sum """ - def __init__(self, space, count_map=None): + def __init__( + self, + space: isl.Space, + count_map: Dict[Countable, GuardedPwQPolynomial] + ) -> None: if not isinstance(space, isl.Space): raise TypeError( "first argument to ToCountPolynomialMap must be " @@ -492,7 +527,7 @@ def __init__(self, space, count_map=None): super().__init__(count_map) - def _zero(self): + def _zero(self) -> isl.PwQPolynomial: space = self.space.insert_dims(dim_type.out, 0, 1) return isl.PwQPolynomial.zero(space) @@ -505,7 +540,7 @@ def copy(self, count_map=None, space=None): return type(self)(space, count_map) - def eval_and_sum(self, params=None): + def eval_and_sum(self, params: Optional[Mapping[str, int]] = None) -> int: """Add all counts and evaluate with provided parameter dict *params* :return: An :class:`int` containing the sum of all counts @@ -534,7 +569,10 @@ def eval_and_sum(self, params=None): # {{{ subst_into_to_count_map -def subst_into_guarded_pwqpolynomial(new_space, guarded_poly, subst_dict): +def subst_into_guarded_pwqpolynomial( + new_space: isl.Space, + guarded_poly: GuardedPwQPolynomial, + subst_dict: Mapping[str, PwQPolynomial]) -> GuardedPwQPolynomial: from loopy.isl_helpers import subst_into_pwqpolynomial, get_param_subst_domain poly = subst_into_pwqpolynomial( @@ -551,9 +589,12 @@ def subst_into_guarded_pwqpolynomial(new_space, guarded_poly, subst_dict): return GuardedPwQPolynomial(poly, valid_domain) -def subst_into_to_count_map(space, tcm, subst_dict): +def subst_into_to_count_map( + space: isl.Space, + tcm: ToCountPolynomialMap, + subst_dict: Mapping[str, PwQPolynomial]) -> ToCountPolynomialMap: from loopy.isl_helpers import subst_into_pwqpolynomial - new_count_map = {} + new_count_map: Dict[Countable, GuardedPwQPolynomial] = {} for key, value in tcm.count_map.items(): if isinstance(value, GuardedPwQPolynomial): new_count_map[key] = subst_into_guarded_pwqpolynomial( @@ -575,7 +616,7 @@ def subst_into_to_count_map(space, tcm, subst_dict): # {{{ CountGranularity -class CountGranularity: +class CountGranularity(Enum): """Strings specifying whether an operation should be counted once per *work-item*, *sub-group*, or *work-group*. @@ -596,17 +637,36 @@ class CountGranularity: """ - WORKITEM = "workitem" - SUBGROUP = "subgroup" - WORKGROUP = "workgroup" - ALL = [WORKITEM, SUBGROUP, WORKGROUP] + WORKITEM = 0 + SUBGROUP = 1 + WORKGROUP = 2 # }}} # {{{ Op descriptor -class Op(ImmutableRecord): +class OpType(Enum): + """ + .. attribute:: ADD + .. attribute:: MUL + .. attribute:: DIV + .. attribute:: POW + .. attribute:: SHIFT + .. attribute:: BITWISE + """ + ADD = enum_auto() + MUL = enum_auto() + DIV = enum_auto() + POW = enum_auto() + SHIFT = enum_auto() + BITWISE = enum_auto() + MAXMIN = enum_auto() + SPECIAL_FUNC = enum_auto() + + +@dataclass(frozen=True, eq=True) +class Op: """A descriptor for a type of arithmetic operation. .. attribute:: dtype @@ -614,10 +674,9 @@ class Op(ImmutableRecord): A :class:`loopy.types.LoopyType` or :class:`numpy.dtype` that specifies the data type operated on. - .. attribute:: name + .. attribute:: op_type - A :class:`str` that specifies the kind of arithmetic operation as - *add*, *mul*, *div*, *pow*, *shift*, *bw* (bitwise), etc. + A :class:`OpType`. .. attribute:: count_granularity @@ -630,46 +689,50 @@ class Op(ImmutableRecord): work-group executes on a single compute unit with all work-items within the work-group sharing local memory. A sub-group is an implementation-dependent grouping of work-items within a work-group, - analagous to an NVIDIA CUDA warp. + analogous to an NVIDIA CUDA warp. .. attribute:: kernel_name A :class:`str` representing the kernel name where the operation occurred. - """ - - def __init__(self, dtype=None, name=None, count_granularity=None, - kernel_name=None): - if count_granularity not in CountGranularity.ALL+[None]: - raise ValueError("Op.__init__: count_granularity '%s' is " - "not allowed. count_granularity options: %s" - % (count_granularity, CountGranularity.ALL+[None])) + .. attribute:: tags - if dtype is not None: - from loopy.types import to_loopy_type - dtype = to_loopy_type(dtype) + A :class:`frozenset` of tags to the operation. - super().__init__(dtype=dtype, name=name, - count_granularity=count_granularity, - kernel_name=kernel_name) + """ + dtype: Optional[LoopyType] = None + op_type: Optional[OpType] = None + count_granularity: Optional[CountGranularity] = None + kernel_name: Optional[str] = None + tags: FrozenSet[Tag] = frozenset() def __repr__(self): - # Record.__repr__ overridden for consistent ordering and conciseness if self.kernel_name is not None: return (f"Op({self.dtype}, {self.name}, {self.count_granularity}," - f' "{self.kernel_name}")') + f' "{self.kernel_name}", {self.tags})') else: - return f"Op({self.dtype}, {self.name}, {self.count_granularity})" + return f"Op({self.dtype}, {self.name}, " + \ + f"{self.count_granularity}, {self.tags})" # }}} # {{{ MemAccess descriptor -class MemAccess(ImmutableRecord): +class AccessDirection(Enum): + """ + .. attribute:: READ + .. attribute:: WRITE + """ + READ = 0 + WRITE = 1 + + +@dataclass(frozen=True, eq=True) +class MemAccess: """A descriptor for a type of memory access. - .. attribute:: mtype + .. attribute:: address_space A :class:`str` that specifies the memory type accessed as **global** or **local** @@ -697,10 +760,9 @@ class MemAccess(ImmutableRecord): specifies global strides for each global id in the memory access index. global ids not found will not be present in ``gid_strides.keys()``. - .. attribute:: direction + .. attribute:: read_write - A :class:`str` that specifies the direction of memory access as - **load** or **store**. + A :class:`AccessDirection` or *None*. .. attribute:: variable @@ -724,80 +786,108 @@ class MemAccess(ImmutableRecord): work-group executes on a single compute unit with all work-items within the work-group sharing local memory. A sub-group is an implementation-dependent grouping of work-items within a work-group, - analagous to an NVIDIA CUDA warp. + analogous to an NVIDIA CUDA warp. .. attribute:: kernel_name A :class:`str` representing the kernel name where the operation occurred. - """ - def __init__(self, mtype=None, dtype=None, lid_strides=None, gid_strides=None, - direction=None, variable=None, - *, variable_tags=None, - count_granularity=None, kernel_name=None): + .. attribute:: tags - if count_granularity not in CountGranularity.ALL+[None]: - raise ValueError("Op.__init__: count_granularity '%s' is " - "not allowed. count_granularity options: %s" - % (count_granularity, CountGranularity.ALL+[None])) + A :class:`frozenset` of tags to the operation. + """ - if variable_tags is None: - variable_tags = frozenset() + address_space: Optional[AddressSpace] = None + lid_strides: Optional[Mapping[int, Expression]] = None + gid_strides: Optional[Mapping[int, Expression]] = None + dtype: Optional[LoopyType] = None + read_write: Optional[AccessDirection] = None + variable: Optional[str] = None - if dtype is not None: - from loopy.types import to_loopy_type - dtype = to_loopy_type(dtype) + variable_tags: FrozenSet[Tag] = frozenset() + count_granularity: Optional[CountGranularity] = None + kernel_name: Optional[str] = None + tags: FrozenSet[Tag] = frozenset() - ImmutableRecord.__init__(self, mtype=mtype, dtype=dtype, - lid_strides=lid_strides, - gid_strides=gid_strides, direction=direction, - variable=variable, variable_tags=variable_tags, - count_granularity=count_granularity, - kernel_name=kernel_name) + @property + def mtype(self) -> str: + from warnings import warn + warn("MemAccess.mtype is deprecated and will stop working in 2024. " + "Use MemAccess.address_space instead.", + DeprecationWarning, stacklevel=2) + + if self.address_space == AddressSpace.GLOBAL: + return "global" + elif self.address_space == AddressSpace.LOCAL: + return "local" + else: + raise ValueError(f"unexpected address_space: '{self.address_space}'") - def __hash__(self): - # dicts in gid_strides and lid_strides aren't natively hashable - return hash(repr(self)) + @property + def direction(self) -> str: + from warnings import warn + warn("MemAccess.access_direction is deprecated " + "and will stop working in 2024. " + "Use MemAccess.read_write instead.", + DeprecationWarning, stacklevel=2) + + if self.read_write == AccessDirection.READ: + return "read" + elif self.address_space == AccessDirection.WRITE: + return "write" + else: + raise ValueError(f"unexpected read_write: '{self.read_write}'") def __repr__(self): # Record.__repr__ overridden for consistent ordering and conciseness - return "MemAccess({}, {}, {}, {}, {}, {}, {}, {}, {})".format( - self.mtype, + return "MemAccess({}, {}, {}, {}, {}, {}, {}, {}, {}, {})".format( + self.address_space, self.dtype, None if self.lid_strides is None else dict( sorted(self.lid_strides.items())), None if self.gid_strides is None else dict( sorted(self.gid_strides.items())), - self.direction, + self.read_write, self.variable, "None" if not self.variable_tags else str(self.variable_tags), self.count_granularity, - repr(self.kernel_name)) + repr(self.kernel_name), + self.tags) # }}} # {{{ Sync descriptor -class Sync(ImmutableRecord): +class SynchronizationKind(Enum): + BARRIER_GLOBAL = 0 + BARRIER_LOCAL = 1 + KERNEL_LAUNCH = 2 + + +@dataclass(frozen=True, eq=True) +class Sync: """A descriptor for a type of synchronization. .. attribute:: kind - A string describing the synchronization kind, e.g. ``"barrier_global"`` or - ``"barrier_local"`` or ``"kernel_launch"``. + A :class:`SynchronizationKind` or *None*. .. attribute:: kernel_name A :class:`str` representing the kernel name where the operation occurred. - """ - def __init__(self, kind=None, kernel_name=None): - super().__init__(kind=kind, kernel_name=kernel_name) + .. attribute:: tags + + A :class:`frozenset` of tags attached to the synchronization. + """ + sync_kind: Optional[SynchronizationKind] = None + kernel_name: Optional[str] = None + tags: FrozenSet[Tag] = frozenset() def __repr__(self): # Record.__repr__ overridden for consistent ordering and conciseness - return f"Sync({self.kind}, {self.kernel_name})" + return f"Sync({self.sync_kind}, {self.kernel_name}, {self.tags})" # }}} @@ -805,7 +895,7 @@ def __repr__(self): # {{{ CounterBase class CounterBase(CombineMapper): - def __init__(self, knl, callables_table, kernel_rec): + def __init__(self, knl: LoopKernel, callables_table, kernel_rec) -> None: self.knl = knl self.callables_table = callables_table self.kernel_rec = kernel_rec @@ -816,22 +906,29 @@ def __init__(self, knl, callables_table, kernel_rec): self.one = self.zero + 1 @cached_property - def param_space(self): + def param_space(self) -> isl.Space: return get_kernel_parameter_space(self.knl) - def new_poly_map(self, count_map): + def new_poly_map(self, count_map) -> ToCountPolynomialMap: return ToCountPolynomialMap(self.param_space, count_map) - def new_zero_poly_map(self): + def _new_zero_map(self) -> ToCountPolynomialMap: return self.new_poly_map({}) - def combine(self, values): - return sum(values) + def combine(self, values: Iterable[ToCountMap]) -> ToCountPolynomialMap: + return sum(values, self._new_zero_map()) + + def map_tagged_expression( + self, expr: TaggedExpression, tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: + return self.rec(expr.expr, expr.tags) - def map_constant(self, expr): - return self.new_zero_poly_map() + def map_constant( + self, expr: Expression, tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: + return self._new_zero_map() - def map_call(self, expr): + def map_call(self, expr: p.Call, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] @@ -852,41 +949,42 @@ def map_call(self, expr): return subst_into_to_count_map( self.param_space, sub_result, subst_dict) \ - + self.rec(expr.parameters) + + self.rec(expr.parameters, tags) else: raise NotImplementedError() - def map_call_with_kwargs(self, expr): + def map_call_with_kwargs( + self, expr: p.CallWithKwargs, tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: # See https://github.com/inducer/loopy/pull/323 raise NotImplementedError - def map_sum(self, expr): - if expr.children: - return sum(self.rec(child) for child in expr.children) - else: - return self.new_zero_poly_map() - - map_product = map_sum - - def map_comparison(self, expr): - return self.rec(expr.left)+self.rec(expr.right) + def map_comparison( + self, expr: p.Comparison, tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: + return self.rec(expr.left, tags) + self.rec(expr.right, tags) - def map_if(self, expr): + def map_if( + self, expr: p.If, tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: warn_with_kernel(self.knl, "summing_if_branches", "%s counting sum of if-expression branches." % type(self).__name__) - return self.rec(expr.condition) + self.rec(expr.then) \ - + self.rec(expr.else_) + return self.rec(expr.condition, tags) + self.rec(expr.then, tags) \ + + self.rec(expr.else_, tags) - def map_if_positive(self, expr): + def map_if_positive( + self, expr: p.IfPositive, tags: FrozenSet[Tag]) -> ToCountMap: warn_with_kernel(self.knl, "summing_if_branches", "%s counting sum of if-expression branches." % type(self).__name__) - return self.rec(expr.criterion) + self.rec(expr.then) \ - + self.rec(expr.else_) + return self.rec(expr.criterion, tags) + self.rec(expr.then, tags) \ + + self.rec(expr.else_, tags) - def map_common_subexpression(self, expr): + def map_common_subexpression( + self, expr: p.CommonSubexpression, tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: raise RuntimeError("%s encountered %s--not supposed to happen" % (type(self).__name__, type(expr).__name__)) @@ -894,36 +992,40 @@ def map_common_subexpression(self, expr): map_derivative = map_common_subexpression map_slice = map_common_subexpression - def map_reduction(self, expr): + def map_reduction( + self, expr: Reduction, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: # preprocessing should have removed these raise RuntimeError("%s encountered %s--not supposed to happen" % (type(self).__name__, type(expr).__name__)) + def __call__( + self, expr, tags: Optional[FrozenSet[Tag]] = None + ) -> ToCountPolynomialMap: + if tags is None: + tags = frozenset() + return self.rec(expr, tags=tags) + # }}} # {{{ ExpressionOpCounter class ExpressionOpCounter(CounterBase): - def __init__(self, knl, callables_table, kernel_rec, - count_within_subscripts=True): - super().__init__( - knl, callables_table, kernel_rec) + def __init__(self, knl: LoopKernel, callables_table, kernel_rec, + count_within_subscripts: bool = True): + super().__init__(knl, callables_table, kernel_rec) self.count_within_subscripts = count_within_subscripts arithmetic_count_granularity = CountGranularity.SUBGROUP - def combine(self, values): - return sum(values) - - def map_constant(self, expr): - return self.new_zero_poly_map() + def map_constant(self, expr: Any, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: + return self._new_zero_map() map_tagged_variable = map_constant map_variable = map_constant map_nan = map_constant - def map_call(self, expr): + def map_call(self, expr: p.Call, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] @@ -932,137 +1034,160 @@ def map_call(self, expr): if not isinstance(clbl, CallableKernel): return self.new_poly_map( {Op(dtype=self.type_inf(expr), - name="func:"+clbl.name, + op_type=OpType.SPECIAL_FUNC, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one} - ) + self.rec(expr.parameters) + ) + self.rec(expr.parameters, tags) else: - return super().map_call(expr) + return super().map_call(expr, tags) - def map_subscript(self, expr): + def map_subscript( + self, expr: p.Subscript, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: if self.count_within_subscripts: - return self.rec(expr.index) + return self.rec(expr.index, tags) else: - return self.new_zero_poly_map() + return self._new_zero_map() - def map_sub_array_ref(self, expr): + def map_sub_array_ref( + self, expr: SubArrayRef, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: # generates an array view, considered free - return self.new_zero_poly_map() + return self._new_zero_map() - def map_sum(self, expr): + def map_sum(self, expr: p.Sum, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: assert expr.children return self.new_poly_map( {Op(dtype=self.type_inf(expr), - name="add", + op_type=OpType.ADD, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.zero + (len(expr.children)-1)} - ) + sum(self.rec(child) for child in expr.children) + ) + sum(self.rec(child, tags) for child in expr.children) - def map_product(self, expr): + def map_product( + self, expr: p.Product, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: from pymbolic.primitives import is_zero assert expr.children return sum(self.new_poly_map({Op(dtype=self.type_inf(expr), - name="mul", + op_type=OpType.MUL, + tags=tags, count_granularity=( self.arithmetic_count_granularity), kernel_name=self.knl.name): self.one}) - + self.rec(child) + + self.rec(child, tags) for child in expr.children if not is_zero(child + 1)) + \ self.new_poly_map({Op(dtype=self.type_inf(expr), - name="mul", + op_type=OpType.MUL, + tags=tags, count_granularity=( self.arithmetic_count_granularity), kernel_name=self.knl.name): -self.one}) - def map_quotient(self, expr, *args): + def map_quotient( + self, expr: p.QuotientBase, tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="div", + op_type=OpType.DIV, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one}) \ - + self.rec(expr.numerator) \ - + self.rec(expr.denominator) + + self.rec(expr.numerator, tags) \ + + self.rec(expr.denominator, tags) map_floor_div = map_quotient map_remainder = map_quotient - def map_power(self, expr): + def map_power(self, expr: p.Power, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="pow", + op_type=OpType.POW, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one}) \ - + self.rec(expr.base) \ - + self.rec(expr.exponent) + + self.rec(expr.base, tags) \ + + self.rec(expr.exponent, tags) - def map_left_shift(self, expr): + def map_left_shift( + self, expr: Union[p.LeftShift, p.RightShift], tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="shift", + op_type=OpType.SHIFT, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one}) \ - + self.rec(expr.shiftee) \ - + self.rec(expr.shift) + + self.rec(expr.shiftee, tags) \ + + self.rec(expr.shift, tags) map_right_shift = map_left_shift - def map_bitwise_not(self, expr): + def map_bitwise_not( + self, expr: p.BitwiseNot, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="bw", + op_type=OpType.BITWISE, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.one}) \ - + self.rec(expr.child) + + self.rec(expr.child, tags) - def map_bitwise_or(self, expr): + def map_bitwise_or( + self, expr: Union[p.BitwiseOr, p.BitwiseAnd, p.BitwiseXor], + tags: FrozenSet[Tag]) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="bw", + op_type=OpType.BITWISE, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): self.zero + (len(expr.children)-1)}) \ - + sum(self.rec(child) for child in expr.children) + + sum(self.rec(child, tags) for child in expr.children) map_bitwise_xor = map_bitwise_or map_bitwise_and = map_bitwise_or - def map_if(self, expr): + def map_if(self, expr: p.If, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: warn_with_kernel(self.knl, "summing_if_branches_ops", "ExpressionOpCounter counting ops as sum of " "if-statement branches.") - return self.rec(expr.condition) + self.rec(expr.then) \ - + self.rec(expr.else_) + return self.rec(expr.condition, tags) + self.rec(expr.then, tags) \ + + self.rec(expr.else_, tags) - def map_if_positive(self, expr): + def map_if_positive( + self, expr: p.IfPositive, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: warn_with_kernel(self.knl, "summing_ifpos_branches_ops", "ExpressionOpCounter counting ops as sum of " "if_pos-statement branches.") - return self.rec(expr.criterion) + self.rec(expr.then) \ - + self.rec(expr.else_) + return self.rec(expr.criterion, tags) + self.rec(expr.then, tags) \ + + self.rec(expr.else_, tags) - def map_min(self, expr): + def map_min( + self, expr: Union[p. Min, p.Max], tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: return self.new_poly_map({Op(dtype=self.type_inf(expr), - name="maxmin", + op_type=OpType.MAXMIN, + tags=tags, count_granularity=self.arithmetic_count_granularity, kernel_name=self.knl.name): len(expr.children)-1}) \ - + sum(self.rec(child) for child in expr.children) + + sum(self.rec(child, tags) for child in expr.children) map_max = map_min - def map_common_subexpression(self, expr): + def map_common_subexpression(self, expr, tags): raise NotImplementedError("ExpressionOpCounter encountered " "common_subexpression, " "map_common_subexpression not implemented.") - def map_substitution(self, expr): + def map_substitution(self, expr, tags): raise NotImplementedError("ExpressionOpCounter encountered " "substitution, " "map_substitution not implemented.") - def map_derivative(self, expr): + def map_derivative(self, expr, tags): raise NotImplementedError("ExpressionOpCounter encountered " "derivative, " "map_derivative not implemented.") - def map_slice(self, expr): + def map_slice(self, expr, tags): raise NotImplementedError("ExpressionOpCounter encountered slice, " "map_slice not implemented.") @@ -1084,7 +1209,9 @@ def map_floor_div(self, expr): # {{{ _get_lid_and_gid_strides -def _get_lid_and_gid_strides(knl, array, index): +def _get_lid_and_gid_strides( + knl: LoopKernel, array: ArrayBase, index: Tuple[Expression, ...] + ) -> Tuple[Mapping[int, Expression], Mapping[int, Expression]]: # find all local and global index tags and corresponding inames from loopy.symbolic import get_dependencies my_inames = get_dependencies(index) & knl.all_inames() @@ -1116,7 +1243,9 @@ def _get_lid_and_gid_strides(knl, array, index): from loopy.symbolic import simplify_using_aff from loopy.diagnostic import ExpressionNotAffineError - def get_iname_strides(tag_to_iname_dict): + def get_iname_strides( + tag_to_iname_dict: Mapping[InameImplementationTag, str] + ) -> Mapping[InameImplementationTag, Expression]: tag_to_stride_dict = {} if array.dim_tags is None: @@ -1173,178 +1302,102 @@ def get_iname_strides(tag_to_iname_dict): # {{{ MemAccessCounterBase -class MemAccessCounterBase(CounterBase): - def map_sub_array_ref(self, expr): +class MemAccessCounter(CounterBase): + def map_sub_array_ref( + self, expr: SubArrayRef, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: # generates an array view, considered free - return self.new_zero_poly_map() + return self._new_zero_map() - def map_call(self, expr): + def map_call(self, expr: p.Call, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: from loopy.symbolic import ResolvedFunction assert isinstance(expr.function, ResolvedFunction) clbl = self.callables_table[expr.function.name] from loopy.kernel.function_interface import CallableKernel if not isinstance(clbl, CallableKernel): - return self.rec(expr.parameters) + return self.rec(expr.parameters, tags) else: - return super().map_call(expr) - -# }}} + return super().map_call(expr, tags) + # local_mem_count_granularity = CountGranularity.SUBGROUP -# {{{ LocalMemAccessCounter + def count_var_access(self, + dtype: LoopyType, + name: str, + index: Optional[Tuple[Expression, ...]], + tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: + count_map = {} -class LocalMemAccessCounter(MemAccessCounterBase): - local_mem_count_granularity = CountGranularity.SUBGROUP + array = self.knl.get_var_descriptor(name) - def count_var_access(self, dtype, name, index): - count_map = {} - if name in self.knl.temporary_variables: - array = self.knl.temporary_variables[name] - if isinstance(array, TemporaryVariable) and ( - array.address_space == AddressSpace.LOCAL): - if index is None: - # no subscript - count_map[MemAccess( - mtype="local", - dtype=dtype, - count_granularity=self.local_mem_count_granularity, - kernel_name=self.knl.name)] = self.one - return self.new_poly_map(count_map) - - array = self.knl.temporary_variables[name] - - # could be tuple or scalar index - index_tuple = index - if not isinstance(index_tuple, tuple): - index_tuple = (index_tuple,) - - lid_strides, gid_strides = _get_lid_and_gid_strides( - self.knl, array, index_tuple) - - count_map[MemAccess( - mtype="local", + if index is None: + # no subscript + count_map[MemAccess( + address_space=AddressSpace.LOCAL, + tags=tags, dtype=dtype, - lid_strides=dict(sorted(lid_strides.items())), - gid_strides=dict(sorted(gid_strides.items())), - variable=name, count_granularity=self.local_mem_count_granularity, kernel_name=self.knl.name)] = self.one + return self.new_poly_map(count_map) + + # could be tuple or scalar index + index_tuple = index + if not isinstance(index_tuple, tuple): + index_tuple = (index_tuple,) + + lid_strides, gid_strides = _get_lid_and_gid_strides( + self.knl, array, index_tuple) + + count_map[MemAccess( + address_space=array.address_space, + dtype=dtype, + tags=tags, + lid_strides=lid_strides, + gid_strides=gid_strides, + variable=name, + count_granularity=self.local_mem_count_granularity, + kernel_name=self.knl.name)] = self.one return self.new_poly_map(count_map) - def map_variable(self, expr): + def map_variable( + self, expr: p.Variable, tags: FrozenSet[Tag] + ) -> ToCountPolynomialMap: return self.count_var_access( - self.type_inf(expr), expr.name, None) + self.type_inf(expr), expr.name, None, tags) map_tagged_variable = map_variable - def map_subscript(self, expr): + def map_subscript( + self, expr: p.Subscript, tags: FrozenSet[Tag]) -> ToCountPolynomialMap: return (self.count_var_access(self.type_inf(expr), expr.aggregate.name, - expr.index) - + self.rec(expr.index)) + expr.index, tags) + + self.rec(expr.index, tags)) # }}} -# {{{ GlobalMemAccessCounter - -class GlobalMemAccessCounter(MemAccessCounterBase): - def map_variable(self, expr): - name = expr.name - - if name in self.knl.arg_dict: - array = self.knl.arg_dict[name] - else: - # this is a temporary variable - # FIXME temporary variable could have global address space - return self.new_zero_poly_map() - - if not isinstance(array, lp.ArrayArg): - # this array is not in global memory - return self.new_zero_poly_map() - - return self.new_poly_map({MemAccess(mtype="global", - dtype=self.type_inf(expr), lid_strides={}, - gid_strides={}, variable=name, - count_granularity=CountGranularity.WORKITEM, - kernel_name=self.knl.name): self.one} - ) + self.rec(expr.index) - - def map_subscript(self, expr): - name = expr.aggregate.name - try: - var_tags = expr.aggregate.tags - except AttributeError: - var_tags = frozenset() - - is_global_temp = False - if name in self.knl.arg_dict: - array = self.knl.arg_dict[name] - elif name in self.knl.temporary_variables: - # This a temporary, but might have global address space - from loopy.kernel.data import AddressSpace - array = self.knl.temporary_variables[name] - if array.address_space != AddressSpace.GLOBAL: - # This temporary does not have global address space - return self.rec(expr.index) - # This temporary has global address space - is_global_temp = True - else: - # This temporary does not have global address space - return self.rec(expr.index) - - if (not is_global_temp) and not isinstance(array, lp.ArrayArg): - # This array is not in global memory - return self.rec(expr.index) - - index_tuple = expr.index # could be tuple or scalar index - if not isinstance(index_tuple, tuple): - index_tuple = (index_tuple,) - - lid_strides, gid_strides = _get_lid_and_gid_strides( - self.knl, array, index_tuple) - - global_access_count_granularity = CountGranularity.SUBGROUP - - # Account for broadcasts once per subgroup - count_granularity = CountGranularity.WORKITEM if ( - # if the stride in lid.0 is known - 0 in lid_strides - and - # it is nonzero - lid_strides[0] != 0 - ) else global_access_count_granularity - - return self.new_poly_map({MemAccess( - mtype="global", - dtype=self.type_inf(expr), - lid_strides=dict(sorted(lid_strides.items())), - gid_strides=dict(sorted(gid_strides.items())), - variable=name, - variable_tags=var_tags, - count_granularity=count_granularity, - kernel_name=self.knl.name, - ): self.one} - ) + self.rec(expr.index_tuple) - -# }}} +# {{{ AccessFootprintGatherer +FootprintsT = Dict[str, isl.Set] -# {{{ AccessFootprintGatherer class AccessFootprintGatherer(CombineMapper): - def __init__(self, kernel, domain, ignore_uncountable=False): + def __init__(self, + kernel: LoopKernel, + domain: isl.BasicSet, + ignore_uncountable: bool = False) -> None: self.kernel = kernel self.domain = domain self.ignore_uncountable = ignore_uncountable @staticmethod - def combine(values): + def combine(values: Iterable[FootprintsT]) -> FootprintsT: assert values - def merge_dicts(a, b): + def merge_dicts(a: FootprintsT, b: FootprintsT) -> FootprintsT: result = a.copy() for var_name, footprint in b.items(): @@ -1358,13 +1411,13 @@ def merge_dicts(a, b): from functools import reduce return reduce(merge_dicts, values) - def map_constant(self, expr): + def map_constant(self, expr: p.Any) -> FootprintsT: return {} - def map_variable(self, expr): + def map_variable(self, expr: p.Variable) -> FootprintsT: return {} - def map_subscript(self, expr): + def map_subscript(self, expr: p.Subscript) -> FootprintsT: subscript = expr.index if not isinstance(subscript, tuple): @@ -1401,13 +1454,15 @@ def map_subscript(self, expr): # {{{ count -def add_assumptions_guard(kernel, pwqpolynomial): +def add_assumptions_guard( + kernel: LoopKernel, pwqpolynomial: isl.PwQPolynomial + ) -> GuardedPwQPolynomial: return GuardedPwQPolynomial( pwqpolynomial, kernel.assumptions.align_params(pwqpolynomial.space)) -def count(kernel, set, space=None): +def count(kernel, set: isl.Set, space=None) -> GuardedPwQPolynomial: if isinstance(kernel, TranslationUnit): kernel_names = [i for i, clbl in kernel.callables_table.items() if isinstance(clbl, CallableKernel)] @@ -1513,7 +1568,9 @@ def count(kernel, set, space=None): return add_assumptions_guard(kernel, total_count) -def get_unused_hw_axes_factor(knl, callables_table, insn, disregard_local_axes): +def get_unused_hw_axes_factor( + knl: LoopKernel, callables_table, insn: InstructionBase, + disregard_local_axes: bool) -> GuardedPwQPolynomial: # FIXME: Multi-kernel support gsize, lsize = knl.get_grid_size_upper_bounds(callables_table) @@ -1552,7 +1609,8 @@ def mult_grid_factor(used_axes, sizes): return add_assumptions_guard(knl, result) -def count_inames_domain(knl, inames): +def count_inames_domain( + knl: LoopKernel, inames: FrozenSet[str]) -> GuardedPwQPolynomial: space = get_kernel_parameter_space(knl) if not inames: return add_assumptions_guard(knl, @@ -1563,8 +1621,12 @@ def count_inames_domain(knl, inames): return count(knl, domain, space=space) -def count_insn_runs(knl, callables_table, insn, count_redundant_work, - disregard_local_axes=False): +def count_insn_runs( + knl: LoopKernel, + callables_table: Mapping[str, InKernelCallable], + insn: InstructionBase, + count_redundant_work: bool, + disregard_local_axes: bool = False) -> GuardedPwQPolynomial: insn_inames = insn.within_inames @@ -1584,8 +1646,14 @@ def count_insn_runs(knl, callables_table, insn, count_redundant_work, return c -def _get_insn_count(knl, callables_table, insn_id, subgroup_size, - count_redundant_work, count_granularity=CountGranularity.WORKITEM): +def _get_insn_count( + knl: LoopKernel, + callables_table: Mapping[str, InKernelCallable], + insn_id: str, + subgroup_size: Optional[int], + count_redundant_work: bool, + count_granularity: CountGranularity = CountGranularity.WORKITEM + ) -> GuardedPwQPolynomial: insn = knl.id_to_insn[insn_id] if count_granularity is None: @@ -1645,18 +1713,20 @@ def _get_insn_count(knl, callables_table, insn_id, subgroup_size, else: # this should not happen since this is enforced in Op/MemAccess - raise ValueError("get_insn_count: count_granularity '%s' is" - "not allowed. count_granularity options: %s" - % (count_granularity, CountGranularity.ALL+[None])) + raise ValueError("get_insn_count: count_granularity " + f"'{count_granularity}' is not allowed.") # }}} # {{{ get_op_map -def _get_op_map_for_single_kernel(knl, callables_table, - count_redundant_work, - count_within_subscripts, subgroup_size, within): +def _get_op_map_for_single_kernel( + knl: LoopKernel, + callables_table: Mapping[str, InKernelCallable], + count_redundant_work: bool, + count_within_subscripts: bool, + subgroup_size: int, within) -> ToCountPolynomialMap: subgroup_size = _process_subgroup_size(knl, subgroup_size) @@ -1668,7 +1738,7 @@ def _get_op_map_for_single_kernel(knl, callables_table, op_counter = ExpressionOpCounter(knl, callables_table, kernel_rec, count_within_subscripts) - op_map = op_counter.new_zero_poly_map() + op_map = op_counter._new_zero_map() from loopy.kernel.instruction import ( CallInstruction, CInstruction, Assignment, @@ -1693,9 +1763,12 @@ def _get_op_map_for_single_kernel(knl, callables_table, return op_map -def get_op_map(program, count_redundant_work=False, - count_within_subscripts=True, subgroup_size=None, - entrypoint=None, within=None): +def get_op_map( + t_unit: TranslationUnit, *, count_redundant_work: bool = False, + count_within_subscripts: bool = True, + subgroup_size: Optional[int] = None, + entrypoint: Optional[str] = None, + within: Any = None): """Count the number of operations in a loopy kernel. @@ -1713,7 +1786,7 @@ def get_op_map(program, count_redundant_work=False, :arg subgroup_size: (currently unused) An :class:`int`, :class:`str` ``"guess"``, or *None* that specifies the sub-group size. An OpenCL sub-group is an implementation-dependent grouping of work-items within - a work-group, analagous to an NVIDIA CUDA warp. subgroup_size is used, + a work-group, analogous to an NVIDIA CUDA warp. subgroup_size is used, e.g., when counting a :class:`MemAccess` whose count_granularity specifies that it should only be counted once per sub-group. If set to *None* an attempt to find the sub-group size using the device will be @@ -1755,25 +1828,28 @@ def get_op_map(program, count_redundant_work=False, """ if entrypoint is None: - if len(program.entrypoints) > 1: + if len(t_unit.entrypoints) > 1: raise LoopyError("Must provide entrypoint") - entrypoint = list(program.entrypoints)[0] + entrypoint = list(t_unit.entrypoints)[0] - assert entrypoint in program.entrypoints + assert entrypoint in t_unit.entrypoints from loopy.preprocess import preprocess_program, infer_unknown_types - program = preprocess_program(program) + t_unit = preprocess_program(t_unit) from loopy.match import parse_match within = parse_match(within) # Ordering restriction: preprocess might insert arguments to # make strides valid. Those also need to go through type inference. - program = infer_unknown_types(program, expect_completion=True) + t_unit = infer_unknown_types(t_unit, expect_completion=True) + + kernel = t_unit[entrypoint] + assert isinstance(kernel, LoopKernel) return _get_op_map_for_single_kernel( - program[entrypoint], program.callables_table, + kernel, t_unit.callables_table, count_redundant_work=count_redundant_work, count_within_subscripts=count_within_subscripts, subgroup_size=subgroup_size, @@ -1782,7 +1858,7 @@ def get_op_map(program, count_redundant_work=False, # }}} -# {{{ subgoup size finding +# {{{ subgroup size finding def _find_subgroup_size_for_knl(knl): from loopy.target.pyopencl import PyOpenCLTarget @@ -1840,8 +1916,11 @@ def _process_subgroup_size(knl, subgroup_size_requested): # {{{ get_mem_access_map -def _get_mem_access_map_for_single_kernel(knl, callables_table, - count_redundant_work, subgroup_size, within): +def _get_mem_access_map_for_single_kernel( + knl: LoopKernel, + callables_table: Mapping[str, InKernelCallable], + count_redundant_work: bool, subgroup_size: Optional[int], + within: Any) -> ToCountPolynomialMap: subgroup_size = _process_subgroup_size(knl, subgroup_size) @@ -1850,11 +1929,8 @@ def _get_mem_access_map_for_single_kernel(knl, callables_table, count_redundant_work=count_redundant_work, subgroup_size=subgroup_size) - access_counter_g = GlobalMemAccessCounter( - knl, callables_table, kernel_rec) - access_counter_l = LocalMemAccessCounter( - knl, callables_table, kernel_rec) - access_map = access_counter_g.new_zero_poly_map() + access_counter = MemAccessCounter(knl, callables_table, kernel_rec) + access_map = access_counter._new_zero_map() from loopy.kernel.instruction import ( CallInstruction, CInstruction, Assignment, @@ -1864,14 +1940,12 @@ def _get_mem_access_map_for_single_kernel(knl, callables_table, if within(knl, insn): if isinstance(insn, (CallInstruction, CInstruction, Assignment)): insn_access_map = ( - access_counter_g(insn.expression) - + access_counter_l(insn.expression) - ).with_set_attributes(direction="load") + access_counter(insn.expression) + ).with_set_attributes(read_write=AccessDirection.READ) for assignee in insn.assignees: insn_access_map = insn_access_map + ( - access_counter_g(assignee) - + access_counter_l(assignee) - ).with_set_attributes(direction="store") + access_counter(assignee) + ).with_set_attributes(read_write=AccessDirection.WRITE) for key, val in insn_access_map.count_map.items(): count = _get_insn_count(knl, callables_table, insn.id, @@ -1889,9 +1963,11 @@ def _get_mem_access_map_for_single_kernel(knl, callables_table, return access_map -def get_mem_access_map(program, count_redundant_work=False, - subgroup_size=None, entrypoint=None, - within=None): +def get_mem_access_map( + t_unit: TranslationUnit, *, count_redundant_work: bool = False, + subgroup_size: Optional[int] = None, + entrypoint: Optional[str] = None, + within: Any = None) -> ToCountPolynomialMap: """Count the number of memory accesses in a loopy kernel. :arg knl: A :class:`loopy.LoopKernel` whose memory accesses are to be @@ -1906,7 +1982,7 @@ def get_mem_access_map(program, count_redundant_work=False, :arg subgroup_size: An :class:`int`, :class:`str` ``"guess"``, or *None* that specifies the sub-group size. An OpenCL sub-group is an implementation-dependent grouping of work-items within a work-group, - analagous to an NVIDIA CUDA warp. subgroup_size is used, e.g., when + analogous to an NVIDIA CUDA warp. subgroup_size is used, e.g., when counting a :class:`MemAccess` whose count_granularity specifies that it should only be counted once per sub-group. If set to *None* an attempt to find the sub-group size using the device will be made, if this fails @@ -1977,26 +2053,26 @@ def get_mem_access_map(program, count_redundant_work=False, """ if entrypoint is None: - if len(program.entrypoints) > 1: + if len(t_unit.entrypoints) > 1: raise LoopyError("Must provide entrypoint") - entrypoint = list(program.entrypoints)[0] + entrypoint = list(t_unit.entrypoints)[0] - assert entrypoint in program.entrypoints + assert entrypoint in t_unit.entrypoints from loopy.preprocess import preprocess_program, infer_unknown_types - program = preprocess_program(program) + t_unit = preprocess_program(t_unit) from loopy.match import parse_match within = parse_match(within) # Ordering restriction: preprocess might insert arguments to # make strides valid. Those also need to go through type inference. - program = infer_unknown_types(program, expect_completion=True) + t_unit = infer_unknown_types(t_unit, expect_completion=True) return _get_mem_access_map_for_single_kernel( - program[entrypoint], program.callables_table, + t_unit[entrypoint], t_unit.callables_table, count_redundant_work=count_redundant_work, subgroup_size=subgroup_size, within=within) @@ -2006,8 +2082,10 @@ def get_mem_access_map(program, count_redundant_work=False, # {{{ get_synchronization_map -def _get_synchronization_map_for_single_kernel(knl, callables_table, - subgroup_size=None): +def _get_synchronization_map_for_single_kernel( + knl: LoopKernel, + callables_table: Mapping[str, InKernelCallable], + subgroup_size: Optional[int] = None): knl = lp.get_one_linearized_kernel(knl, callables_table) @@ -2019,7 +2097,7 @@ def _get_synchronization_map_for_single_kernel(knl, callables_table, subgroup_size=subgroup_size) sync_counter = CounterBase(knl, callables_table, kernel_rec) - sync_map = sync_counter.new_zero_poly_map() + sync_map = sync_counter._new_zero_map() iname_list = [] @@ -2042,7 +2120,7 @@ def _get_synchronization_map_for_single_kernel(knl, callables_table, elif isinstance(sched_item, CallKernel): sync_map = sync_map + ToCountMap( - {Sync("kernel_launch", knl.name): + {Sync(SynchronizationKind.KERNEL_LAUNCH, knl.name): count_inames_domain(knl, frozenset(iname_list))}) elif isinstance(sched_item, ReturnFromKernel): @@ -2055,7 +2133,10 @@ def _get_synchronization_map_for_single_kernel(knl, callables_table, return sync_map -def get_synchronization_map(program, subgroup_size=None, entrypoint=None): +def get_synchronization_map( + t_unit: TranslationUnit, *, + subgroup_size: Optional[int] = None, + entrypoint: Optional[str] = None) -> ToCountPolynomialMap: """Count the number of synchronization events each work-item encounters in a loopy kernel. @@ -2064,7 +2145,7 @@ def get_synchronization_map(program, subgroup_size=None, entrypoint=None): :arg subgroup_size: (currently unused) An :class:`int`, :class:`str` ``"guess"``, or *None* that specifies the sub-group size. An OpenCL sub-group is an implementation-dependent grouping of work-items within - a work-group, analagous to an NVIDIA CUDA warp. subgroup_size is used, + a work-group, analogous to an NVIDIA CUDA warp. subgroup_size is used, e.g., when counting a :class:`MemAccess` whose count_granularity specifies that it should only be counted once per sub-group. If set to *None* an attempt to find the sub-group size using the device will be @@ -2092,21 +2173,21 @@ def get_synchronization_map(program, subgroup_size=None, entrypoint=None): """ if entrypoint is None: - if len(program.entrypoints) > 1: + if len(t_unit.entrypoints) > 1: raise LoopyError("Must provide entrypoint") - entrypoint = list(program.entrypoints)[0] + entrypoint = list(t_unit.entrypoints)[0] - assert entrypoint in program.entrypoints + assert entrypoint in t_unit.entrypoints from loopy.preprocess import preprocess_program, infer_unknown_types - program = preprocess_program(program) + t_unit = preprocess_program(t_unit) # Ordering restriction: preprocess might insert arguments to # make strides valid. Those also need to go through type inference. - program = infer_unknown_types(program, expect_completion=True) + t_unit = infer_unknown_types(t_unit, expect_completion=True) return _get_synchronization_map_for_single_kernel( - program[entrypoint], program.callables_table, + t_unit[entrypoint], t_unit.callables_table, subgroup_size=subgroup_size) # }}} @@ -2114,7 +2195,9 @@ def get_synchronization_map(program, subgroup_size=None, entrypoint=None): # {{{ gather_access_footprints -def _gather_access_footprints_for_single_kernel(kernel, ignore_uncountable): +def _gather_access_footprints_for_single_kernel( + kernel: LoopKernel, ignore_uncountable: bool + ) -> Tuple[FootprintsT, FootprintsT]: write_footprints = [] read_footprints = [] @@ -2136,10 +2219,14 @@ def _gather_access_footprints_for_single_kernel(kernel, ignore_uncountable): write_footprints.append(afg(insn.assignees)) read_footprints.append(afg(insn.expression)) - return write_footprints, read_footprints + return ( + AccessFootprintGatherer.combine(write_footprints), + AccessFootprintGatherer.combine(read_footprints)) -def gather_access_footprints(program, ignore_uncountable=False, entrypoint=None): +def gather_access_footprints( + t_unit: TranslationUnit, *, ignore_uncountable: bool = False, + entrypoint: Optional[str] = None) -> Mapping[MemAccess, isl.Set]: """Return a dictionary mapping ``(var_name, direction)`` to :class:`islpy.Set` instances capturing which indices of each the array *var_name* are read/written (where *direction* is either ``read`` or @@ -2151,48 +2238,48 @@ def gather_access_footprints(program, ignore_uncountable=False, entrypoint=None) """ if entrypoint is None: - if len(program.entrypoints) > 1: + if len(t_unit.entrypoints) > 1: raise LoopyError("Must provide entrypoint") - entrypoint = list(program.entrypoints)[0] + entrypoint = list(t_unit.entrypoints)[0] - assert entrypoint in program.entrypoints + assert entrypoint in t_unit.entrypoints - # FIMXE: works only for one callable kernel till now. + # FIXME: works only for one callable kernel till now. if len([in_knl_callable for in_knl_callable in - program.callables_table.values() if isinstance(in_knl_callable, + t_unit.callables_table.values() if isinstance(in_knl_callable, CallableKernel)]) != 1: - raise NotImplementedError("Currently only supported for program with " - "only one CallableKernel.") + raise NotImplementedError("Currently only supported for " + "translation unit with only one CallableKernel.") from loopy.preprocess import preprocess_program, infer_unknown_types - program = preprocess_program(program) + t_unit = preprocess_program(t_unit) # Ordering restriction: preprocess might insert arguments to # make strides valid. Those also need to go through type inference. - program = infer_unknown_types(program, expect_completion=True) - - write_footprints = [] - read_footprints = [] + t_unit = infer_unknown_types(t_unit, expect_completion=True) + kernel = t_unit[entrypoint] + assert isinstance(kernel, LoopKernel) write_footprints, read_footprints = _gather_access_footprints_for_single_kernel( - program[entrypoint], ignore_uncountable) - - write_footprints = AccessFootprintGatherer.combine(write_footprints) - read_footprints = AccessFootprintGatherer.combine(read_footprints) + kernel, ignore_uncountable) result = {} for vname, footprint in write_footprints.items(): - result[(vname, "write")] = footprint + result[MemAccess(variable=vname, read_write=AccessDirection.WRITE)] \ + = footprint for vname, footprint in read_footprints.items(): - result[(vname, "read")] = footprint + result[MemAccess(variable=vname, read_write=AccessDirection.READ)] \ + = footprint return result -def gather_access_footprint_bytes(program, ignore_uncountable=False): +def gather_access_footprint_bytes( + t_unit: TranslationUnit, *, ignore_uncountable: bool = False + ) -> ToCountPolynomialMap: """Return a dictionary mapping ``(var_name, direction)`` to :class:`islpy.PwQPolynomial` instances capturing the number of bytes are read/written (where *direction* is either ``read`` or ``write`` on array @@ -2203,30 +2290,25 @@ def gather_access_footprint_bytes(program, ignore_uncountable=False): nonlinear indices) """ - from loopy.preprocess import preprocess_program, infer_unknown_types - kernel = infer_unknown_types(program, expect_completion=True) + t_unit = lp.infer_unknown_types(t_unit, expect_completion=True) + t_unit = lp.preprocess_program(t_unit) - from loopy.kernel import KernelState - if kernel.state < KernelState.PREPROCESSED: - kernel = preprocess_program(program) + fp = gather_access_footprints(t_unit, ignore_uncountable=ignore_uncountable) - result = {} - fp = gather_access_footprints(kernel, - ignore_uncountable=ignore_uncountable) + # FIXME: Only supporting a single kernel for now + kernel = t_unit.default_entrypoint - for key, var_fp in fp.items(): - vname, direction = key - - var_descr = kernel.get_var_descriptor(vname) + result = {} + for ma, var_fp in fp.items(): + assert ma.variable + var_descr = kernel.get_var_descriptor(ma.variable) + assert var_descr.dtype bytes_transferred = ( int(var_descr.dtype.numpy_dtype.itemsize) * count(kernel, var_fp)) - if key in result: - result[key] += bytes_transferred - else: - result[key] = bytes_transferred + result[ma] = add_assumptions_guard(kernel, bytes_transferred) - return result + return ToCountPolynomialMap(get_kernel_parameter_space(kernel), result) # }}} diff --git a/loopy/symbolic.py b/loopy/symbolic.py index b6bd1d009..d61b4c980 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -89,6 +89,8 @@ .. autoclass:: TaggedVariable +.. autoclass:: TaggedExpression + .. autoclass:: Reduction .. autoclass:: LinearSubscript @@ -114,6 +116,10 @@ # {{{ mappers with support for loopy-specific primitives class IdentityMapperMixin: + def map_tagged_expression(self, expr, *args, **kwargs): + new_expr = self.rec(expr.expr, *args, **kwargs) + return TaggedExpression(expr.tags, new_expr) + def map_literal(self, expr, *args, **kwargs): return expr @@ -207,6 +213,12 @@ def map_common_subexpression_uncached(self, expr): class WalkMapperMixin: + def map_tagged_expression(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.expr, *args, **kwargs) + def map_literal(self, expr, *args, **kwargs): self.visit(expr, *args, **kwargs) @@ -273,6 +285,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase): class CombineMapper(CombineMapperBase): + def map_tagged_expression(self, expr, *args, **kwargs): + return self.rec(expr.expr, *args, **kwargs) + def map_reduction(self, expr, *args, **kwargs): return self.rec(expr.expr, *args, **kwargs) @@ -298,6 +313,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase, class StringifyMapper(StringifyMapperBase): + def map_tagged_expression(self, expr, *args): + from pymbolic.mapper.stringifier import PREC_NONE + return f"TaggedExpression({expr.tags}, {self.rec(expr.expr, PREC_NONE)}" + def map_literal(self, expr, *args): return expr.s @@ -440,6 +459,10 @@ def map_tagged_variable(self, expr, *args, **kwargs): def map_loopy_function_identifier(self, expr, *args, **kwargs): return set() + def map_tagged_expression(self, expr, *args, **kwargs): + deps = self.rec(expr.expr, *args, **kwargs) + return deps + def map_sub_array_ref(self, expr, *args, **kwargs): deps = self.rec(expr.subscript, *args, **kwargs) return deps - set(expr.swept_inames) @@ -681,7 +704,7 @@ class TaggedVariable(LoopyExpressionBase, p.Variable, Taggable): A :class:`frozenset` of subclasses of :class:`pytools.tag.Tag` used to provide metadata on this object. Legacy string tags are converted to :class:`~loopy.LegacyStringInstructionTag` or, if they used to carry - a functional meaning, the tag carrying that same fucntional meaning + a functional meaning, the tag carrying that same functional meaning (e.g. :class:`~loopy.UseStreamingStoreTag`). Inherits from :class:`pymbolic.primitives.Variable` @@ -712,6 +735,40 @@ def copy(self, *, name=None, tags=None): mapper_method = intern("map_tagged_variable") +class TaggedExpression(LoopyExpressionBase): + """ + Represents a frozenset of tags attached to an :attr:`expr`. + + .. attribute:: tags + + A :class:`frozenset` of subclasses of :class:`pytools.tag.Tag` used to + provide metadata on this expression. + + .. attribute:: expr + + An expression to which :attr:`tags` are attached. + """ + + init_arg_names = ("tags", "expr") + + def __init__(self, tags, expr): + self.tags = tags + self.expr = expr + + def __getinitargs__(self): + return (self.tags, self.expr) + + def get_hash(self): + return hash((self.__class__, self.tags, self.expr)) + + def is_equal(self, other): + return (other.__class__ == self.__class__ + and other.tags == self.tags + and other.expr == self.expr) + + mapper_method = intern("map_tagged_expression") + + class Reduction(LoopyExpressionBase): """ Represents a reduction operation on :attr:`expr` across :attr:`inames`. diff --git a/loopy/translation_unit.py b/loopy/translation_unit.py index 0c009cdd0..7e01e7028 100644 --- a/loopy/translation_unit.py +++ b/loopy/translation_unit.py @@ -21,6 +21,7 @@ """ import collections +from typing import Callable, FrozenSet, Optional, Union from pytools import ImmutableRecord from pymbolic.primitives import Variable @@ -29,11 +30,12 @@ from loopy.symbolic import (RuleAwareIdentityMapper, ResolvedFunction, SubstitutionRuleMappingContext) from loopy.kernel.function_interface import ( - CallableKernel, ScalarCallable) + CallableKernel, InKernelCallable, ScalarCallable) from loopy.diagnostic import LoopyError from loopy.library.reduction import ReductionOpFunction from loopy.kernel import LoopKernel +from loopy.target import TargetBase from loopy.tools import update_persistent_hash from pymbolic.primitives import Call from pyrsistent import pmap, PMap @@ -176,6 +178,12 @@ class TranslationUnit(ImmutableRecord): :meth:`~TranslationUnit.copy`. """ + entrypoints: FrozenSet[str] + callables_table: PMap[str, InKernelCallable] + target: TargetBase + func_id_to_in_knl_callable_mappers: Optional[FrozenSet[ + Callable[[TargetBase, str], InKernelCallable]]] + def __init__(self, entrypoints=frozenset(), callables_table=None, @@ -288,7 +296,7 @@ def with_kernel(self, kernel): new_callables = self.callables_table.set(kernel.name, clbl) return self.copy(callables_table=new_callables) - def __getitem__(self, name): + def __getitem__(self, name) -> Union[InKernelCallable, LoopKernel]: """ For the callable named *name*, return a :class:`loopy.LoopKernel` if it's a :class:`~loopy.kernel.function_interface.CallableKernel` @@ -301,10 +309,12 @@ def __getitem__(self, name): return result @property - def default_entrypoint(self): + def default_entrypoint(self) -> LoopKernel: if len(self.entrypoints) == 1: entrypoint, = self.entrypoints - return self[entrypoint] + ep_kernel = self[entrypoint] + assert isinstance(ep_kernel, LoopKernel) + return ep_kernel else: raise ValueError("TranslationUnit has multiple possible entrypoints." " The default entrypoint kernel is not uniquely" @@ -353,7 +363,7 @@ def __call__(self, *args, **kwargs): return pex(*args, **kwargs) - def __str__(self): + def __str__(self) -> str: # FIXME: do a topological sort by the call graph def strify_callable(clbl): diff --git a/test/test_statistics.py b/test/test_statistics.py index 8b384e3e4..38f39d750 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1531,6 +1531,78 @@ def test_no_loop_ops(): assert f64_mul == 1 +from pytools.tag import Tag + + +class MyCostTag1(Tag): + pass + + +class MyCostTag2(Tag): + pass + + +class MyCostTagSum(Tag): + pass + + +def test_op_taggedexpression(): + from loopy.symbolic import TaggedExpression + from pymbolic.primitives import Subscript, Variable, Sum + + n = 500 + + knl = lp.make_kernel( + "{[i]: 0<=i