20 #include <unordered_set>
28 #include "glog/logging.h"
31 namespace polyhedral {
41 using ScheduleTreeUPtr = std::unique_ptr<detail::ScheduleTree>;
138 struct ScheduleTree {
141 const tc::polyhedral::detail::ScheduleTree&);
144 ScheduleTree() =
delete;
147 std::vector<ScheduleTreeUPtr>&& children,
148 detail::ScheduleTreeType type,
149 std::unique_ptr<ScheduleTreeElemBase>&& elem)
150 : ctx_(ctx), type_(type), elem_(std::move(elem)) {
151 appendChildren(std::move(children));
153 ScheduleTree(
const ScheduleTree& st);
156 explicit ScheduleTree(isl::ctx ctx);
158 bool operator==(
const ScheduleTree& other)
const;
159 bool operator!=(
const ScheduleTree& other)
const {
160 return !(*
this == other);
164 void swapChild(
int pos, ScheduleTreeUPtr& swappee) {
165 CHECK_GE(pos, 0) <<
"position out of children bounds";
166 CHECK_LE(pos, children_.size()) <<
"position out of children bounds";
167 CHECK(swappee.get()) <<
"Cannot swap in a null tree";
168 std::swap(children_[pos], swappee);
172 ScheduleTree* child(
const std::vector<size_t>& positions);
173 const ScheduleTree* child(
const std::vector<size_t>& positions)
const;
174 size_t numChildren()
const {
175 return children_.size();
179 void insertChildren(
size_t pos, std::vector<ScheduleTreeUPtr>&& children) {
180 CHECK_GE(pos, 0) <<
"position out of children bounds";
181 CHECK_LE(pos, children_.size()) <<
"position out of children bounds";
182 for (
const auto& c : children) {
183 CHECK(c.get()) <<
"inserting null or moved-from child";
187 children_.begin() + pos,
188 std::make_move_iterator(children.begin()),
189 std::make_move_iterator(children.end()));
192 void insertChild(
size_t pos, ScheduleTreeUPtr&& child) {
198 void appendChildren(std::vector<ScheduleTreeUPtr>&& children) {
199 insertChildren(children_.size(), std::move(children));
202 void appendChild(ScheduleTreeUPtr&& child) {
203 insertChild(children_.size(), std::move(child));
206 ScheduleTreeUPtr detachChild(
size_t pos) {
207 CHECK_GE(pos, 0) <<
"position out of children bounds";
208 CHECK_LT(pos, children_.size()) <<
"position out of children bounds";
210 ScheduleTreeUPtr child = std::move(children_[pos]);
211 children_.erase(children_.begin() + pos);
215 std::vector<ScheduleTreeUPtr> detachChildren() {
216 std::vector<ScheduleTreeUPtr> tmpChildren;
217 std::swap(tmpChildren, children_);
221 std::vector<ScheduleTreeUPtr> replaceChildren(
222 std::vector<ScheduleTreeUPtr>&& children) {
223 auto oldChildren = detachChildren();
224 appendChildren(std::move(children));
228 ScheduleTreeUPtr replaceChild(
size_t pos, ScheduleTreeUPtr&& child) {
229 CHECK_GE(pos, 0) <<
"position out of children bounds";
230 CHECK_LT(pos, children_.size()) <<
"position out of children bounds";
232 ScheduleTreeUPtr oldChild = std::move(children_[pos]);
233 children_[pos] = std::move(child);
238 std::vector<ScheduleTree*> children() {
239 std::vector<ScheduleTree*> res;
240 res.reserve(children_.size());
241 for (
auto& p : children_) {
242 res.push_back(p.get());
246 std::vector<const ScheduleTree*> children()
const {
247 std::vector<const ScheduleTree*> res;
248 res.reserve(children_.size());
249 for (
const auto& p : children_) {
250 res.push_back(p.get());
255 ScheduleTree* ancestor(ScheduleTree* relativeRoot,
size_t generation);
256 const ScheduleTree* ancestor(
257 const ScheduleTree* relativeRoot,
258 size_t generation)
const;
262 std::vector<ScheduleTree*> ancestors(ScheduleTree* relativeRoot);
263 std::vector<const ScheduleTree*> ancestors(
264 const ScheduleTree* relativeRoot)
const;
266 std::vector<size_t> positionRelativeTo(
267 const ScheduleTree* relativeRoot)
const;
269 inline size_t positionInParent(
const ScheduleTree* parent)
const {
270 auto p = positionRelativeTo(parent);
271 CHECK_EQ(1, p.size()) << *parent <<
" is not the parent of " << *
this;
275 size_t scheduleDepth(
const ScheduleTree* relativeRoot)
const;
280 static ScheduleTreeUPtr makeBand(
281 isl::multi_union_pw_aff mupa,
282 std::vector<ScheduleTreeUPtr>&& children = {});
284 static ScheduleTreeUPtr makeDomain(
285 isl::union_set domain,
286 std::vector<ScheduleTreeUPtr>&& children = {});
288 static ScheduleTreeUPtr makeContext(
290 std::vector<ScheduleTreeUPtr>&& children = {});
292 static ScheduleTreeUPtr makeFilter(
293 isl::union_set filter,
294 std::vector<ScheduleTreeUPtr>&& children = {});
296 template <
typename MappingIdType>
297 static inline ScheduleTreeUPtr makeMappingFilter(
298 isl::union_set filter,
299 const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
301 std::vector<ScheduleTreeUPtr>&& children = {});
303 static ScheduleTreeUPtr makeExtension(
304 isl::union_map extension,
305 std::vector<ScheduleTreeUPtr>&& children = {});
307 template <
typename... Args>
308 static ScheduleTreeUPtr makeBand(
309 isl::multi_union_pw_aff mupa,
312 mupa, vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
315 template <
typename... Args>
316 static ScheduleTreeUPtr makeDomain(isl::union_set domain, Args&&... args) {
318 domain, vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
321 template <
typename... Args>
322 static ScheduleTreeUPtr makeContext(isl::set context, Args&&... args) {
324 context, vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
327 template <
typename... Args>
328 static ScheduleTreeUPtr makeFilter(isl::union_set filter, Args&&... args) {
330 filter, vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
333 template <
typename MappingIdType,
typename... Args>
334 static inline ScheduleTreeUPtr makeMappingFilter(
335 isl::union_set filter,
336 const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
339 return makeMappingFilter(
342 vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
345 template <
typename... Args>
346 static ScheduleTreeUPtr makeExtension(
347 isl::union_map extension,
349 return makeExtension(
351 vectorFromArgs<ScheduleTreeUPtr>(std::forward<Args>(args)...));
354 template <
typename... Args>
355 static ScheduleTreeUPtr makeSet(Args&&... args) {
356 return fromList<ScheduleTreeElemSet>(
357 detail::ScheduleTreeType::Set, std::forward<Args>(args)...);
360 template <
typename... Args>
361 static ScheduleTreeUPtr makeSequence(Args&&... args) {
362 return fromList<ScheduleTreeElemSequence>(
363 detail::ScheduleTreeType::Sequence, std::forward<Args>(args)...);
367 void flattenSequenceOrSet() {
370 type_ == detail::ScheduleTreeType::Sequence ||
371 type_ == detail::ScheduleTreeType::Set);
375 for (
size_t i = 0; i < children_.size(); ++i) {
376 if (children_[i]->type_ != type_) {
379 auto grandChildren = children_[i]->detachChildren();
381 insertChildren(i, std::move(grandChildren));
387 template <
typename T,
typename Arg,
typename... Args>
388 static ScheduleTreeUPtr
389 fromList(detail::ScheduleTreeType type, Arg&& arg, Args&&... args) {
391 std::is_base_of<ScheduleTreeElemBase, T>::value,
392 "Can only construct elements derived from ScheduleTreeElemBase");
395 typename std::remove_reference<Arg>::type,
396 ScheduleTreeUPtr>::value,
397 "Arguments must be rvalue references to ScheduleTreeUPtr");
399 auto ctx = arg->ctx_;
400 std::vector<ScheduleTreeUPtr> children =
401 vectorFromArgs(std::forward<Arg>(arg), std::forward<Args>(args)...);
403 auto res = ScheduleTreeUPtr(
new ScheduleTree(
407 std::unique_ptr<ScheduleTreeElemBase>(
new T)));
409 if (type == detail::ScheduleTreeType::Sequence ||
410 type == detail::ScheduleTreeType::Set) {
411 res->flattenSequenceOrSet();
417 return ScheduleTreeUPtr(
new ScheduleTree(tree));
421 template <
typename T>
422 static std::vector<T> collect(T tree) {
423 return collectDFSPreorder(tree);
426 template <
typename T>
427 static std::vector<T> collect(T tree, detail::ScheduleTreeType type) {
428 return collectDFSPreorder(tree, type);
431 static std::vector<ScheduleTree*> collectDFSPostorder(ScheduleTree* tree);
432 static std::vector<ScheduleTree*> collectDFSPreorder(ScheduleTree* tree);
433 static std::vector<ScheduleTree*> collectDFSPostorder(
435 detail::ScheduleTreeType type);
436 static std::vector<ScheduleTree*> collectDFSPreorder(
438 detail::ScheduleTreeType type);
440 static std::vector<const ScheduleTree*> collectDFSPostorder(
441 const ScheduleTree* tree);
442 static std::vector<const ScheduleTree*> collectDFSPreorder(
443 const ScheduleTree* tree);
444 static std::vector<const ScheduleTree*> collectDFSPostorder(
445 const ScheduleTree* tree,
446 detail::ScheduleTreeType type);
447 static std::vector<const ScheduleTree*> collectDFSPreorder(
448 const ScheduleTree* tree,
449 detail::ScheduleTreeType type);
454 template <
typename T>
456 const ScheduleTree* t =
this;
457 return const_cast<T*
>(t->elemAs<
const T>());
459 template <
typename T>
460 const T* elemAs()
const {
462 std::is_base_of<ScheduleTreeElemBase, T>::value,
463 "Must call with a class derived from ScheduleTreeElemBase");
464 if (type_ != T::NodeType) {
467 return static_cast<const T*
>(
468 const_cast<const ScheduleTreeElemBase*
>(elem_.get()));
475 template <
typename T>
477 const ScheduleTree* t =
this;
478 return const_cast<T*
>(t->elemAsBase<
const T>());
480 template <
typename T>
481 const T* elemAsBase()
const {
483 std::is_base_of<ScheduleTreeElemBase, T>::value,
484 "Must call with a class derived from ScheduleTreeElemBase");
489 if (type_ != T::NodeType &&
491 T::NodeDerivedTypes.begin(), T::NodeDerivedTypes.end(), type_) ==
492 T::NodeDerivedTypes.end()) {
495 return static_cast<const T*
>(
496 const_cast<const ScheduleTreeElemBase*
>(elem_.get()));
503 mutable isl::ctx ctx_;
506 std::vector<ScheduleTreeUPtr> children_{};
509 detail::ScheduleTreeType type_{detail::ScheduleTreeType::None};
510 std::unique_ptr<ScheduleTreeElemBase> elem_{
nullptr};
513 std::ostream&
operator<<(std::ostream& os,
const ScheduleTree& tree);
517 isl::union_set activeDomainPoints(
518 const detail::ScheduleTree* root,
519 const detail::ScheduleTree* node);
std::ostream & operator<<(std::ostream &out, const MappingOptionsAsCpp &mo)
Definition: mapping_options_cpp_printer.h:79
std::vector< Arg > vectorFromArgs()
Definition: vararg.h:45
bool operator==(const std::vector< const DLTensor * > &inputsTensor, const std::vector< detail::TensorInfo > &inputsInfo)
bool operator!=(isl::val v, long i)
Definition: islpp.h:103
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Halide::Internal::Stmt &s)