Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
schedule_transforms-inl.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 namespace tc {
19 namespace polyhedral {
20 template <typename MappingIdType>
21 inline detail::ScheduleTree* insertMappingFilterAbove(
22  detail::ScheduleTree* root,
23  detail::ScheduleTree* tree,
24  isl::union_set filter,
25  const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
26  mappingIds) {
27  auto parent = tree->ancestor(root, 1);
28  auto childPos = tree->positionInParent(parent);
29  parent->insertChild(
30  childPos,
31  detail::ScheduleTree::makeMappingFilter(
32  filter, mappingIds, parent->detachChild(childPos)));
33  return parent->child({childPos});
34 }
35 
36 template <typename MappingIdType>
37 inline void insertMappingFilterBelow(
38  detail::ScheduleTree* tree,
39  isl::union_set filter,
40  const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
41  mappingIds) {
42  auto numChildren = tree->numChildren();
43  CHECK_LE(numChildren, 1);
44  tree->appendChild(detail::ScheduleTree::makeMappingFilter(
45  filter, mappingIds, tree->detachChildren()));
46 }
47 
48 template <typename MappingIdType>
49 inline detail::ScheduleTree* mapToParameterWithExtent(
50  detail::ScheduleTree* root,
51  detail::ScheduleTree* tree,
52  int pos,
53  MappingIdType id,
54  size_t extent) {
55  auto band = tree->elemAs<detail::ScheduleTreeElemBand>();
56  CHECK(band) << "expected a band, got " << *tree;
57  CHECK_GE(pos, 0) << "dimension underflow";
58  CHECK_LT(pos, band->nMember()) << "dimension overflow";
59  CHECK_NE(extent, 0) << "NYI: mapping to 0";
60 
61  auto domain = activeDomainPoints(root, tree).universe();
62 
63  // Introduce the "mapping" parameter after checking it is not already present
64  // in the schedule space.
65  auto space = band->mupa_.get_space();
66  int idPos = space.find_dim_by_id(isl::dim_type::param, id);
67  if (idPos != -1) {
68  for (auto upa : isl::MUPA(band->mupa_)) {
69  for (auto pa : upa) {
70  CHECK(not pa.pa.involves_dims(isl::dim_type::param, pos, 1));
71  }
72  }
73  }
74 
75  // Create mapping filter by equating the newly introduced
76  // parameter "id" to the "pos"-th schedule dimension modulo its extent.
77  auto upa =
78  band->mupa_.get_union_pw_aff(pos).mod_val(isl::val(tree->ctx_, extent));
79  upa = upa.sub(isl::union_pw_aff::param_on_domain(domain, id));
80  auto filter = upa.zero_union_set();
81  return insertMappingFilterAbove<MappingIdType>(root, tree, filter, {id})
82  ->child({0});
83 }
84 } // namespace polyhedral
85 } // namespace tc
Definition: isl_mu_wrappers.h:194