Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
schedule_tree.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <algorithm>
19 #include <memory>
20 #include <unordered_set>
21 #include <vector>
22 
25 #include "tc/core/utils/vararg.h"
26 #include "tc/external/isl.h"
27 
28 #include "glog/logging.h"
29 
30 namespace tc {
31 namespace polyhedral {
32 namespace detail {
33 
34 // Internal representation of a polyhedral schedule information, wrapping a
35 // ScheduleTree, convertible to and from isl::schedule.
36 //
37 struct ScheduleTree;
38 
39 } // namespace detail
40 
41 using ScheduleTreeUPtr = std::unique_ptr<detail::ScheduleTree>;
42 
43 namespace detail {
44 
46 //
47 // Schedule Trees
48 //
49 // Memory model: tree uniquely owns its children, user owns the root,
50 // traversals are non-owning.
51 //
52 // ScheduleTree is a data store of the ScheduleXYZ API. It implements a
53 // mutable tree data structure, each ScheduleTree having a potentially empty
54 // list of children, with the following ownership semantics. A ScheduleTree
55 // object owns its children. When a child is added to or removed from the
56 // tree, the ownership is transferred from or to the caller. Users of
57 // ScheduleTree own the root of the tree and may apply any ownership policies,
58 // e.g. share the ownership. Users are guaranteed that, if they own a
59 // ScheduleTree, it is a root of a tree. Ownership rules are enforced through
60 // unique_ptr and move semantics. In particular, users are not expected to
61 // manipulate ScheduleTree objects by value.
62 //
63 // New ScheduleTree objects of various types or deep copies of the existing
64 // objects can be constructed using static factory functions, which transfer
65 // the ownership of the constructed object to the caller. These functions
66 // optionally take a list of subtrees that will become children of the newly
67 // constructed tree, which takes ownership.
68 //
69 // Tree structure can be changed by appending, inserting, detaching or swapping
70 // the subtrees. Only trees owned by the user can be attached, inserted or
71 // swapped with, in which case the ownership is transferred to the parent tree
72 // object. The user is expected to own the root of the tree and the inserted
73 // tree, but not the insertion point. The ownership of the detached or swapped
74 // tree is transferred to the caller.
75 //
76 // ScheduleTrees are not supposed to have null children, which is checked in
77 // the construction/child manipulation in debug builds.
78 //
79 //
80 // Internal structure: single-linked tree (no parent pointer).
81 //
82 // Because the caller must own the root of the tree, it is always possible to
83 // find the parent or any ancestor tree by traversing the tree from the root.
84 // Subtrees are ordered and are identified by their position in the parent
85 // tree.
86 //
87 // Trees can be traversed, inspected and modified through raw non-owning
88 // pointers. PreorderDFS, PostorderDFS and BFS traversals are provided. Tree
89 // modification is in place and does not require the caller to own the
90 // ScheduleTree object.
91 //
92 // Tree modification functions are external to the ScheduleTree class and
93 // should only rely on the exposed API to avoid breaking the ownership and
94 // non-null guarantees. For the sake of consistency, modification functions
95 // should take a raw pointer to the root of the tree as the first argument,
96 // even if they do not use it, and a raw pointer to the subtree being
97 // manipulated. Transformation functions should account for the root pointer
98 // being relative, i.e. not being the actual root pointer owned by the caller,
99 // but rather some ancestor of the given node, above which the transformation
100 // has no effect (think of C++ standard library with begin/end iterators).
101 //
102 //
103 // Well-formedness guarantees: non-null subtrees.
104 //
105 // ScheduleTree does NOT impose any structure requirements on the tree, e.g.
106 // those of ISL. A tree with a null child is ill-formed.
107 //
108 // Note that we do not enforce the isl schedule tree invariants [1] at the API
109 // level. Instead, we provide a function checkValidIslSchedule() to verify
110 // whether a schedule that has a given ScheduleTree as root can be converted to
111 // an isl::schedule. Note that, for compatibility reasons, the current
112 // implementation may also create isl::schedules that do not maintain isl
113 // schedule tree guarantees. The behavior of isl calls on such schedules is
114 // undefined.
115 //
116 // The following isl invariants can be checked
117 // 1. root is domain/extension
118 // 2. only sequence/set have multiple children, these children are filters
119 // 3. nodes do not refer to parameters that were not previously introduced by a
120 // context or a domain node
121 // 4. nodes do not refer to inactive domain points, i.e. those that were
122 // filtered away (warning only)
123 // 5. union of filters contains all active domain elements
124 // 6. domain of an expansion contains all active domain elements
125 // 7. partial schedule of a band node is total for all active domain elements
126 // 8. extension nodes do not introduce any elements that are already active
127 // domain elements
128 // 9. (not enforced)
129 // 10. that anchored nodes match the flattened space of the outer bands
130 //
131 // TODO(ftynse): implement bool checkValidIslSchedule() to check schedule
132 // structure without failing the run.
133 //
134 //
135 // [1] Verdoolaege, Guelton, Grosser & Cohen (2014). "Schedule trees". In
136 // IMPACT 2014.
138 struct ScheduleTree {
139  friend std::ostream& tc::polyhedral::detail::operator<<(
140  std::ostream&,
141  const tc::polyhedral::detail::ScheduleTree&);
142 
143  private:
144  ScheduleTree() = delete;
145  ScheduleTree(
146  isl::ctx ctx,
147  std::vector<ScheduleTreeUPtr>&& children,
148  detail::ScheduleTreeType type,
149  std::unique_ptr<ScheduleTreeElemBase>&& elem)
150  : ctx_(ctx), type_(type), elem_(std::move(elem)) {
151  appendChildren(std::move(children));
152  }
153  ScheduleTree(const ScheduleTree& st);
154 
155  public:
156  explicit ScheduleTree(isl::ctx ctx);
157 
158  bool operator==(const ScheduleTree& other) const;
159  bool operator!=(const ScheduleTree& other) const {
160  return !(*this == other);
161  }
162 
163  // Swap a tree with with the given tree.
164  void swapChild(int pos, ScheduleTreeUPtr& swappee) {
165  CHECK_GE(pos, 0) << "position out of children bounds";
166  CHECK_LE(pos, children_.size()) << "position out of children bounds";
167  CHECK(swappee.get()) << "Cannot swap in a null tree";
168  std::swap(children_[pos], swappee);
169  }
170 
171  // Child accessors (only in-place modification allowed)
172  ScheduleTree* child(const std::vector<size_t>& positions);
173  const ScheduleTree* child(const std::vector<size_t>& positions) const;
174  size_t numChildren() const {
175  return children_.size();
176  };
177 
178  // Manipulators for the list of children.
179  void insertChildren(size_t pos, std::vector<ScheduleTreeUPtr>&& children) {
180  CHECK_GE(pos, 0) << "position out of children bounds";
181  CHECK_LE(pos, children_.size()) << "position out of children bounds";
182  for (const auto& c : children) {
183  CHECK(c.get()) << "inserting null or moved-from child";
184  }
185 
186  children_.insert(
187  children_.begin() + pos,
188  std::make_move_iterator(children.begin()),
189  std::make_move_iterator(children.end()));
190  }
191 
192  void insertChild(size_t pos, ScheduleTreeUPtr&& child) {
193  // One cannot move from an initializer_list, so need an actual temporary
194  // object here.
195  insertChildren(pos, vectorFromArgs(std::move(child)));
196  }
197 
198  void appendChildren(std::vector<ScheduleTreeUPtr>&& children) {
199  insertChildren(children_.size(), std::move(children));
200  }
201 
202  void appendChild(ScheduleTreeUPtr&& child) {
203  insertChild(children_.size(), std::move(child));
204  }
205 
206  ScheduleTreeUPtr detachChild(size_t pos) {
207  CHECK_GE(pos, 0) << "position out of children bounds";
208  CHECK_LT(pos, children_.size()) << "position out of children bounds";
209 
210  ScheduleTreeUPtr child = std::move(children_[pos]);
211  children_.erase(children_.begin() + pos);
212  return child;
213  }
214 
215  std::vector<ScheduleTreeUPtr> detachChildren() {
216  std::vector<ScheduleTreeUPtr> tmpChildren;
217  std::swap(tmpChildren, children_);
218  return tmpChildren;
219  }
220 
221  std::vector<ScheduleTreeUPtr> replaceChildren(
222  std::vector<ScheduleTreeUPtr>&& children) {
223  auto oldChildren = detachChildren();
224  appendChildren(std::move(children));
225  return oldChildren;
226  }
227 
228  ScheduleTreeUPtr replaceChild(size_t pos, ScheduleTreeUPtr&& child) {
229  CHECK_GE(pos, 0) << "position out of children bounds";
230  CHECK_LT(pos, children_.size()) << "position out of children bounds";
231 
232  ScheduleTreeUPtr oldChild = std::move(children_[pos]);
233  children_[pos] = std::move(child);
234  return oldChild;
235  }
236 
237  // Helper to avoid calling collect + filter for this common case
238  std::vector<ScheduleTree*> children() {
239  std::vector<ScheduleTree*> res;
240  res.reserve(children_.size());
241  for (auto& p : children_) {
242  res.push_back(p.get());
243  }
244  return res;
245  };
246  std::vector<const ScheduleTree*> children() const {
247  std::vector<const ScheduleTree*> res;
248  res.reserve(children_.size());
249  for (const auto& p : children_) {
250  res.push_back(p.get());
251  }
252  return res;
253  };
254 
255  ScheduleTree* ancestor(ScheduleTree* relativeRoot, size_t generation);
256  const ScheduleTree* ancestor(
257  const ScheduleTree* relativeRoot,
258  size_t generation) const;
259  // Returns the ancestors up to relativeRoot in a vector. The first element
260  // of the result is relativeRoot, the last element of the result is the
261  // father of the "this" ScheduleTree.
262  std::vector<ScheduleTree*> ancestors(ScheduleTree* relativeRoot);
263  std::vector<const ScheduleTree*> ancestors(
264  const ScheduleTree* relativeRoot) const;
265 
266  std::vector<size_t> positionRelativeTo(
267  const ScheduleTree* relativeRoot) const;
268 
269  inline size_t positionInParent(const ScheduleTree* parent) const {
270  auto p = positionRelativeTo(parent);
271  CHECK_EQ(1, p.size()) << *parent << " is not the parent of " << *this;
272  return p[0];
273  }
274 
275  size_t scheduleDepth(const ScheduleTree* relativeRoot) const;
276 
277  //
278  // Factory functions
279  //
280  static ScheduleTreeUPtr makeBand(
281  isl::multi_union_pw_aff mupa,
282  std::vector<ScheduleTreeUPtr>&& children = {});
283 
284  static ScheduleTreeUPtr makeDomain(
285  isl::union_set domain,
286  std::vector<ScheduleTreeUPtr>&& children = {});
287 
288  static ScheduleTreeUPtr makeContext(
289  isl::set context,
290  std::vector<ScheduleTreeUPtr>&& children = {});
291 
292  static ScheduleTreeUPtr makeFilter(
293  isl::union_set filter,
294  std::vector<ScheduleTreeUPtr>&& children = {});
295 
296  template <typename MappingIdType>
297  static inline ScheduleTreeUPtr makeMappingFilter(
298  isl::union_set filter,
299  const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
300  mappingIds,
301  std::vector<ScheduleTreeUPtr>&& children = {});
302 
303  static ScheduleTreeUPtr makeExtension(
304  isl::union_map extension,
305  std::vector<ScheduleTreeUPtr>&& children = {});
306 
307  template <typename... Args>
308  static ScheduleTreeUPtr makeBand(
309  isl::multi_union_pw_aff mupa,
310  Args&&... args) {
311  return makeBand(
312  mupa, vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
313  }
314 
315  template <typename... Args>
316  static ScheduleTreeUPtr makeDomain(isl::union_set domain, Args&&... args) {
317  return makeDomain(
318  domain, vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
319  }
320 
321  template <typename... Args>
322  static ScheduleTreeUPtr makeContext(isl::set context, Args&&... args) {
323  return makeContext(
324  context, vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
325  }
326 
327  template <typename... Args>
328  static ScheduleTreeUPtr makeFilter(isl::union_set filter, Args&&... args) {
329  return makeFilter(
330  filter, vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
331  }
332 
333  template <typename MappingIdType, typename... Args>
334  static inline ScheduleTreeUPtr makeMappingFilter(
335  isl::union_set filter,
336  const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
337  mappingIds,
338  Args&&... args) {
339  return makeMappingFilter(
340  filter,
341  mappingIds,
342  vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
343  }
344 
345  template <typename... Args>
346  static ScheduleTreeUPtr makeExtension(
347  isl::union_map extension,
348  Args&&... args) {
349  return makeExtension(
350  extension,
351  vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
352  }
353 
354  template <typename... Args>
355  static ScheduleTreeUPtr makeSet(Args&&... args) {
356  return fromList<ScheduleTreeElemSet>(
357  detail::ScheduleTreeType::Set, std::forward<Args>(args)...);
358  }
359 
360  template <typename... Args>
361  static ScheduleTreeUPtr makeSequence(Args&&... args) {
362  return fromList<ScheduleTreeElemSequence>(
363  detail::ScheduleTreeType::Sequence, std::forward<Args>(args)...);
364  }
365 
366  // Flatten nested nodes of the same type.
367  void flattenSequenceOrSet() {
368  // This should be enforced by the type system...
369  CHECK(
370  type_ == detail::ScheduleTreeType::Sequence ||
371  type_ == detail::ScheduleTreeType::Set);
372 
373  // Iterate over the changing list of children. If a child has the same list
374  // type as a parent, replace it with grandchildren and traverse them too.
375  for (size_t i = 0; i < children_.size(); ++i) {
376  if (children_[i]->type_ != type_) {
377  continue;
378  }
379  auto grandChildren = children_[i]->detachChildren();
380  detachChild(i);
381  insertChildren(i, std::move(grandChildren));
382  --i;
383  }
384  }
385 
386  // disallow empty lists in syntax
387  template <typename T, typename Arg, typename... Args>
388  static ScheduleTreeUPtr
389  fromList(detail::ScheduleTreeType type, Arg&& arg, Args&&... args) {
390  static_assert(
391  std::is_base_of<ScheduleTreeElemBase, T>::value,
392  "Can only construct elements derived from ScheduleTreeElemBase");
393  static_assert(
394  std::is_same<
395  typename std::remove_reference<Arg>::type,
396  ScheduleTreeUPtr>::value,
397  "Arguments must be rvalue references to ScheduleTreeUPtr");
398 
399  auto ctx = arg->ctx_;
400  std::vector<ScheduleTreeUPtr> children =
401  vectorFromArgs(std::forward<Arg>(arg), std::forward<Args>(args)...);
402 
403  auto res = ScheduleTreeUPtr(new ScheduleTree(
404  ctx,
405  std::move(children),
406  type,
407  std::unique_ptr<ScheduleTreeElemBase>(new T)));
408 
409  if (type == detail::ScheduleTreeType::Sequence ||
410  type == detail::ScheduleTreeType::Set) {
411  res->flattenSequenceOrSet();
412  }
413  return res;
414  }
415 
416  static ScheduleTreeUPtr makeScheduleTree(const ScheduleTree& tree) {
417  return ScheduleTreeUPtr(new ScheduleTree(tree));
418  }
419 
420  // Collect the nodes of "tree" in some arbitrary order.
421  template <typename T>
422  static std::vector<T> collect(T tree) {
423  return collectDFSPreorder(tree);
424  }
425  // Collect the nodes of "tree" of the given type in some arbitrary order.
426  template <typename T>
427  static std::vector<T> collect(T tree, detail::ScheduleTreeType type) {
428  return collectDFSPreorder(tree, type);
429  }
430 
431  static std::vector<ScheduleTree*> collectDFSPostorder(ScheduleTree* tree);
432  static std::vector<ScheduleTree*> collectDFSPreorder(ScheduleTree* tree);
433  static std::vector<ScheduleTree*> collectDFSPostorder(
434  ScheduleTree* tree,
435  detail::ScheduleTreeType type);
436  static std::vector<ScheduleTree*> collectDFSPreorder(
437  ScheduleTree* tree,
438  detail::ScheduleTreeType type);
439 
440  static std::vector<const ScheduleTree*> collectDFSPostorder(
441  const ScheduleTree* tree);
442  static std::vector<const ScheduleTree*> collectDFSPreorder(
443  const ScheduleTree* tree);
444  static std::vector<const ScheduleTree*> collectDFSPostorder(
445  const ScheduleTree* tree,
446  detail::ScheduleTreeType type);
447  static std::vector<const ScheduleTree*> collectDFSPreorder(
448  const ScheduleTree* tree,
449  detail::ScheduleTreeType type);
450 
451  // View elem_ as the specified type.
452  // Returns nullptr if this is not the proper type.
453  // Inline impl for now, does not justify an extra -inl.h file
454  template <typename T>
455  T* elemAs() {
456  const ScheduleTree* t = this;
457  return const_cast<T*>(t->elemAs<const T>());
458  }
459  template <typename T>
460  const T* elemAs() const {
461  static_assert(
462  std::is_base_of<ScheduleTreeElemBase, T>::value,
463  "Must call with a class derived from ScheduleTreeElemBase");
464  if (type_ != T::NodeType) {
465  return nullptr;
466  }
467  return static_cast<const T*>(
468  const_cast<const ScheduleTreeElemBase*>(elem_.get()));
469  }
470 
471  // View elem_ as the specified type.
472  // Returns nullptr if neither this type, **nor any of the derived types**,
473  // are T.
474  // Inline impl for now, does not justify an extra -inl.h file
475  template <typename T>
476  T* elemAsBase() {
477  const ScheduleTree* t = this;
478  return const_cast<T*>(t->elemAsBase<const T>());
479  }
480  template <typename T>
481  const T* elemAsBase() const {
482  static_assert(
483  std::is_base_of<ScheduleTreeElemBase, T>::value,
484  "Must call with a class derived from ScheduleTreeElemBase");
485  // These T::NodeDerivedTypes are ugly, let's see if we can improve in the
486  // future but if we want dynamic typing and to avoid enumerations at each
487  // call site, which I claim we absolutely do, then we are not left with
488  // many options.
489  if (type_ != T::NodeType &&
490  std::find(
491  T::NodeDerivedTypes.begin(), T::NodeDerivedTypes.end(), type_) ==
492  T::NodeDerivedTypes.end()) {
493  return nullptr;
494  }
495  return static_cast<const T*>(
496  const_cast<const ScheduleTreeElemBase*>(elem_.get()));
497  }
498 
499  //
500  // Data members
501  //
502  public:
503  mutable isl::ctx ctx_;
504 
505  private:
506  std::vector<ScheduleTreeUPtr> children_{};
507 
508  public:
509  detail::ScheduleTreeType type_{detail::ScheduleTreeType::None};
510  std::unique_ptr<ScheduleTreeElemBase> elem_{nullptr};
511 };
512 
513 std::ostream& operator<<(std::ostream& os, const ScheduleTree& tree);
514 
515 } // namespace detail
516 
517 isl::union_set activeDomainPoints(
518  const detail::ScheduleTree* root,
519  const detail::ScheduleTree* node);
520 } // namespace polyhedral
521 } // namespace tc
522 
std::ostream & operator<<(std::ostream &out, const MappingOptionsAsCpp &mo)
Definition: mapping_options_cpp_printer.h:79
std::vector< Arg > vectorFromArgs()
Definition: vararg.h:45
bool operator==(const std::vector< const DLTensor * > &inputsTensor, const std::vector< detail::TensorInfo > &inputsInfo)
bool operator!=(isl::val v, long i)
Definition: islpp.h:103
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Halide::Internal::Stmt &s)