19 #include <unordered_set>
27 namespace polyhedral {
30 enum class ScheduleTreeType {
45 struct ScheduleTreeElemBase {
46 static constexpr detail::ScheduleTreeType NodeType =
47 detail::ScheduleTreeType::None;
48 static std::unique_ptr<ScheduleTreeElemBase> make(isl::schedule_node node);
49 static std::unique_ptr<ScheduleTreeElemBase> make(
const ScheduleTree& st);
50 virtual ~ScheduleTreeElemBase() {}
51 virtual std::ostream& write(std::ostream& os)
const = 0;
52 virtual detail::ScheduleTreeType type()
const = 0;
55 struct ScheduleTreeElemContext :
public ScheduleTreeElemBase {
56 static constexpr std::initializer_list<detail::ScheduleTreeType>
57 NodeDerivedTypes{detail::ScheduleTreeType::None};
58 static constexpr detail::ScheduleTreeType NodeType =
59 detail::ScheduleTreeType::Context;
61 ScheduleTreeElemContext() =
delete;
62 ScheduleTreeElemContext(
const ScheduleTreeElemContext& eb)
63 : context_(eb.context_) {}
64 explicit ScheduleTreeElemContext(isl::set s) : context_(s) {}
65 virtual ~ScheduleTreeElemContext()
override {}
66 bool operator==(
const ScheduleTreeElemContext& other)
const;
67 bool operator!=(
const ScheduleTreeElemContext& other)
const {
68 return !(*
this == other);
70 virtual std::ostream& write(std::ostream& os)
const override;
71 virtual detail::ScheduleTreeType type()
const override {
76 struct ScheduleTreeElemDomain :
public ScheduleTreeElemBase {
77 static constexpr std::initializer_list<detail::ScheduleTreeType>
78 NodeDerivedTypes{detail::ScheduleTreeType::None};
79 static constexpr detail::ScheduleTreeType NodeType =
80 detail::ScheduleTreeType::Domain;
81 isl::union_set domain_;
82 ScheduleTreeElemDomain() =
delete;
83 ScheduleTreeElemDomain(
const ScheduleTreeElemDomain& eb)
84 : domain_(eb.domain_) {}
85 explicit ScheduleTreeElemDomain(isl::union_set us) : domain_(us) {}
86 virtual ~ScheduleTreeElemDomain()
override {}
87 bool operator==(
const ScheduleTreeElemDomain& other)
const;
88 bool operator!=(
const ScheduleTreeElemDomain& other)
const {
89 return !(*
this == other);
91 virtual std::ostream& write(std::ostream& os)
const override;
92 virtual detail::ScheduleTreeType type()
const override {
97 struct ScheduleTreeElemExtension :
public ScheduleTreeElemBase {
98 static constexpr std::initializer_list<detail::ScheduleTreeType>
99 NodeDerivedTypes{detail::ScheduleTreeType::None};
100 static constexpr detail::ScheduleTreeType NodeType =
101 detail::ScheduleTreeType::Extension;
102 isl::union_map extension_;
103 ScheduleTreeElemExtension() =
delete;
104 ScheduleTreeElemExtension(
const ScheduleTreeElemExtension& eb)
105 : extension_(eb.extension_) {}
106 explicit ScheduleTreeElemExtension(isl::union_map m) : extension_(m) {}
107 virtual ~ScheduleTreeElemExtension()
override {}
108 bool operator==(
const ScheduleTreeElemExtension& other)
const;
109 bool operator!=(
const ScheduleTreeElemExtension& other)
const {
110 return !(*
this == other);
112 virtual std::ostream& write(std::ostream& os)
const override;
113 virtual detail::ScheduleTreeType type()
const override {
118 struct ScheduleTreeElemFilter :
public ScheduleTreeElemBase {
119 static constexpr std::initializer_list<detail::ScheduleTreeType>
120 NodeDerivedTypes{detail::ScheduleTreeType::MappingFilter};
121 static constexpr detail::ScheduleTreeType NodeType =
122 detail::ScheduleTreeType::Filter;
123 isl::union_set filter_;
124 ScheduleTreeElemFilter() =
delete;
125 ScheduleTreeElemFilter(
const ScheduleTreeElemFilter& eb)
126 : filter_(eb.filter_) {}
127 explicit ScheduleTreeElemFilter(isl::union_set s) : filter_(s) {}
128 virtual ~ScheduleTreeElemFilter()
override {}
129 bool operator==(
const ScheduleTreeElemFilter& other)
const;
130 bool operator!=(
const ScheduleTreeElemFilter& other)
const {
131 return !(*
this == other);
133 virtual std::ostream& write(std::ostream& os)
const override;
134 virtual detail::ScheduleTreeType type()
const override {
139 struct ScheduleTreeElemMappingFilter :
public ScheduleTreeElemFilter {
140 static constexpr std::initializer_list<detail::ScheduleTreeType>
141 NodeDerivedTypes{detail::ScheduleTreeType::None};
142 static constexpr detail::ScheduleTreeType NodeType =
143 detail::ScheduleTreeType::MappingFilter;
144 ScheduleTreeElemMappingFilter() =
delete;
145 ScheduleTreeElemMappingFilter(
const ScheduleTreeElemMappingFilter& eb)
146 : ScheduleTreeElemFilter(eb.filter_), mappingIds(eb.mappingIds) {}
147 ScheduleTreeElemMappingFilter(
149 const std::unordered_set<
151 typename mapping::MappingId::Hash>& ids)
152 : ScheduleTreeElemFilter(us), mappingIds(ids) {
155 for (
auto id : std::vector<mapping::MappingId>{BX, BY, BZ, TX, TY, TZ}) {
156 if (mappingIds.count(
id) > 0) {
157 CHECK_EQ(1, ids.count(
id)) <<
"id: " <<
id <<
" mapped >1 times";
158 CHECK_LE(0, s.find_dim_by_id(isl::dim_type::param,
id))
159 <<
"unexpected missing id: " <<
id <<
" in filter: " << s;
161 auto pos = s.find_dim_by_id(isl::dim_type::param,
id);
163 pos > 0 && s.involves_dims(isl::dim_type::param, pos, 1);
165 std::stringstream ss;
166 for (
auto id : ids) {
167 ss <<
id.to_str() <<
" ";
171 CHECK(
false) <<
"unexpected involved id: " <<
id
172 <<
" in filter: " << s
173 <<
" but not present in filter id list: " << ss.str();
179 virtual ~ScheduleTreeElemMappingFilter()
override {}
180 bool operator==(
const ScheduleTreeElemMappingFilter& other)
const;
181 bool operator!=(
const ScheduleTreeElemMappingFilter& other)
const {
182 return !(*
this == other);
184 virtual std::ostream& write(std::ostream& os)
const override;
185 virtual detail::ScheduleTreeType type()
const override {
190 unordered_set<mapping::MappingId, typename mapping::MappingId::Hash>
194 struct ScheduleTreeElemSequence :
public ScheduleTreeElemBase {
195 static constexpr std::initializer_list<detail::ScheduleTreeType>
196 NodeDerivedTypes{detail::ScheduleTreeType::None};
197 static constexpr detail::ScheduleTreeType NodeType =
198 detail::ScheduleTreeType::Sequence;
199 explicit ScheduleTreeElemSequence() {}
200 ScheduleTreeElemSequence(
const ScheduleTreeElemSequence& eb) {}
201 virtual ~ScheduleTreeElemSequence()
override {}
202 bool operator==(
const ScheduleTreeElemSequence& other)
const;
203 bool operator!=(
const ScheduleTreeElemSequence& other)
const {
204 return !(*
this == other);
206 virtual std::ostream& write(std::ostream& os)
const override;
207 virtual detail::ScheduleTreeType type()
const override {
212 struct ScheduleTreeElemSet :
public ScheduleTreeElemBase {
213 static constexpr std::initializer_list<detail::ScheduleTreeType>
214 NodeDerivedTypes{detail::ScheduleTreeType::None};
215 static constexpr detail::ScheduleTreeType NodeType =
216 detail::ScheduleTreeType::Set;
217 explicit ScheduleTreeElemSet() {}
218 ScheduleTreeElemSet(
const ScheduleTreeElemSet& eb) {}
219 virtual ~ScheduleTreeElemSet()
override {}
220 bool operator==(
const ScheduleTreeElemSet& other)
const;
221 bool operator!=(
const ScheduleTreeElemSet& other)
const {
222 return !(*
this == other);
224 virtual std::ostream& write(std::ostream& os)
const override;
225 virtual detail::ScheduleTreeType type()
const override {
230 struct ScheduleTreeElemBand :
public ScheduleTreeElemBase {
232 ScheduleTreeElemBand() =
default;
235 static constexpr std::initializer_list<detail::ScheduleTreeType>
236 NodeDerivedTypes{detail::ScheduleTreeType::None};
237 static constexpr detail::ScheduleTreeType NodeType =
238 detail::ScheduleTreeType::Band;
240 ScheduleTreeElemBand(
const ScheduleTreeElemBand& eb)
241 : permutable_(eb.permutable_),
243 coincident_(eb.coincident_),
244 unroll_(eb.unroll_) {}
245 virtual ~ScheduleTreeElemBand()
override {}
246 bool operator==(
const ScheduleTreeElemBand& other)
const;
247 bool operator!=(
const ScheduleTreeElemBand& other)
const {
248 return !(*
this == other);
250 virtual std::ostream& write(std::ostream& os)
const override;
251 virtual detail::ScheduleTreeType type()
const override {
259 static std::unique_ptr<ScheduleTreeElemBand> fromMultiUnionPwAff(
260 isl::multi_union_pw_aff mupa);
263 size_t nMember()
const;
266 size_t nOuterCoincident()
const;
272 void drop(
int pos,
int n);
275 bool permutable_{
false};
276 isl::multi_union_pw_aff mupa_;
278 std::vector<bool> coincident_;
281 std::vector<bool> unroll_;
285 const ScheduleTreeElemBase* e1,
286 const ScheduleTreeElemBase* e2,
287 detail::ScheduleTreeType type);
289 std::ostream&
operator<<(std::ostream& os, isl::ast_loop_type lt);
290 std::ostream&
operator<<(std::ostream& os, detail::ScheduleTreeType nt);
293 const std::vector<std::unique_ptr<ScheduleTree>>& vst);
294 std::ostream&
operator<<(std::ostream& os,
const ScheduleTreeElemBase& eb);
std::ostream & operator<<(std::ostream &out, const MappingOptionsAsCpp &mo)
Definition: mapping_options_cpp_printer.h:79
bool operator==(const std::vector< const DLTensor * > &inputsTensor, const std::vector< detail::TensorInfo > &inputsInfo)
Definition: isl_mu_wrappers.h:208
bool operator!=(isl::val v, long i)
Definition: islpp.h:103
#define USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ)
Definition: mapping_types.h:118