-
Notifications
You must be signed in to change notification settings - Fork 2
/
parallelization.py
106 lines (82 loc) · 3.37 KB
/
parallelization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import annotations
import copy
from typing import Dict, List
from tiralib.tiramisu.tiramisu_iterator_node import IteratorIdentifier
from tiralib.tiramisu.tiramisu_tree import TiramisuTree
from tiralib.tiramisu.tiramisu_actions.tiramisu_action import (
TiramisuAction,
TiramisuActionType,
)
class Parallelization(TiramisuAction):
"""
Parallelization optimization command.
"""
def __init__(
self,
params: List[IteratorIdentifier],
comps: List[str] | None = None,
):
# Parallelization only takes one parameter the loop to
# parallelize specified by a tuple (computation_name, iterator_level)
assert len(params) == 1
self.params = params
self.comps = comps
self.iterator_id = self.params[0]
super().__init__(
type=TiramisuActionType.PARALLELIZATION,
params=params,
comps=comps,
)
def initialize_action_for_tree(self, tiramisu_tree: TiramisuTree):
# we save a copy of the tree to be able to restore it later
self.tree = copy.deepcopy(tiramisu_tree)
if self.comps is None:
iterator = tiramisu_tree.get_iterator_of_computation(
self.iterator_id[0], self.iterator_id[1]
)
self.comps = tiramisu_tree.get_iterator_subtree_computations(iterator.name)
# order the computations by their absolute order
self.comps.sort(
key=lambda comp: tiramisu_tree.computations_absolute_order[comp]
)
self.set_string_representations(tiramisu_tree)
def set_string_representations(self, tiramisu_tree: TiramisuTree):
assert self.iterator_id is not None
assert self.comps is not None
level = self.iterator_id[1]
first_comp = self.comps[0]
self.tiramisu_optim_str = f"{first_comp}.tag_parallel_level({level});\n"
self.str_representation = f"P(L{level},comps={self.comps})"
self.legality_check_string = f"prepare_schedules_for_legality_checks(true);\n is_legal &= loop_parallelization_is_legal({level}, {{{', '.join([f'&{comp}' for comp in self.comps]) }}});\n {self.tiramisu_optim_str}" # noqa: E501
@classmethod
def _get_candidates_of_node(
cls, node_name: str, program_tree: TiramisuTree
) -> list:
candidates = []
node = program_tree.iterators[node_name]
if node.child_iterators:
candidates.append(
[program_tree.iterators[child].id for child in node.child_iterators]
)
for child in node.child_iterators:
candidates += cls._get_candidates_of_node(child, program_tree)
return candidates
@classmethod
def get_candidates(cls, program_tree: TiramisuTree) -> Dict[str, List[str]]:
"""Get the list of candidates for parallelization.
Parameters:
----------
`program_tree`: `TiramisuTree`
The Tiramisu tree of the program.
Returns:
-------
`Dict`
Dictionary of candidates for parallelization of each root.
"""
candidates = {}
for root in program_tree.roots:
rootId = program_tree.iterators[root].id
candidates[rootId] = [[rootId]] + cls._get_candidates_of_node(
root, program_tree
)
return candidates