Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
scop.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include <dlpack/dlpack.h>
25 
26 #include "tc/core/constants.h"
27 #include "tc/core/halide2isl.h"
31 #include "tc/core/tc2halide.h"
32 #include "tc/external/isl.h"
33 
34 namespace tc {
35 namespace polyhedral {
36 
37 // Reduction dims must be properly ordered
38 using ReductionDimSet = std::set<std::string>;
39 class TensorReferenceGroup;
40 
41 class MappedScop;
42 
43 struct Scop {
44  private:
45  Scop() {}
46 
47  public:
48  // Should be reserved for internal use and unit testing.
49  static std::unique_ptr<Scop> makeScop(
50  isl::ctx ctx,
51  const tc2halide::HalideComponents& components);
52 
53  // Preferred points of entry, given a TC string or a treeRef,
54  // Halide IR is constructed and made a member by setting halideComponents.
55  // These operations are grouped and scheduled in a halide::Stmt which becomes
56  // the unit from which the scop is constructed.
57  static std::unique_ptr<Scop> makeScop(isl::ctx ctx, const std::string& tc);
58 
59  static std::unique_ptr<Scop> makeScop(
60  isl::ctx ctx,
61  const lang::TreeRef& treeRef);
62 
63  // Clone a Scop
64  static std::unique_ptr<Scop> makeScop(const Scop& scop) {
65  auto res = std::unique_ptr<Scop>(new Scop());
66  res->globalParameterContext = scop.globalParameterContext;
67  res->halide = scop.halide;
68  res->reads = scop.reads;
69  res->writes = scop.writes;
70  res->scheduleTreeUPtr =
71  detail::ScheduleTree::makeScheduleTree(*scop.scheduleTreeUPtr);
72  res->treeSyncUpdateMap = scop.treeSyncUpdateMap;
73  res->defaultReductionInitMap = scop.defaultReductionInitMap;
74  res->groupCounts_ = scop.groupCounts_;
75  res->promotedDecls_ = scop.promotedDecls_;
76  res->activePromotions_ = scop.activePromotions_;
77  return res;
78  }
79 
80  // Intersect globalParameterContext with extraGlobalParameterContext.
81  inline void intersectContext(isl::set extraGlobalParameterContext) {
82  auto context = globalParameterContext & extraGlobalParameterContext;
83  globalParameterContext = context;
84  }
85 
86  // Specialize a Scop with extra globalParameterContext information
87  // If you want to intersect the support domain with the
88  // extraGlobalParameterContext then you need to do it explicitly.
89  // Otherwise ambiguities will ensue.
90  // TODO: this is still subject to interpretation but intersecting seems a
91  // bit final here so probably we're right not doing it.
92  static std::unique_ptr<Scop> makeSpecializedScop(
93  const Scop& scop,
94  isl::set extraGlobalParameterContext) {
95  CHECK(extraGlobalParameterContext.is_subset(scop.globalParameterContext))
96  << "expected extra context " << extraGlobalParameterContext
97  << " to be more specialized than " << scop.globalParameterContext;
98  auto res = makeScop(scop);
99  res->intersectContext(extraGlobalParameterContext);
100  // **WARNING** if called before scheduling, this could result in a
101  // (partially) specialized schedule, i.e. force
102  // strategy.proto.fix_parameters_before_scheduling to true.
103  // If you want to intersect the support domain with the
104  // extraGlobalParameterContext then you need to do it explicitly.
105  // Note that the access relations must be intersect with the context as
106  // well to obtain consistent dependences.
107  // TODO: this is still subject to interpretation but intersecting seems
108  // final here so probably we're right not doing it.
109  // res->domain() =
110  // res->domain().intersect_params(res->globalParameterContext);
111  return res;
112  }
113 
114  // Specialize the Scop with respect to its globalParameterContext.
115  void specializeToContext() {
116  domain() = domain().intersect_params(globalParameterContext);
117  reads = reads.intersect_params(globalParameterContext);
118  writes = writes.intersect_params(globalParameterContext);
119  }
120 
121  // Returns a set that specializes (all) the scop's parameter space to the
122  // integer values passed to the function.
123  // WARNING: this version relies on parameter ordering, be sure you know what
124  // you are doing.
125  template <typename T>
126  isl::set makeContext(const std::vector<T>& sizes = std::vector<T>()) const {
127  auto s = domain().get_space().params();
128  return makeSpecializationSet(s, sizes);
129  }
130 
131  // Returns a set that specializes the (positional) scop's subset of
132  // parameter space to the integer values passed to the function.
133  template <typename T>
134  isl::set makeContext(
135  const std::unordered_map<int, T>& sizes =
136  std::unordered_map<int, T>()) const {
137  auto s = domain().get_space().params();
138  return makeSpecializationSet(s, sizes);
139  }
140 
141  // Returns a set that specializes the named scop's subset of
142  // parameter space to the integer values passed to the function.
143  template <typename T>
144  isl::set makeContext(
145  const std::unordered_map<std::string, T>& sizes =
146  std::unordered_map<std::string, T>()) const {
147  auto s = domain().get_space().params();
148  return makeSpecializationSet(s, sizes);
149  }
150 
151  // Compute the values of parameters based on the effective sizes of the
152  // tensors provided as arguments and their parametric expressions stored in
153  // halide ImageParams. We only know input sizes, output sizes are inferred.
154  // Result is an isl set directly usable as context.
155  isl::set makeContextFromInputs(
156  const std::vector<const DLTensor*>& inputs) const;
157 
158  // Fix the values of the specified parameters in the context
159  // to the corresponding specified values.
160  template <typename T>
161  void fixParameters(const std::unordered_map<std::string, T>& sizes) {
162  intersectContext(makeContext(sizes));
163  }
164 
165  // Given the context set, return the list of parameter values in the same
166  // order as codegen places them in the function signature, i.e. following the
167  // order of scop.params.
168  std::vector<long> getParameterValues(isl::set context) const;
169 
170  isl::id nextGroupIdForTensor(isl::id tensorId) {
171  auto ctx = domain().get_ctx();
172  std::stringstream ss;
173  ss << "_" << tensorId.get_name() << "_" << groupCounts_[tensorId]++;
174  return isl::id(ctx, ss.str());
175  }
176 
177  // Assuming redPoint is a reduction candidate node with
178  // the given reduction update statement identifier,
179  // add an extension node for a reduction init and
180  // a reduction update statement and insert the new
181  // statements before and after (the children of) redPoint.
182  // If redPoint is a sequence node, then the new node are inserted
183  // inside that sequence node. Otherwise, a new sequence node is created.
184  //
185  // The transformed shape is:
186  //
187  // *extension( <- extension
188  // sequence(
189  // *filter() <- red_init in new or existing sequence
190  // redPoint
191  // *filter() <- red_update in new or existing sequence
192  // )
193  // )
194  //
195  // This tree structure typically appears when one does not include the
196  // innermost loop as part of an n-D tiling and mapping scheme but rather
197  // does (n-K)D tiling and placement and then another level of placement
198  // inside that.
199  isl::id insertReductionSync1D(
200  detail::ScheduleTree* redPoint,
201  isl::id updateId);
202 
203  // Given a sequence node in the schedule tree, insert
204  // synchronization before the child at position "pos".
205  // If "pos" is equal to the number of children, then
206  // the synchronization is added after the last child.
207  void insertSync(detail::ScheduleTree* seqNode, size_t pos);
208 
209  // Insert synchronization after the given subtree,
210  // creating a sequence node if needed.
211  void insertSyncAfter(detail::ScheduleTree* tree) {
212  insertExtensionLabelAfter(scheduleRoot(), tree, makeSyncId());
213  }
214 
215  size_t reductionUID() const {
216  static size_t count = 0;
217  return count++;
218  }
219  size_t syncUID() const {
220  static size_t count = 0;
221  return count++;
222  }
223 
224  isl::id makeSyncId() const {
225  auto ctx = domain().get_ctx();
226  return isl::id(ctx, std::string(kSyncIdPrefix) + std::to_string(syncUID()));
227  }
228 
229  static bool isSyncId(isl::id id) {
230  if (!id.has_name()) {
231  return false;
232  }
233  auto name = id.get_name();
234  if (name.find(kSyncIdPrefix) != 0) {
235  return false;
236  }
237  name = name.substr(std::string(kSyncIdPrefix).size());
238  char* end;
239  std::strtol(name.c_str(), &end, 10);
240  if (end - name.c_str() != name.size()) {
241  return false;
242  }
243  return true;
244  }
245 
246  static isl::id makeRefId(isl::ctx ctx) {
247  static thread_local size_t count = 0;
248  return isl::id(ctx, std::string("__tc_ref_") + std::to_string(count++));
249  }
250 
251  std::pair<isl::id, isl::id> makeReductionSpecialIds(isl::id updateId) {
252  auto uid = reductionUID();
253  auto treeSyncId = isl::id(
254  domain().get_ctx(), std::string("red_update") + std::to_string(uid));
255  auto reductionInitId = isl::id(
256  domain().get_ctx(), std::string("red_init") + std::to_string(uid));
257  CHECK_EQ(0, treeSyncUpdateMap.count(treeSyncId));
258  CHECK_EQ(0, defaultReductionInitMap.count(treeSyncId));
259 
260  treeSyncUpdateMap.emplace(treeSyncId, updateId);
261  defaultReductionInitMap.emplace(treeSyncId, reductionInitId);
262  return std::make_pair(treeSyncId, reductionInitId);
263  }
264 
265  bool isTreeSyncId(isl::id id) const {
266  return treeSyncUpdateMap.count(id) == 1;
267  }
268 
269  bool isDefaultReductionInitId(isl::id id) const {
270  for (const auto& p : defaultReductionInitMap) {
271  if (p.second == id) {
272  return true;
273  }
274  }
275  return false;
276  }
277 
278  isl::id getReductionUpdateForDefaultInit(isl::id id) const {
279  for (const auto& p : defaultReductionInitMap) {
280  if (p.second == id) {
281  return treeSyncUpdateMap.at(p.first);
282  }
283  }
284  CHECK(false) << "not found";
285  return id;
286  }
287 
288  bool isReductionUpdate(isl::id id) const {
289  for (const auto& kvp : treeSyncUpdateMap) {
290  if (id == kvp.second) {
291  return true;
292  }
293  }
294  return false;
295  }
296 
297  size_t reductionUpdatePos(isl::id id) const {
298  size_t pos = 0;
299  CHECK(isReductionUpdate(id));
300  for (const auto& kvp : treeSyncUpdateMap) {
301  if (id == kvp.second) {
302  return pos;
303  }
304  pos++;
305  }
306  return -1;
307  }
308 
309  void promoteEverythingAt(std::vector<size_t> pos);
310 
311  struct PromotedDecl {
312  isl::id tensorId;
313  std::vector<size_t> sizes;
314  };
315 
316  struct PromotionInfo {
317  std::shared_ptr<TensorReferenceGroup> group;
318  isl::union_map outerSchedule;
319  isl::id groupId;
320  };
321 
322  const std::unordered_map<isl::id, PromotedDecl, isl::IslIdIslHash>&
323  promotedDecls() const {
324  return promotedDecls_;
325  }
326 
327  const std::
328  unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>&
329  activePromotions() const {
330  return activePromotions_;
331  }
332 
333  detail::ScheduleTree* scheduleRoot() {
334  return scheduleTreeUPtr.get();
335  }
336 
337  const detail::ScheduleTree* scheduleRoot() const {
338  return scheduleTreeUPtr.get();
339  }
340 
341  // Create a Scop scheduled with a given scheduling strategy.
342  static std::unique_ptr<Scop> makeScheduled(
343  const Scop& scop,
344  const SchedulerOptionsView& schedulerOptions);
345 
346  // Tile the outermost band.
347  // Splits the band into tile loop band and point loop band where point loops
348  // have fixed trip counts specified in "tiling", and returns a pointer to the
349  // tile loop band.
350  detail::ScheduleTree* tileOuterBand(const TilingView& tiling);
351 
352  // Reschedule the schedule subtree rooted at "tree" with the
353  // given scheduler options.
354  void reschedule(
355  detail::ScheduleTree* tree,
356  const SchedulerOptionsView& schedulerOptions);
357 
358  // Find an input or an output argument given its name.
359  // Assumes such argument exists.
360  const Halide::OutputImageParam& findArgument(isl::id id) const;
361 
362  // Promote a tensor reference group to shared memory, inserting the copy
363  // statements below the given node. Inserts an Extension node below the give
364  // node, unless there is already another Extension node which introduces
365  // copies. The Extension node has a unique Sequence child, whose children
366  // perform copies from global memory, then main computation using the
367  // original nodes, then copies back to global memory. The caller is in
368  // charge of inserting the synchronization nodes.
369  //
370  // Creates the promoted array declaration in the internal list.
371  // If "forceLastExtentOdd" is set, the last extent in the declaration is
372  // incremented if it is even. This serves as a simple heuristic to reduce
373  // shared memory bank conflicts.
374  void promoteGroupToShared(
375  isl::id tensorId,
376  std::unique_ptr<TensorReferenceGroup>&& gr,
377  detail::ScheduleTree* tree,
378  const std::unordered_set<isl::id, isl::IslIdIslHash>& activeStmts,
379  isl::union_map schedule,
380  bool forceLastExtentOdd = false);
381 
382  // Given a tree node under which the promotion copy statements were
383  // introduced, insert syncthread statements before and after the copies.
384  // The tree should match the structure:
385  // any(
386  // extension(
387  // sequence(
388  // // <-- sync will be inserted here
389  // filter(any()), // filter that refers to read
390  // ...
391  // // <-- sync will be inserted here if filter above exists
392  // filter(any()), // at least one filter that does not refer to
393  // ... // read/write
394  // // <-- sync will be inserted here if filter below exists
395  // filter(any()), // filter that refers to write
396  // ...
397  // // <-- sync will be inserted here
398  // )))
399  //
400  void insertSyncsAroundCopies(detail::ScheduleTree* tree);
401 
402  private:
403  // Compute a schedule satisfying the given schedule constraints and
404  // taking into account the scheduler options.
405  // Note that some of the scheduler options have already been
406  // taken into account during the construction of the schedule constraints.
407  static std::unique_ptr<detail::ScheduleTree> computeSchedule(
408  isl::schedule_constraints constraints,
409  const SchedulerOptionsView& schedulerOptions);
410 
411  public:
412  // Halide stuff
413  struct {
414  std::vector<Halide::Internal::Parameter> params;
415  std::vector<std::string> idx, reductionIdx;
416  std::vector<Halide::ImageParam> inputs;
417  std::vector<Halide::OutputImageParam> outputs;
418  std::vector<halide2isl::Reduction> reductions;
419  std::unordered_map<isl::id, Halide::Internal::Stmt, isl::IslIdIslHash>
420  statements;
421  std::unordered_map<const Halide::Internal::IRNode*, isl::id> accesses;
422  } halide;
423 
424  // Poyhedral IR
425  //
426  // The domain is collected from the root of the ScheduleTree; no redundant
427  // state is kept.
428  // By analogy with generalized functions, the domain is the "support" part
429  // of the ScheduleTree "function".
430  isl::union_set& domain();
431  const isl::union_set domain() const;
432  // A globalParameterContext is kept. This represents (partial)
433  // parameter specialization coming from the outside.
434  // This may be further specialized before codegen.
435  // This globalParameterContext must not give rise to a context node in the
436  // schedule tree.
437  // This globalParameterContext is intersected with the domain of the
438  // ScheduleTree for best possible specialization of polyhedral decisions and
439  // transformations. By the analogy with generalized functions, the
440  // globalParameterContext becomes part of the "support" of the ScheduleTree
441  // "function".
442  // This globalParameterContext lives in a parameter space.
443  isl::set globalParameterContext; // TODO: not too happy about this name
444 
445  isl::union_map reads;
446  isl::union_map writes;
447 
448  private:
449  // By analogy with generalized functions, a ScheduleTree is a (piecewise
450  // affine) function operating on a support.
451  // The support is originally an isl::union_set corresponding to the union of
452  // the iteration domains of the statements in the Scop.
453  // The support must be the unique root node of the ScheduleTree and be of
454  // type: ScheduleTreeElemDomain.
455  std::unique_ptr<detail::ScheduleTree> scheduleTreeUPtr;
456 
457  public:
458  // For reduction matching purposes we keep the following maps
459  std::unordered_map<isl::id, isl::id, isl::IslIdIslHash> treeSyncUpdateMap;
460  std::unordered_map<isl::id, isl::id, isl::IslIdIslHash>
461  defaultReductionInitMap; // treeSyncId -> defaultInitId
462 
463  private:
464  // Memory promotion stuff
465  // tensorId -> number of mapped groups
466  std::unordered_map<isl::id, size_t, isl::IslIdIslHash> groupCounts_;
467  // groupId -> (tensorId, groupSizes)
468  std::unordered_map<isl::id, PromotedDecl, isl::IslIdIslHash> promotedDecls_;
469  // stmtId -> (group, partial schedule, groupId)
470  std::unordered_map<isl::id, std::vector<PromotionInfo>, isl::IslIdIslHash>
471  activePromotions_;
472 };
473 
474 std::ostream& operator<<(std::ostream& os, const Scop&);
475 
476 } // namespace polyhedral
477 } // namespace tc
Definition: tc2halide.h:29
isl::set makeSpecializationSet(isl::space space, const std::unordered_map< int, T > &paramValues)
Definition: islpp.h:314
std::ostream & operator<<(std::ostream &out, const MappingOptionsAsCpp &mo)
Definition: mapping_options_cpp_printer.h:79
std::shared_ptr< Tree > TreeRef
Definition: tree.h:44
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Halide::Internal::Stmt &s)
Definition: islpp.h:260