diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index af9823856..892fa2c20 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -37,7 +37,7 @@ AdvancedIndexInNoncontiguousAxes, NormalizedSlice, ShapeType, AbstractResultWithNamedArrays) -from pytato.scalar_expr import ScalarExpression, INT_CLASSES, IntegralT +from pytato.scalar_expr import ScalarExpression, INT_CLASSES from pytato.diagnostic import CannotBeLoweredToIndexLambda from pytato.tags import AssumeNonNegative from pytato.transform import Mapper @@ -51,53 +51,58 @@ def _get_reshaped_indices(expr: Reshape) -> Tuple[ScalarExpression, ...]: assert expr.size == 1 return () - if expr.order not in ["C", "F"]: + if expr.order.upper() not in ["C", "F"]: raise NotImplementedError("Order expected to be 'C' or 'F'", - f" found {expr.order}") - - if expr.order == "C": - newstrides: List[IntegralT] = [1] # reshaped array strides - for new_axis_len in reversed(expr.shape[1:]): - assert isinstance(new_axis_len, INT_CLASSES) - newstrides.insert(0, newstrides[0]*new_axis_len) - - flattened_idx = sum(prim.Variable(f"_{i}")*stride - for i, stride in enumerate(newstrides)) - - oldstrides: List[IntegralT] = [1] # input array strides - for axis_len in reversed(expr.array.shape[1:]): - assert isinstance(axis_len, INT_CLASSES) - oldstrides.insert(0, oldstrides[0]*axis_len) - - assert isinstance(expr.array.shape[-1], INT_CLASSES) - oldsizetills = [expr.array.shape[-1]] # input array size - # till for axes idx - for old_axis_len in reversed(expr.array.shape[:-1]): - assert isinstance(old_axis_len, INT_CLASSES) - oldsizetills.insert(0, oldsizetills[0]*old_axis_len) - - else: - newstrides: List[IntegralT] = [1] # reshaped array strides - for new_axis_len in expr.shape[:-1]: - assert isinstance(new_axis_len, INT_CLASSES) - newstrides.append(newstrides[-1]*new_axis_len) - - flattened_idx = sum(prim.Variable(f"_{i}")*stride - for i, stride in enumerate(newstrides)) - - oldstrides: List[IntegralT] = [1] # input array strides - for axis_len in expr.array.shape[:-1]: - assert isinstance(axis_len, INT_CLASSES) - oldstrides.append(oldstrides[-1]*axis_len) - - assert isinstance(expr.array.shape[0], INT_CLASSES) - oldsizetills = [expr.array.shape[0]] # input array size till for axes idx - for old_axis_len in expr.array.shape[1:]: - assert isinstance(old_axis_len, INT_CLASSES) - oldsizetills.append(oldsizetills[-1]*old_axis_len) - - return tuple(((flattened_idx % sizetill) // stride) - for stride, sizetill in zip(oldstrides, oldsizetills)) + f"(case insensitive) found {expr.order}") + + order = expr.order + oldshape = expr.array.shape + newshape = expr.shape + + # {{{ compute strides + + oldstrides = [1] + oldstride_axes = (reversed(oldshape[1:]) if order == "C" else oldshape[:-1]) + + for ax_len in oldstride_axes: + assert isinstance(ax_len, INT_CLASSES) + oldstrides.append(oldstrides[-1]*ax_len) + + newstrides = [1] + newstride_axes = (reversed(newshape[1:]) if order == "C" else newshape[:-1]) + + for ax_len in newstride_axes: + assert isinstance(ax_len, INT_CLASSES) + newstrides.append(newstrides[-1]*ax_len) + + # }}} + + # {{{ compute size tills + + oldsizetills = [oldshape[-1] if order == "C" else oldshape[0]] + oldsizetill_ax = (oldshape[:-1][::-1] if order == "C" else oldshape[:-1]) + for ax_len in oldsizetill_ax: + oldsizetills.append(oldsizetills[-1]*ax_len) + + # }}} + + # {{{ if order is C, then computed info is backwards + + if order == "C": + oldstrides = oldstrides[::-1] + newstrides = newstrides[::-1] + oldsizetills = oldsizetills[::-1] + + # }}} + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + ret = tuple( + (flattened_idx % sizetill) // stride + for stride, sizetill in zip(oldstrides, oldsizetills)) + + return ret class ToIndexLambdaMixin: diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 780d96fd4..638c4d08b 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -35,7 +35,7 @@ """ -from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict, FrozenSet, +from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict, Mapping, Iterable, Any, TypeVar, cast) from bidict import bidict from pytato.scalar_expr import SCALAR_CLASSES @@ -58,7 +58,7 @@ from pytato.diagnostic import UnknownIndexLambdaExpr from pytools import UniqueNameGenerator -from pytools.tag import Tag +from pytools.tag import Tag, UniqueTag import logging logger = logging.getLogger(__name__) @@ -70,6 +70,22 @@ GraphNodeT = TypeVar("GraphNodeT") +# {{{ IgnoredForPropagationTag + +class AxisIgnoredForPropagationTag(UniqueTag): + """ + Used to influence equality constraints when determining which axes tags + are allowed to propagate along. + + The intended use case for this is to prevent the axes of a matrix used to, + for example, differentiate a tensor of DOF data from picking up on the + unique tags attached to the axes of the tensor. + """ + pass + +# }}} + + # {{{ AxesTagsEquationCollector class AxesTagsEquationCollector(Mapper): @@ -167,6 +183,8 @@ def record_equations_from_axes_tags(self, ary: Array) -> None: Records equations for *ary*\'s axis tags of type :attr:`tag_t`. """ for iaxis, axis in enumerate(ary.axes): + if axis.tags_of_type(AxisIgnoredForPropagationTag): + continue lhs_var = self.get_var_for_axis(ary, iaxis) for tag in axis.tags_of_type(self.tag_t): rhs_var = self.get_var_for_tag(tag) @@ -492,11 +510,12 @@ def map_einsum(self, expr: Einsum) -> None: descr_to_var[EinsumElementwiseAxis(iaxis)] = self.get_var_for_axis(expr, iaxis) - for access_descrs, arg in zip(expr.access_descriptors, - expr.args): + for access_descrs, arg in zip(expr.access_descriptors, expr.args): for iarg_axis, descr in enumerate(access_descrs): - in_tag_var = self.get_var_for_axis(arg, iarg_axis) + if arg.axes[iarg_axis].tags_of_type(AxisIgnoredForPropagationTag): + continue + in_tag_var = self.get_var_for_axis(arg, iarg_axis) if descr in descr_to_var: self.record_equation(descr_to_var[descr], in_tag_var) else: @@ -556,38 +575,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array: # }}} -def _get_propagation_graph_from_constraints( - equations: List[Tuple[str, str]]) -> Mapping[str, FrozenSet[str]]: - from immutabledict import immutabledict - propagation_graph: Dict[str, Set[str]] = {} - for lhs, rhs in equations: - assert lhs != rhs - propagation_graph.setdefault(lhs, set()).add(rhs) - propagation_graph.setdefault(rhs, set()).add(lhs) - - return immutabledict({k: frozenset(v) - for k, v in propagation_graph.items()}) - - -def get_reachable_nodes(undirected_graph: Mapping[GraphNodeT, Iterable[GraphNodeT]], - source_node: GraphNodeT) -> FrozenSet[GraphNodeT]: - """ - Returns a :class:`frozenset` of all nodes in *undirected_graph* that are - reachable from *source_node*. - """ - nodes_visited: Set[GraphNodeT] = set() - nodes_to_visit = {source_node} - while nodes_to_visit: - current_node = nodes_to_visit.pop() - nodes_visited.add(current_node) - - neighbors = undirected_graph[current_node] - nodes_to_visit.update({node - for node in neighbors - if node not in nodes_visited}) - - return frozenset(nodes_visited) - +# {{{ AxisTagAttacher class AxisTagAttacher(CopyMapper): """ @@ -614,12 +602,8 @@ def rec(self, expr: ArrayOrNames) -> Any: assert expr_copy.ndim == expr.ndim for iaxis in range(expr.ndim): - axis_tags = self.axis_to_tags.get((expr, iaxis), []) - if len(axis_tags) == 0: - print(f"failed to infer axis {iaxis} of array of type {type(expr)}.") - print(f"{expr.non_equality_tags=}") expr_copy = expr_copy.with_tagged_axis( - iaxis, axis_tags) + iaxis, self.axis_to_tags.get((expr, iaxis), [])) # {{{ tag reduction descrs @@ -663,6 +647,8 @@ def __call__(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override assert isinstance(result, (Array, AbstractResultWithNamedArrays)) return result +# }}} + def unify_axes_tags( expr: ArrayOrNames, @@ -697,19 +683,25 @@ def unify_axes_tags( # Defn. A Propagation graph is a graph where nodes denote variables and an # edge between 2 nodes denotes an equality criterion. - propagation_graph = _get_propagation_graph_from_constraints( - equations_collector.equations) + from pytools.graph import ( + get_propagation_graph_from_constraints, + get_reachable_nodes + ) known_tag_vars = frozenset(equations_collector.known_tag_to_var.values()) axis_to_solved_tags: Dict[Tuple[Array, int], Set[Tag]] = {} + propagation_graph = get_propagation_graph_from_constraints( + equations_collector.equations, + ) + for tag, var in equations_collector.known_tag_to_var.items(): - for reachable_var in (get_reachable_nodes(propagation_graph, var) - - known_tag_vars): - axis_to_solved_tags.setdefault( - equations_collector.axis_to_var.inverse[reachable_var], - set() - ).add(tag) + reachable_nodes = get_reachable_nodes(propagation_graph, var) + for reachable_var in (reachable_nodes - known_tag_vars): + axis_to_solved_tags.setdefault( + equations_collector.axis_to_var.inverse[reachable_var], + set() + ).add(tag) return AxisTagAttacher(axis_to_solved_tags, tag_corresponding_redn_descr=unify_redn_descrs,