Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New branch for TP axis tag testing #5

Draft
wants to merge 1 commit into
base: production
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 52 additions & 47 deletions pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
AdvancedIndexInNoncontiguousAxes,
NormalizedSlice, ShapeType,
AbstractResultWithNamedArrays)
from pytato.scalar_expr import ScalarExpression, INT_CLASSES, IntegralT
from pytato.scalar_expr import ScalarExpression, INT_CLASSES
from pytato.diagnostic import CannotBeLoweredToIndexLambda
from pytato.tags import AssumeNonNegative
from pytato.transform import Mapper
Expand All @@ -51,53 +51,58 @@ def _get_reshaped_indices(expr: Reshape) -> Tuple[ScalarExpression, ...]:
assert expr.size == 1
return ()

if expr.order not in ["C", "F"]:
if expr.order.upper() not in ["C", "F"]:
raise NotImplementedError("Order expected to be 'C' or 'F'",
f" found {expr.order}")

if expr.order == "C":
newstrides: List[IntegralT] = [1] # reshaped array strides
for new_axis_len in reversed(expr.shape[1:]):
assert isinstance(new_axis_len, INT_CLASSES)
newstrides.insert(0, newstrides[0]*new_axis_len)

flattened_idx = sum(prim.Variable(f"_{i}")*stride
for i, stride in enumerate(newstrides))

oldstrides: List[IntegralT] = [1] # input array strides
for axis_len in reversed(expr.array.shape[1:]):
assert isinstance(axis_len, INT_CLASSES)
oldstrides.insert(0, oldstrides[0]*axis_len)

assert isinstance(expr.array.shape[-1], INT_CLASSES)
oldsizetills = [expr.array.shape[-1]] # input array size
# till for axes idx
for old_axis_len in reversed(expr.array.shape[:-1]):
assert isinstance(old_axis_len, INT_CLASSES)
oldsizetills.insert(0, oldsizetills[0]*old_axis_len)

else:
newstrides: List[IntegralT] = [1] # reshaped array strides
for new_axis_len in expr.shape[:-1]:
assert isinstance(new_axis_len, INT_CLASSES)
newstrides.append(newstrides[-1]*new_axis_len)

flattened_idx = sum(prim.Variable(f"_{i}")*stride
for i, stride in enumerate(newstrides))

oldstrides: List[IntegralT] = [1] # input array strides
for axis_len in expr.array.shape[:-1]:
assert isinstance(axis_len, INT_CLASSES)
oldstrides.append(oldstrides[-1]*axis_len)

assert isinstance(expr.array.shape[0], INT_CLASSES)
oldsizetills = [expr.array.shape[0]] # input array size till for axes idx
for old_axis_len in expr.array.shape[1:]:
assert isinstance(old_axis_len, INT_CLASSES)
oldsizetills.append(oldsizetills[-1]*old_axis_len)

return tuple(((flattened_idx % sizetill) // stride)
for stride, sizetill in zip(oldstrides, oldsizetills))
f"(case insensitive) found {expr.order}")

order = expr.order
oldshape = expr.array.shape
newshape = expr.shape

# {{{ compute strides

oldstrides = [1]
oldstride_axes = (reversed(oldshape[1:]) if order == "C" else oldshape[:-1])

for ax_len in oldstride_axes:
assert isinstance(ax_len, INT_CLASSES)
oldstrides.append(oldstrides[-1]*ax_len)

newstrides = [1]
newstride_axes = (reversed(newshape[1:]) if order == "C" else newshape[:-1])

for ax_len in newstride_axes:
assert isinstance(ax_len, INT_CLASSES)
newstrides.append(newstrides[-1]*ax_len)

# }}}

# {{{ compute size tills

oldsizetills = [oldshape[-1] if order == "C" else oldshape[0]]
oldsizetill_ax = (oldshape[:-1][::-1] if order == "C" else oldshape[:-1])
for ax_len in oldsizetill_ax:
oldsizetills.append(oldsizetills[-1]*ax_len)

# }}}

# {{{ if order is C, then computed info is backwards

if order == "C":
oldstrides = oldstrides[::-1]
newstrides = newstrides[::-1]
oldsizetills = oldsizetills[::-1]

# }}}

flattened_idx = sum(prim.Variable(f"_{i}")*stride
for i, stride in enumerate(newstrides))

ret = tuple(
(flattened_idx % sizetill) // stride
for stride, sizetill in zip(oldstrides, oldsizetills))

return ret


class ToIndexLambdaMixin:
Expand Down
92 changes: 42 additions & 50 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"""


from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict, FrozenSet,
from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict,
Mapping, Iterable, Any, TypeVar, cast)
from bidict import bidict
from pytato.scalar_expr import SCALAR_CLASSES
Expand All @@ -58,7 +58,7 @@
from pytato.diagnostic import UnknownIndexLambdaExpr

from pytools import UniqueNameGenerator
from pytools.tag import Tag
from pytools.tag import Tag, UniqueTag
import logging
logger = logging.getLogger(__name__)

Expand All @@ -70,6 +70,22 @@
GraphNodeT = TypeVar("GraphNodeT")


# {{{ IgnoredForPropagationTag

class AxisIgnoredForPropagationTag(UniqueTag):
"""
Used to influence equality constraints when determining which axes tags
are allowed to propagate along.

The intended use case for this is to prevent the axes of a matrix used to,
for example, differentiate a tensor of DOF data from picking up on the
unique tags attached to the axes of the tensor.
"""
pass

# }}}


# {{{ AxesTagsEquationCollector

class AxesTagsEquationCollector(Mapper):
Expand Down Expand Up @@ -167,6 +183,8 @@ def record_equations_from_axes_tags(self, ary: Array) -> None:
Records equations for *ary*\'s axis tags of type :attr:`tag_t`.
"""
for iaxis, axis in enumerate(ary.axes):
if axis.tags_of_type(AxisIgnoredForPropagationTag):
continue
lhs_var = self.get_var_for_axis(ary, iaxis)
for tag in axis.tags_of_type(self.tag_t):
rhs_var = self.get_var_for_tag(tag)
Expand Down Expand Up @@ -492,11 +510,12 @@ def map_einsum(self, expr: Einsum) -> None:
descr_to_var[EinsumElementwiseAxis(iaxis)] = self.get_var_for_axis(expr,
iaxis)

for access_descrs, arg in zip(expr.access_descriptors,
expr.args):
for access_descrs, arg in zip(expr.access_descriptors, expr.args):
for iarg_axis, descr in enumerate(access_descrs):
in_tag_var = self.get_var_for_axis(arg, iarg_axis)
if arg.axes[iarg_axis].tags_of_type(AxisIgnoredForPropagationTag):
continue

in_tag_var = self.get_var_for_axis(arg, iarg_axis)
if descr in descr_to_var:
self.record_equation(descr_to_var[descr], in_tag_var)
else:
Expand Down Expand Up @@ -556,38 +575,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array:
# }}}


def _get_propagation_graph_from_constraints(
equations: List[Tuple[str, str]]) -> Mapping[str, FrozenSet[str]]:
from immutabledict import immutabledict
propagation_graph: Dict[str, Set[str]] = {}
for lhs, rhs in equations:
assert lhs != rhs
propagation_graph.setdefault(lhs, set()).add(rhs)
propagation_graph.setdefault(rhs, set()).add(lhs)

return immutabledict({k: frozenset(v)
for k, v in propagation_graph.items()})


def get_reachable_nodes(undirected_graph: Mapping[GraphNodeT, Iterable[GraphNodeT]],
source_node: GraphNodeT) -> FrozenSet[GraphNodeT]:
"""
Returns a :class:`frozenset` of all nodes in *undirected_graph* that are
reachable from *source_node*.
"""
nodes_visited: Set[GraphNodeT] = set()
nodes_to_visit = {source_node}
while nodes_to_visit:
current_node = nodes_to_visit.pop()
nodes_visited.add(current_node)

neighbors = undirected_graph[current_node]
nodes_to_visit.update({node
for node in neighbors
if node not in nodes_visited})

return frozenset(nodes_visited)

# {{{ AxisTagAttacher

class AxisTagAttacher(CopyMapper):
"""
Expand All @@ -614,12 +602,8 @@ def rec(self, expr: ArrayOrNames) -> Any:
assert expr_copy.ndim == expr.ndim

for iaxis in range(expr.ndim):
axis_tags = self.axis_to_tags.get((expr, iaxis), [])
if len(axis_tags) == 0:
print(f"failed to infer axis {iaxis} of array of type {type(expr)}.")
print(f"{expr.non_equality_tags=}")
expr_copy = expr_copy.with_tagged_axis(
iaxis, axis_tags)
iaxis, self.axis_to_tags.get((expr, iaxis), []))

# {{{ tag reduction descrs

Expand Down Expand Up @@ -663,6 +647,8 @@ def __call__(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override
assert isinstance(result, (Array, AbstractResultWithNamedArrays))
return result

# }}}


def unify_axes_tags(
expr: ArrayOrNames,
Expand Down Expand Up @@ -697,19 +683,25 @@ def unify_axes_tags(
# Defn. A Propagation graph is a graph where nodes denote variables and an
# edge between 2 nodes denotes an equality criterion.

propagation_graph = _get_propagation_graph_from_constraints(
equations_collector.equations)
from pytools.graph import (
get_propagation_graph_from_constraints,
get_reachable_nodes
)

known_tag_vars = frozenset(equations_collector.known_tag_to_var.values())
axis_to_solved_tags: Dict[Tuple[Array, int], Set[Tag]] = {}

propagation_graph = get_propagation_graph_from_constraints(
equations_collector.equations,
)

for tag, var in equations_collector.known_tag_to_var.items():
for reachable_var in (get_reachable_nodes(propagation_graph, var)
- known_tag_vars):
axis_to_solved_tags.setdefault(
equations_collector.axis_to_var.inverse[reachable_var],
set()
).add(tag)
reachable_nodes = get_reachable_nodes(propagation_graph, var)
for reachable_var in (reachable_nodes - known_tag_vars):
axis_to_solved_tags.setdefault(
equations_collector.axis_to_var.inverse[reachable_var],
set()
).add(tag)

return AxisTagAttacher(axis_to_solved_tags,
tag_corresponding_redn_descr=unify_redn_descrs,
Expand Down
Loading