Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
schedule_tree_elem.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <memory>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tc/external/isl.h"
23 
25 
26 namespace tc {
27 namespace polyhedral {
28 namespace detail {
29 
30 enum class ScheduleTreeType {
31  None,
32  Band,
33  Context,
34  Domain,
35  Extension,
36  Filter,
37  Sequence,
38  Set,
39  MappingFilter,
40  Any,
41 };
42 
43 struct ScheduleTree;
44 
45 struct ScheduleTreeElemBase {
46  static constexpr detail::ScheduleTreeType NodeType =
47  detail::ScheduleTreeType::None;
48  static std::unique_ptr<ScheduleTreeElemBase> make(isl::schedule_node node);
49  static std::unique_ptr<ScheduleTreeElemBase> make(const ScheduleTree& st);
50  virtual ~ScheduleTreeElemBase() {}
51  virtual std::ostream& write(std::ostream& os) const = 0;
52  virtual detail::ScheduleTreeType type() const = 0;
53 };
54 
55 struct ScheduleTreeElemContext : public ScheduleTreeElemBase {
56  static constexpr std::initializer_list<detail::ScheduleTreeType>
57  NodeDerivedTypes{detail::ScheduleTreeType::None};
58  static constexpr detail::ScheduleTreeType NodeType =
59  detail::ScheduleTreeType::Context;
60  isl::set context_;
61  ScheduleTreeElemContext() = delete;
62  ScheduleTreeElemContext(const ScheduleTreeElemContext& eb)
63  : context_(eb.context_) {}
64  explicit ScheduleTreeElemContext(isl::set s) : context_(s) {}
65  virtual ~ScheduleTreeElemContext() override {}
66  bool operator==(const ScheduleTreeElemContext& other) const;
67  bool operator!=(const ScheduleTreeElemContext& other) const {
68  return !(*this == other);
69  }
70  virtual std::ostream& write(std::ostream& os) const override;
71  virtual detail::ScheduleTreeType type() const override {
72  return NodeType;
73  }
74 };
75 
76 struct ScheduleTreeElemDomain : public ScheduleTreeElemBase {
77  static constexpr std::initializer_list<detail::ScheduleTreeType>
78  NodeDerivedTypes{detail::ScheduleTreeType::None};
79  static constexpr detail::ScheduleTreeType NodeType =
80  detail::ScheduleTreeType::Domain;
81  isl::union_set domain_;
82  ScheduleTreeElemDomain() = delete;
83  ScheduleTreeElemDomain(const ScheduleTreeElemDomain& eb)
84  : domain_(eb.domain_) {}
85  explicit ScheduleTreeElemDomain(isl::union_set us) : domain_(us) {}
86  virtual ~ScheduleTreeElemDomain() override {}
87  bool operator==(const ScheduleTreeElemDomain& other) const;
88  bool operator!=(const ScheduleTreeElemDomain& other) const {
89  return !(*this == other);
90  }
91  virtual std::ostream& write(std::ostream& os) const override;
92  virtual detail::ScheduleTreeType type() const override {
93  return NodeType;
94  }
95 };
96 
97 struct ScheduleTreeElemExtension : public ScheduleTreeElemBase {
98  static constexpr std::initializer_list<detail::ScheduleTreeType>
99  NodeDerivedTypes{detail::ScheduleTreeType::None};
100  static constexpr detail::ScheduleTreeType NodeType =
101  detail::ScheduleTreeType::Extension;
102  isl::union_map extension_;
103  ScheduleTreeElemExtension() = delete;
104  ScheduleTreeElemExtension(const ScheduleTreeElemExtension& eb)
105  : extension_(eb.extension_) {}
106  explicit ScheduleTreeElemExtension(isl::union_map m) : extension_(m) {}
107  virtual ~ScheduleTreeElemExtension() override {}
108  bool operator==(const ScheduleTreeElemExtension& other) const;
109  bool operator!=(const ScheduleTreeElemExtension& other) const {
110  return !(*this == other);
111  }
112  virtual std::ostream& write(std::ostream& os) const override;
113  virtual detail::ScheduleTreeType type() const override {
114  return NodeType;
115  }
116 };
117 
118 struct ScheduleTreeElemFilter : public ScheduleTreeElemBase {
119  static constexpr std::initializer_list<detail::ScheduleTreeType>
120  NodeDerivedTypes{detail::ScheduleTreeType::MappingFilter};
121  static constexpr detail::ScheduleTreeType NodeType =
122  detail::ScheduleTreeType::Filter;
123  isl::union_set filter_;
124  ScheduleTreeElemFilter() = delete;
125  ScheduleTreeElemFilter(const ScheduleTreeElemFilter& eb)
126  : filter_(eb.filter_) {}
127  explicit ScheduleTreeElemFilter(isl::union_set s) : filter_(s) {}
128  virtual ~ScheduleTreeElemFilter() override {}
129  bool operator==(const ScheduleTreeElemFilter& other) const;
130  bool operator!=(const ScheduleTreeElemFilter& other) const {
131  return !(*this == other);
132  }
133  virtual std::ostream& write(std::ostream& os) const override;
134  virtual detail::ScheduleTreeType type() const override {
135  return NodeType;
136  }
137 };
138 
139 struct ScheduleTreeElemMappingFilter : public ScheduleTreeElemFilter {
140  static constexpr std::initializer_list<detail::ScheduleTreeType>
141  NodeDerivedTypes{detail::ScheduleTreeType::None};
142  static constexpr detail::ScheduleTreeType NodeType =
143  detail::ScheduleTreeType::MappingFilter;
144  ScheduleTreeElemMappingFilter() = delete;
145  ScheduleTreeElemMappingFilter(const ScheduleTreeElemMappingFilter& eb)
146  : ScheduleTreeElemFilter(eb.filter_), mappingIds(eb.mappingIds) {}
147  ScheduleTreeElemMappingFilter(
148  isl::union_set us,
149  const std::unordered_set<
150  mapping::MappingId,
151  typename mapping::MappingId::Hash>& ids)
152  : ScheduleTreeElemFilter(us), mappingIds(ids) {
153  USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ);
154  for (auto s : isl::UNION_SET(us)) {
155  for (auto id : std::vector<mapping::MappingId>{BX, BY, BZ, TX, TY, TZ}) {
156  if (mappingIds.count(id) > 0) {
157  CHECK_EQ(1, ids.count(id)) << "id: " << id << " mapped >1 times";
158  CHECK_LE(0, s.find_dim_by_id(isl::dim_type::param, id))
159  << "unexpected missing id: " << id << " in filter: " << s;
160  } else {
161  auto pos = s.find_dim_by_id(isl::dim_type::param, id);
162  bool involved =
163  pos > 0 && s.involves_dims(isl::dim_type::param, pos, 1);
164  if (involved) {
165  std::stringstream ss;
166  for (auto id : ids) {
167  ss << id.to_str() << " ";
168  }
169  // TODO: will need to relax this if we map the same loop
170  // iteratively without stripmining it beforehand
171  CHECK(false) << "unexpected involved id: " << id
172  << " in filter: " << s
173  << " but not present in filter id list: " << ss.str();
174  }
175  }
176  }
177  }
178  }
179  virtual ~ScheduleTreeElemMappingFilter() override {}
180  bool operator==(const ScheduleTreeElemMappingFilter& other) const;
181  bool operator!=(const ScheduleTreeElemMappingFilter& other) const {
182  return !(*this == other);
183  }
184  virtual std::ostream& write(std::ostream& os) const override;
185  virtual detail::ScheduleTreeType type() const override {
186  return NodeType;
187  }
188 
189  const std::
190  unordered_set<mapping::MappingId, typename mapping::MappingId::Hash>
191  mappingIds;
192 };
193 
194 struct ScheduleTreeElemSequence : public ScheduleTreeElemBase {
195  static constexpr std::initializer_list<detail::ScheduleTreeType>
196  NodeDerivedTypes{detail::ScheduleTreeType::None};
197  static constexpr detail::ScheduleTreeType NodeType =
198  detail::ScheduleTreeType::Sequence;
199  explicit ScheduleTreeElemSequence() {}
200  ScheduleTreeElemSequence(const ScheduleTreeElemSequence& eb) {}
201  virtual ~ScheduleTreeElemSequence() override {}
202  bool operator==(const ScheduleTreeElemSequence& other) const;
203  bool operator!=(const ScheduleTreeElemSequence& other) const {
204  return !(*this == other);
205  }
206  virtual std::ostream& write(std::ostream& os) const override;
207  virtual detail::ScheduleTreeType type() const override {
208  return NodeType;
209  }
210 };
211 
212 struct ScheduleTreeElemSet : public ScheduleTreeElemBase {
213  static constexpr std::initializer_list<detail::ScheduleTreeType>
214  NodeDerivedTypes{detail::ScheduleTreeType::None};
215  static constexpr detail::ScheduleTreeType NodeType =
216  detail::ScheduleTreeType::Set;
217  explicit ScheduleTreeElemSet() {}
218  ScheduleTreeElemSet(const ScheduleTreeElemSet& eb) {}
219  virtual ~ScheduleTreeElemSet() override {}
220  bool operator==(const ScheduleTreeElemSet& other) const;
221  bool operator!=(const ScheduleTreeElemSet& other) const {
222  return !(*this == other);
223  }
224  virtual std::ostream& write(std::ostream& os) const override;
225  virtual detail::ScheduleTreeType type() const override {
226  return NodeType;
227  }
228 };
229 
230 struct ScheduleTreeElemBand : public ScheduleTreeElemBase {
231  private:
232  ScheduleTreeElemBand() = default;
233 
234  public:
235  static constexpr std::initializer_list<detail::ScheduleTreeType>
236  NodeDerivedTypes{detail::ScheduleTreeType::None};
237  static constexpr detail::ScheduleTreeType NodeType =
238  detail::ScheduleTreeType::Band;
239 
240  ScheduleTreeElemBand(const ScheduleTreeElemBand& eb)
241  : permutable_(eb.permutable_),
242  mupa_(eb.mupa_),
243  coincident_(eb.coincident_),
244  unroll_(eb.unroll_) {}
245  virtual ~ScheduleTreeElemBand() override {}
246  bool operator==(const ScheduleTreeElemBand& other) const;
247  bool operator!=(const ScheduleTreeElemBand& other) const {
248  return !(*this == other);
249  }
250  virtual std::ostream& write(std::ostream& os) const override;
251  virtual detail::ScheduleTreeType type() const override {
252  return NodeType;
253  }
254 
255  // First replace "mupa" by its greatest integer part to ensure that the
256  // schedule is always integral.
257  // The band is not marked permutable, the dimensions are not marked
258  // coincident and are not marked for unrolling.
259  static std::unique_ptr<ScheduleTreeElemBand> fromMultiUnionPwAff(
260  isl::multi_union_pw_aff mupa);
261 
262  // Return the number of scheduling dimensions in the band
263  size_t nMember() const;
264 
265  // Return the number of outer coincident members in the band.
266  size_t nOuterCoincident() const;
267 
268  // Drop the "n" dimensions starting at "pos" from "band".
269  // We apply the transformation even if "n" is zero to ensure consistent
270  // behavior with respect to changes in the schedule space.
271  // The caller is responsible for updating the isolate option (Note: why?)
272  void drop(int pos, int n);
273 
274  public:
275  bool permutable_{false};
276  isl::multi_union_pw_aff mupa_;
277 
278  std::vector<bool> coincident_;
279  // For each member, should the corresponding loop in the generated code
280  // be (fully) unrolled?
281  std::vector<bool> unroll_;
282 };
283 
284 bool elemEquals(
285  const ScheduleTreeElemBase* e1,
286  const ScheduleTreeElemBase* e2,
287  detail::ScheduleTreeType type);
288 
289 std::ostream& operator<<(std::ostream& os, isl::ast_loop_type lt);
290 std::ostream& operator<<(std::ostream& os, detail::ScheduleTreeType nt);
291 std::ostream& operator<<(
292  std::ostream& os,
293  const std::vector<std::unique_ptr<ScheduleTree>>& vst);
294 std::ostream& operator<<(std::ostream& os, const ScheduleTreeElemBase& eb);
295 
296 } // namespace detail
297 } // namespace polyhedral
298 } // namespace tc
std::ostream & operator<<(std::ostream &out, const MappingOptionsAsCpp &mo)
Definition: mapping_options_cpp_printer.h:79
bool operator==(const std::vector< const DLTensor * > &inputsTensor, const std::vector< detail::TensorInfo > &inputsInfo)
Definition: isl_mu_wrappers.h:208
bool operator!=(isl::val v, long i)
Definition: islpp.h:103
#define USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ)
Definition: mapping_types.h:118