21 #include <unordered_map>
24 #include <dlpack/dlpack.h>
35 namespace polyhedral {
38 using ReductionDimSet = std::set<std::string>;
39 class TensorReferenceGroup;
49 static std::unique_ptr<Scop> makeScop(
57 static std::unique_ptr<Scop> makeScop(isl::ctx ctx,
const std::string& tc);
59 static std::unique_ptr<Scop> makeScop(
64 static std::unique_ptr<Scop> makeScop(
const Scop& scop) {
65 auto res = std::unique_ptr<Scop>(
new Scop());
66 res->globalParameterContext = scop.globalParameterContext;
67 res->halide = scop.halide;
68 res->reads = scop.reads;
69 res->writes = scop.writes;
70 res->scheduleTreeUPtr =
72 res->treeSyncUpdateMap = scop.treeSyncUpdateMap;
73 res->defaultReductionInitMap = scop.defaultReductionInitMap;
74 res->groupCounts_ = scop.groupCounts_;
75 res->promotedDecls_ = scop.promotedDecls_;
76 res->activePromotions_ = scop.activePromotions_;
81 inline void intersectContext(isl::set extraGlobalParameterContext) {
82 auto context = globalParameterContext & extraGlobalParameterContext;
83 globalParameterContext = context;
92 static std::unique_ptr<Scop> makeSpecializedScop(
94 isl::set extraGlobalParameterContext) {
95 CHECK(extraGlobalParameterContext.is_subset(scop.globalParameterContext))
96 <<
"expected extra context " << extraGlobalParameterContext
97 <<
" to be more specialized than " << scop.globalParameterContext;
98 auto res = makeScop(scop);
99 res->intersectContext(extraGlobalParameterContext);
115 void specializeToContext() {
116 domain() = domain().intersect_params(globalParameterContext);
117 reads = reads.intersect_params(globalParameterContext);
118 writes = writes.intersect_params(globalParameterContext);
125 template <
typename T>
126 isl::set makeContext(
const std::vector<T>& sizes = std::vector<T>())
const {
127 auto s = domain().get_space().params();
133 template <
typename T>
134 isl::set makeContext(
135 const std::unordered_map<int, T>& sizes =
136 std::unordered_map<int, T>())
const {
137 auto s = domain().get_space().params();
143 template <
typename T>
144 isl::set makeContext(
145 const std::unordered_map<std::string, T>& sizes =
146 std::unordered_map<std::string, T>())
const {
147 auto s = domain().get_space().params();
155 isl::set makeContextFromInputs(
156 const std::vector<const DLTensor*>& inputs)
const;
160 template <
typename T>
161 void fixParameters(
const std::unordered_map<std::string, T>& sizes) {
162 intersectContext(makeContext(sizes));
168 std::vector<long> getParameterValues(isl::set context)
const;
170 isl::id nextGroupIdForTensor(isl::id tensorId) {
171 auto ctx = domain().get_ctx();
172 std::stringstream ss;
173 ss <<
"_" << tensorId.get_name() <<
"_" << groupCounts_[tensorId]++;
174 return isl::id(ctx, ss.str());
199 isl::id insertReductionSync1D(
200 detail::ScheduleTree* redPoint,
207 void insertSync(detail::ScheduleTree* seqNode,
size_t pos);
211 void insertSyncAfter(detail::ScheduleTree* tree) {
212 insertExtensionLabelAfter(scheduleRoot(), tree, makeSyncId());
215 size_t reductionUID()
const {
216 static size_t count = 0;
219 size_t syncUID()
const {
220 static size_t count = 0;
224 isl::id makeSyncId()
const {
225 auto ctx = domain().get_ctx();
226 return isl::id(ctx, std::string(kSyncIdPrefix) + std::to_string(syncUID()));
229 static bool isSyncId(isl::id
id) {
230 if (!
id.has_name()) {
233 auto name =
id.get_name();
234 if (name.find(kSyncIdPrefix) != 0) {
237 name = name.substr(std::string(kSyncIdPrefix).size());
239 std::strtol(name.c_str(), &end, 10);
240 if (end - name.c_str() != name.size()) {
246 static isl::id makeRefId(isl::ctx ctx) {
247 static thread_local
size_t count = 0;
248 return isl::id(ctx, std::string(
"__tc_ref_") + std::to_string(count++));
251 std::pair<isl::id, isl::id> makeReductionSpecialIds(isl::id updateId) {
252 auto uid = reductionUID();
253 auto treeSyncId = isl::id(
254 domain().get_ctx(), std::string(
"red_update") + std::to_string(
uid));
255 auto reductionInitId = isl::id(
256 domain().get_ctx(), std::string(
"red_init") + std::to_string(
uid));
257 CHECK_EQ(0, treeSyncUpdateMap.count(treeSyncId));
258 CHECK_EQ(0, defaultReductionInitMap.count(treeSyncId));
260 treeSyncUpdateMap.emplace(treeSyncId, updateId);
261 defaultReductionInitMap.emplace(treeSyncId, reductionInitId);
262 return std::make_pair(treeSyncId, reductionInitId);
265 bool isTreeSyncId(isl::id
id)
const {
266 return treeSyncUpdateMap.count(
id) == 1;
269 bool isDefaultReductionInitId(isl::id
id)
const {
270 for (
const auto& p : defaultReductionInitMap) {
271 if (p.second ==
id) {
278 isl::id getReductionUpdateForDefaultInit(isl::id
id)
const {
279 for (
const auto& p : defaultReductionInitMap) {
280 if (p.second ==
id) {
281 return treeSyncUpdateMap.at(p.first);
284 CHECK(
false) <<
"not found";
288 bool isReductionUpdate(isl::id
id)
const {
289 for (
const auto& kvp : treeSyncUpdateMap) {
290 if (
id == kvp.second) {
297 size_t reductionUpdatePos(isl::id
id)
const {
299 CHECK(isReductionUpdate(
id));
300 for (
const auto& kvp : treeSyncUpdateMap) {
301 if (
id == kvp.second) {
309 void promoteEverythingAt(std::vector<size_t> pos);
311 struct PromotedDecl {
313 std::vector<size_t> sizes;
316 struct PromotionInfo {
317 std::shared_ptr<TensorReferenceGroup> group;
318 isl::union_map outerSchedule;
322 const std::unordered_map<isl::id, PromotedDecl, isl::IslIdIslHash>&
323 promotedDecls()
const {
324 return promotedDecls_;
329 activePromotions()
const {
330 return activePromotions_;
333 detail::ScheduleTree* scheduleRoot() {
334 return scheduleTreeUPtr.get();
337 const detail::ScheduleTree* scheduleRoot()
const {
338 return scheduleTreeUPtr.get();
342 static std::unique_ptr<Scop> makeScheduled(
344 const SchedulerOptionsView& schedulerOptions);
350 detail::ScheduleTree* tileOuterBand(
const TilingView& tiling);
355 detail::ScheduleTree* tree,
356 const SchedulerOptionsView& schedulerOptions);
360 const Halide::OutputImageParam& findArgument(isl::id
id)
const;
374 void promoteGroupToShared(
376 std::unique_ptr<TensorReferenceGroup>&& gr,
377 detail::ScheduleTree* tree,
378 const std::unordered_set<isl::id, isl::IslIdIslHash>& activeStmts,
379 isl::union_map schedule,
380 bool forceLastExtentOdd =
false);
400 void insertSyncsAroundCopies(detail::ScheduleTree* tree);
407 static std::unique_ptr<detail::ScheduleTree> computeSchedule(
408 isl::schedule_constraints constraints,
409 const SchedulerOptionsView& schedulerOptions);
414 std::vector<Halide::Internal::Parameter> params;
415 std::vector<std::string> idx, reductionIdx;
416 std::vector<Halide::ImageParam> inputs;
417 std::vector<Halide::OutputImageParam> outputs;
418 std::vector<halide2isl::Reduction> reductions;
419 std::unordered_map<isl::id, Halide::Internal::Stmt, isl::IslIdIslHash>
421 std::unordered_map<const Halide::Internal::IRNode*, isl::id> accesses;
430 isl::union_set& domain();
431 const isl::union_set domain()
const;
443 isl::set globalParameterContext;
445 isl::union_map reads;
446 isl::union_map writes;
455 std::unique_ptr<detail::ScheduleTree> scheduleTreeUPtr;
459 std::unordered_map<isl::id, isl::id, isl::IslIdIslHash> treeSyncUpdateMap;
460 std::unordered_map<isl::id, isl::id, isl::IslIdIslHash>
461 defaultReductionInitMap;
466 std::unordered_map<isl::id, size_t, isl::IslIdIslHash> groupCounts_;
468 std::unordered_map<isl::id, PromotedDecl, isl::IslIdIslHash> promotedDecls_;
474 std::ostream&
operator<<(std::ostream& os,
const Scop&);
Definition: tc2halide.h:29
isl::set makeSpecializationSet(isl::space space, const std::unordered_map< int, T > ¶mValues)
Definition: islpp.h:314
std::ostream & operator<<(std::ostream &out, const MappingOptionsAsCpp &mo)
Definition: mapping_options_cpp_printer.h:79
std::shared_ptr< Tree > TreeRef
Definition: tree.h:44
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Halide::Internal::Stmt &s)