Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
schedule_tree_matcher-inl.h
Go to the documentation of this file.
1 
16 #pragma once
17 
20 
21 namespace tc {
22 namespace polyhedral {
23 
24 struct ScheduleTreeMatcher;
25 
26 bool matchOne(ScheduleTreeMatcher matcher, const detail::ScheduleTree* tree);
27 
28 // by-value everywhere for now, probably want something more memory-efficient
29 struct ScheduleTreeMatcher {
30  friend bool matchOne(ScheduleTreeMatcher, const detail::ScheduleTree*);
31 
32  template <typename... Args>
33  ScheduleTreeMatcher(detail::ScheduleTreeType type, Args... args)
34  : type_(type),
35  children_({args...}),
36  propertyMatcher_([](const detail::ScheduleTree*) { return true; }),
37  wildcard(false) {}
38 
39  detail::ScheduleTreeType type_;
40  std::vector<ScheduleTreeMatcher> children_;
41  std::function<bool(const detail::ScheduleTree*)> propertyMatcher_;
42  bool wildcard;
43 };
44 
45 // TODO: all this should probably go in a namespace
46 //
47 // This may also have a function<bool()> callback for consistency and finer
48 // grain control
49 template <typename... Args>
50 inline ScheduleTreeMatcher sequence(Args... children) {
51  return ScheduleTreeMatcher(detail::ScheduleTreeType::Sequence, children...);
52 }
53 
54 template <typename... Args>
55 inline ScheduleTreeMatcher domain(Args... children) {
56  return ScheduleTreeMatcher(detail::ScheduleTreeType::Domain, children...);
57 }
58 
59 template <typename... Args>
60 inline ScheduleTreeMatcher context(Args... children) {
61  return ScheduleTreeMatcher(detail::ScheduleTreeType::Context, children...);
62 }
63 
64 template <typename... Args>
65 inline ScheduleTreeMatcher filter(
66  std::function<bool(isl::union_set)> propertyMatcher,
67  Args... children) {
68  ScheduleTreeMatcher m(detail::ScheduleTreeType::Filter, children...);
69  m.propertyMatcher_ = [propertyMatcher](const detail::ScheduleTree* tree) {
70  return propertyMatcher(
71  tree->elemAs<detail::ScheduleTreeElemFilter>()->filter_);
72  };
73  return m;
74 }
75 
76 template <typename... Args>
77 inline ScheduleTreeMatcher filter(
78  std::function<bool(const detail::ScheduleTree* tree)> propertyMatcher,
79  Args... children) {
80  ScheduleTreeMatcher m(detail::ScheduleTreeType::Filter, children...);
81  m.propertyMatcher_ = propertyMatcher;
82  return m;
83 }
84 
85 // the enable_if horror is necessary to have proper overload resolution in cases
86 // filter(), filter([](...){}) and filter(filter())
87 template <
88  typename First,
89  typename... Args,
90  typename = typename std::enable_if<
91  std::is_same<First, ScheduleTreeMatcher>::value>::type>
92 inline ScheduleTreeMatcher filter(First first, Args... children) {
93  return ScheduleTreeMatcher(
94  detail::ScheduleTreeType::Filter, first, children...);
95 }
96 
97 inline ScheduleTreeMatcher filter() {
98  return ScheduleTreeMatcher(detail::ScheduleTreeType::Filter);
99 }
100 
101 // We could have mapping_filter restrict the property matcher but the
102 // extra-level of engineering sounds like a bad tradeoff, for now..
103 template <typename... Args>
104 inline ScheduleTreeMatcher mapping_filter(
105  std::function<bool(isl::union_set)> propertyMatcher,
106  Args... children) {
107  ScheduleTreeMatcher m(detail::ScheduleTreeType::MappingFilter, children...);
108  m.propertyMatcher_ = [propertyMatcher](const detail::ScheduleTree* tree) {
109  return propertyMatcher(
110  tree->elemAs<detail::ScheduleTreeElemMappingFilter>()->filter_);
111  };
112  return m;
113 }
114 
115 template <typename... Args>
116 inline ScheduleTreeMatcher mapping_filter(
117  std::function<bool(const detail::ScheduleTree* tree)> propertyMatcher,
118  Args... children) {
119  ScheduleTreeMatcher m(detail::ScheduleTreeType::MappingFilter, children...);
120  m.propertyMatcher_ = propertyMatcher;
121  return m;
122 }
123 
124 // the enable_if horror is necessary to have proper overload resolution in cases
125 // mapping_filter(), mapping_filter([](...){}) and
126 // mapping_filter(mapping_filter())
127 template <
128  typename First,
129  typename... Args,
130  typename = typename std::enable_if<
131  std::is_same<First, ScheduleTreeMatcher>::value>::type>
132 inline ScheduleTreeMatcher mapping_filter(First first, Args... children) {
133  return ScheduleTreeMatcher(
134  detail::ScheduleTreeType::MappingFilter, first, children...);
135 }
136 
137 inline ScheduleTreeMatcher mapping_filter() {
138  return ScheduleTreeMatcher(detail::ScheduleTreeType::MappingFilter);
139 }
140 
141 template <typename... Args>
142 inline ScheduleTreeMatcher band(
143  std::function<bool(
144  isl::multi_union_pw_aff mupa,
145  bool permutable,
146  std::vector<bool> coincident,
147  std::vector<bool> unroll)> propertyMatcher,
148  Args... children) {
149  ScheduleTreeMatcher m(detail::ScheduleTreeType::Band, children...);
150  m.propertyMatcher_ = [propertyMatcher](const detail::ScheduleTree* tree) {
151  auto band = tree->elemAs<detail::ScheduleTreeElemBand>();
152  return propertyMatcher(
153  band->mupa_, band->permutable_, band->coincident_, band->unroll_);
154  };
155  return m;
156 }
157 
158 template <
159  typename First,
160  typename... Args,
161  typename = typename std::enable_if<
162  std::is_same<First, ScheduleTreeMatcher>::value>::type>
163 inline ScheduleTreeMatcher band(First first, Args... children) {
164  return ScheduleTreeMatcher(
165  detail::ScheduleTreeType::Band, first, children...);
166 }
167 
168 inline ScheduleTreeMatcher band() {
169  return ScheduleTreeMatcher(detail::ScheduleTreeType::Band);
170 }
171 
172 template <typename... Args>
173 inline ScheduleTreeMatcher extension(
174  std::function<bool(isl::union_map)> propertyMatcher,
175  Args... children) {
176  ScheduleTreeMatcher m(detail::ScheduleTreeType::Extension, children...);
177  m.propertyMatcher_ = [propertyMatcher](const detail::ScheduleTree* tree) {
178  return propertyMatcher(
179  tree->elemAs<detail::ScheduleTreeElemExtension>()->extension_);
180  };
181  return m;
182 }
183 
184 template <
185  typename First,
186  typename... Args,
187  typename = typename std::enable_if<
188  std::is_same<First, ScheduleTreeMatcher>::value>::type>
189 inline ScheduleTreeMatcher extension(First first, Args... children) {
190  return ScheduleTreeMatcher(
191  detail::ScheduleTreeType::Extension, first, children...);
192 }
193 
194 inline ScheduleTreeMatcher extension() {
195  return ScheduleTreeMatcher(detail::ScheduleTreeType::Extension);
196 }
197 
198 // Wildcard ScheduleTreeMatcher can match any 1 or more nodes.
199 // Examples:
200 //
201 // * filter(
202 // any()) matches a filter node with any non-empty subtree
203 //
204 // * filter() matches a leaf filter
205 //
206 // * filter(filter(), any()) matches any subtree whose root is a filter and
207 // whose first child is a filter
208 inline ScheduleTreeMatcher any() {
209  ScheduleTreeMatcher m(detail::ScheduleTreeType::Any);
210  m.wildcard = true;
211  return m;
212 }
213 
214 inline bool matchOne(
215  ScheduleTreeMatcher matcher,
216  const detail::ScheduleTree* tree) {
217  if (!tree) {
218  return false;
219  }
220  if (matcher.wildcard) {
221  return true;
222  }
223  if (matcher.type_ != tree->type_) {
224  return false;
225  }
226  if (!matcher.propertyMatcher_(tree)) {
227  return false;
228  }
229  // Special casing children cases to avoid accessing invalid memory
230  // a. 0 children in either => the number of children need to match
231  if (matcher.children_.size() == 0 || tree->numChildren() == 0) {
232  if (matcher.children_.size() != tree->numChildren()) {
233  return false;
234  }
235  return true;
236  }
237  // b. matcher.children do not end in wildcard then all children must match
238  if (!matcher.children_.back().wildcard &&
239  matcher.children_.size() != tree->numChildren()) {
240  return false;
241  }
242  // c. whatever the case matcher cannot match if is has more children
243  if (matcher.children_.size() > tree->numChildren()) {
244  return false;
245  }
246  // No need to do a BFS here because we recurse anyway.
247  // Only match up to the number of children of the matcher because:
248  // 1. if matcher ends with "any", the remaining children are considered
249  // matched
250  // 2. otherwise we must have the same number of children or we would have
251  // exited just above.
252  // We still need to check well-formedness of the matcher (i.e. no wildcards
253  // except in the last position)
254  for (size_t i = 0; i < matcher.children_.size(); ++i) {
255  CHECK(!matcher.children_[i].wildcard || i == matcher.children_.size() - 1)
256  << "Error in matcher structure, wildcard must be the last child!";
257  if (!matchOne(matcher.children_[i], tree->child({i}))) {
258  return false;
259  }
260  }
261 
262  return true;
263 }
264 
265 // TODO: we may need non-const versions of these to allow for modification
266 // after matching, the property matchers should still take const though.
267 //
268 // FIXME: we are "using namespace detail", specification below is redundant
269 inline std::vector<const detail::ScheduleTree*> matchDFSPreorder(
270  ScheduleTreeMatcher matcher,
271  const detail::ScheduleTree* tree) {
272  std::vector<const detail::ScheduleTree*> res;
273  for (auto t : detail::ScheduleTree::collectDFSPreorder(tree)) {
274  if (matchOne(matcher, t)) {
275  res.push_back(t);
276  }
277  }
278  return res;
279 }
280 
281 // Look for matches in arbitrary order.
282 inline std::vector<const detail::ScheduleTree*> match(
283  ScheduleTreeMatcher matcher,
284  const detail::ScheduleTree* tree) {
285  return matchDFSPreorder(matcher, tree);
286 }
287 
288 } // namespace polyhedral
289 } // namespace tc