Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
schedule_transforms.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <iostream>
19 #include <memory>
20 #include <string>
21 #include <unordered_set>
22 #include <vector>
23 
28 #include "tc/external/isl.h"
29 
30 namespace tc {
31 namespace polyhedral {
33 // Transformation functions, out-of-class
35 // Starting from the "start" ScheduleTree, iteratively traverse the subtree
36 // using the "next" function and collect all nodes along the way.
37 // Stop when "next" returns nullptr.
38 // The returned vector begins with "start".
39 std::vector<detail::ScheduleTree*> collectScheduleTreesPath(
40  std::function<detail::ScheduleTree*(detail::ScheduleTree*)> next,
41  detail::ScheduleTree* start);
42 std::vector<const detail::ScheduleTree*> collectScheduleTreesPath(
43  std::function<const detail::ScheduleTree*(const detail::ScheduleTree*)>
44  next,
45  const detail::ScheduleTree* start);
46 
47 // Joins 2 perfectly nested bands into a single band.
48 // This is a structural transformation but it is not necessarily correct
49 // semantically. In particular, the user is responsible for setting the
50 // permutability of the band since it is generally required to perform
51 // dependence analysis to determine permutability.
52 // The coincident fields corresponding to members of the inner band are cleared.
53 detail::ScheduleTree* joinBands(
54  detail::ScheduleTree* tree,
55  bool permutable = false);
56 
57 // Iteratively joins perfectly nested bands into a single band.
58 // This is a structural transformation but it is not necessarily correct
59 // semantically. In particular, the user is responsible for setting the
60 // permutability of the band since it is generally required to perform
61 // dependence analysis to determine permutability.
62 // The coincident fields corresponding to members of inner bands are cleared.
63 detail::ScheduleTree* joinBandsIterative(
64  detail::ScheduleTree* tree,
65  bool permutable = false);
66 
67 // Split tree rooted under relativeRoot two nested trees, one with the first
68 // "pos" dimensions and one with the remaining dimensions.
69 // The schedules of the two bands live in anonymous spaces.
70 // This updates the current ScheduleTree and returns it so we can chain
71 // expressions.
72 detail::ScheduleTree* bandSplit(
73  detail::ScheduleTree* relativeRoot,
74  detail::ScheduleTree* tree,
75  size_t pos);
76 // Split band rooted under relativeRoot into at most three nested band
77 // such that the band member at position "pos" is isolated
78 // into a single-member band.
79 // The schedules of the split bands live in anonymous spaces.
80 // Update the current ScheduleTree and return
81 // a pointer to band containing the isolated member.
82 detail::ScheduleTree* bandSplitOut(
83  detail::ScheduleTree* relativeRoot,
84  detail::ScheduleTree* tree,
85  size_t pos);
86 
87 // The semantics for this function is somewhat richer than the ISL C semantics.
88 // Since tiling is implemented as a simple band.mupa_ tranformation we can
89 // just complete it with 0 on the unspecified dimensions.
90 // This has the effect of pushing the non-tiled outer-loop inside the tile.
91 // i.e. for i, j, k -> for i, j, ii, jj, k
92 //
93 // On the contrary if you want to keep the non-tiled outer-loop outside the
94 // tile, you can just specify tile size of 1 which, similarly to the current
95 // ISL behavior, will make it so.
96 // i.e. for i, j, k -> for i, j, k, ii, jj, kk where range(kk)=[0, 1]
97 //
98 // This will automatically drop innermost sizes in excess of band->nMember()
99 //
100 // Modifies tree in place and returns it for call chaining purposes
101 //
102 // TODO: Support imperfectly nested tiling
103 detail::ScheduleTree* bandTile(
104  detail::ScheduleTree* tree,
105  const std::vector<size_t>& tileSizes,
106  TileOptions tileOptions);
107 
108 // Change the partial schedule of the band in place by multiplying it with the
109 // given scales. The size of the "scales" vector must correspond to the number
110 // of band members.
111 //
112 // This will automatically drop innermost sizes in excess of band->nMember()
113 detail::ScheduleTree* bandScale(
114  detail::ScheduleTree* tree,
115  const std::vector<size_t>& scales);
116 
117 // Map "pos"-th schedule dimension of the band node identified by "tree" to a
118 // _new_ parameter identified by "id" and limited by 0 <= id < extent. The
119 // parameter must not be present in the space of partial schedule of "tree" and
120 // extent must be non-zero. The mapping corresponds to inserting a filter
121 // node with condition 'dim % extent = id' where dim is "pos"-th
122 // schedule dimension.
123 //
124 // Returns a pointer to the updated band (below the inserted filter)
125 // for call chaining purposes.
126 template <typename MappingIdType>
127 detail::ScheduleTree* mapToParameterWithExtent(
128  detail::ScheduleTree* root,
129  detail::ScheduleTree* tree,
130  int pos,
131  MappingIdType id,
132  size_t extent);
133 
134 // In a tree starting at a (relative) "root", insert a band node with the
135 // given partial schedule above the node identified by "tree".
136 //
137 // The tree is modified in place.
138 // Return a non-owning pointer to the inserted band node
139 // for call chaining purposes.
140 detail::ScheduleTree* insertBandAbove(
141  detail::ScheduleTree* root,
142  detail::ScheduleTree* tree,
143  isl::multi_union_pw_aff mupa);
144 
145 // Insert a band node with the given partial schedule below node "tree",
146 // which is assumed to have at most one child.
147 //
148 // The tree is modified in place.
149 // Return a non-owning pointer to the inserted band node
150 // for call chaining purposes.
151 detail::ScheduleTree* insertBandBelow(
152  detail::ScheduleTree* tree,
153  isl::multi_union_pw_aff mupa);
154 
155 // Update the top-level conext node by intersecting it with "context". The
156 // top-level context node must be located directly under the root of the tree.
157 // If there is no such node, insert one with universe context first.
158 void updateTopLevelContext(detail::ScheduleTree* root, isl::set context);
159 
160 // In a tree starting at a (relative) "root", insert a sequence node with
161 // as only child the node identified by "tree".
162 //
163 // The tree is modified in place.
164 // Return a non-owning pointer to the inserted sequence node
165 // for call chaining purposes.
166 detail::ScheduleTree* insertSequenceAbove(
167  detail::ScheduleTree* root,
168  detail::ScheduleTree* tree);
169 
170 // In a tree starting at a (relative) "root", insert an extension node with the
171 // given extension above the node identified by "tree".
172 //
173 // The tree is modified in place.
174 // Return a non-owning pointer to the inserted extension node
175 // for call chaining purposes.
176 detail::ScheduleTree* insertExtensionAbove(
177  detail::ScheduleTree* root,
178  detail::ScheduleTree* tree,
179  isl::union_map extension);
180 
181 // In a tree starting at a (relative) "root", insert a mapping filter node
182 // with the given filter above the node identified by "tree".
183 //
184 // The tree is modified in place.
185 // Return a non-owning pointer to the inserted filter node
186 // for call chaining purposes.
187 template <typename MappingIdType>
188 inline detail::ScheduleTree* insertMappingFilterAbove(
189  detail::ScheduleTree* root,
190  detail::ScheduleTree* tree,
191  isl::union_set filter,
192  const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
193  mappingIds);
194 
195 // Insert a mapping filter node below node "tree", which is assumed to have at
196 // most one child. The underlying isl::union_set filter is constructed from
197 // the arguments.
198 //
199 // The tree is modified in place.
200 template <typename MappingIdType>
201 inline void insertMappingFilterBelow(
202  detail::ScheduleTree* tree,
203  isl::union_set filter,
204  const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
205  mappingIds);
206 
207 // Given a sequence node in the schedule tree, insert
208 // a zero-dimensional extension statement with the given identifier
209 // before the child at position "pos".
210 // If "pos" is equal to the number of children, then
211 // the statement is added after the last child.
212 void insertExtensionLabelAt(
213  detail::ScheduleTree* root,
214  detail::ScheduleTree* seqNode,
215  size_t pos,
216  isl::id id);
217 
218 // Insert a zero-dimensional extension statement with the given identifier
219 // before node "tree".
220 // If "tree" is a sequence node or a grandchild of a sequence node,
221 // then the new statement is inserted in the right position
222 // of that sequence node.
223 // Otherwise, a new sequence node is inserted.
224 void insertExtensionLabelBefore(
225  detail::ScheduleTree* root,
226  detail::ScheduleTree* tree,
227  isl::id id);
228 
229 // Insert a zero-dimensional extension statement with the given identifier
230 // after node "tree".
231 // If "tree" is a sequence node or a grandchild of a sequence node,
232 // then the new statement is inserted in the right position
233 // of that sequence node.
234 // Otherwise, a new sequence node is inserted.
235 void insertExtensionLabelAfter(
236  detail::ScheduleTree* root,
237  detail::ScheduleTree* tree,
238  isl::id id);
239 
240 // Insert a sequence to ensure that the active domain elements
241 // in the given filter are executed before the other active domain elements.
242 void orderBefore(
243  detail::ScheduleTree* root,
244  detail::ScheduleTree* tree,
245  isl::union_set filter);
246 // Insert a sequence to ensure that the active domain elements
247 // in the given filter are executed after the other active domain elements.
248 void orderAfter(
249  detail::ScheduleTree* root,
250  detail::ScheduleTree* tree,
251  isl::union_set filter);
252 
253 // Given a schedule defined by the ancestors of the given node,
254 // extend it to a schedule that also covers the node itself.
255 isl::union_map extendSchedule(
256  const detail::ScheduleTree* node,
257  isl::union_map schedule);
258 
259 // Get the partial schedule defined by ancestors of the given node and the node
260 // itself.
261 isl::union_map partialSchedule(
262  const detail::ScheduleTree* root,
263  const detail::ScheduleTree* node);
264 
265 // Return the schedule defined by the ancestors of the given node.
266 isl::union_map prefixSchedule(
267  const detail::ScheduleTree* root,
268  const detail::ScheduleTree* node);
269 
270 // Return the concatenation of all outer band node partial schedules.
271 // If there are no outer band nodes, then return a zero-dimensional
272 // function on the universe domain of the schedule tree.
273 // Note that unlike isl_schedule_node_get_prefix_schedule_multi_union_pw_aff,
274 // this function does not take into account any intermediate filter nodes.
275 isl::multi_union_pw_aff prefixScheduleMupa(
276  const detail::ScheduleTree* root,
277  const detail::ScheduleTree* tree);
278 
279 // Get the set of domain points active at the given node. A domain
280 // point is active if it was not filtered away on the path from the
281 // root to the node. The root must be a domain element, otherwise no
282 // elements would be considered active.
283 isl::union_set activeDomainPoints(
284  const detail::ScheduleTree* root,
285  const detail::ScheduleTree* node);
286 
287 // Get the set of statement identifiers whose domains have at least one active
288 // point at the given node, i.e. the statements that were not filtered away on
289 // the path from root to node.
290 std::unordered_set<isl::id, isl::IslIdIslHash> activeStatements(
291  const detail::ScheduleTree* root,
292  const detail::ScheduleTree* node);
293 
295 // Experimental
297 // Mapping filters are introduced one mapping dimension at a time.
298 // This merges consecutive filters.
299 detail::ScheduleTree* mergeConsecutiveMappingFilters(
300  detail::ScheduleTree* root,
301  detail::ScheduleTree* node);
302 
303 } // namespace polyhedral
304 } // namespace tc
305