Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
memory_promotion.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <iostream>
19 
22 #include "tc/external/isl.h"
23 
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <vector>
27 
28 namespace tc {
29 namespace polyhedral {
30 
31 enum class AccessType : short { Read, Write };
32 
33 // A single dimension of the ScopedFootprint.
34 // The scope is defined by a specific position in a schedule tree (const
35 // ScheduleTree*), the user is responsible for maintaining the correspondance
36 // between schedule tree positions and footprints.
37 // Overapproximates one dimension by its lower bound, affine function of
38 // parameters and schedule dimensions visible around the scope, and by a
39 // constant size.
40 struct ScopedFootprintDim {
41  public:
42  ScopedFootprintDim(isl::aff lb, isl::val s) : lowerBound(lb), size(s) {}
43 
44  public:
45  isl::aff lowerBound;
46  isl::val size;
47 };
48 
49 // Rectangular overapproximation of a tensor elements accessed through a single
50 // reference. Each dimension is described independently.
51 // The scope is defined by a specific position in a schedule tree (const
52 // ScheduleTree*), the user is responsible for maintaining the correspondance
53 // between schedule tree positions and footprints.
54 struct ScopedFootprint : std::vector<ScopedFootprintDim> {
55  isl::set footprint(isl::set domain) const;
56  isl::multi_aff lowerBounds() const;
57 };
58 
59 // Descriptor of tensor reference in a Scop.
60 // May be scoped to a specific position in a schedule tree, the user is
61 // responsible for maintaining the correspondance between schedule tree
62 // positions and scoped access relations.
63 class TensorReference {
64  public:
65  bool isRead() const {
66  return type == AccessType::Read;
67  }
68 
69  bool isWrite() const {
70  return type == AccessType::Write;
71  }
72 
73  public:
74  // Original access relation in terms of the Scop domain.
75  isl::map originalAccess;
76 
77  // Access relation in terms of partial schedule at the point where the
78  // reference group is introduced in the tree.
79  isl::map scopedAccess;
80 
81  // Access direction (read or write).
82  AccessType type;
83 
84  // Unique identifier of a reference in the Scop.
85  isl::id refId;
86 };
87 
88 class TensorReferenceGroup;
89 using TensorGroupsInfo = std::vector<std::unique_ptr<TensorReferenceGroup>>;
90 typedef std::unordered_map<isl::id, TensorGroupsInfo, isl::IslIdIslHash>
91  TensorGroups;
92 
93 // A group of tensor references that must be handled together during memory
94 // promotion. In particular, references that access the same tensor element,
95 // and at least one of them modifies it, should be placed in the shared/private
96 // memory together to avoid inconsistent values.
97 //
98 // Scoped to a specific position in a schedule tree, the user is responsible
99 // for maintaing the correspondance between schedule tree positions and scoped
100 // access relations of each reference as well as scoped footprints.
101 class TensorReferenceGroup {
102  private:
103  TensorReferenceGroup() {}
104 
105  public:
106  static TensorGroups accessedBySubtree(
107  const detail::ScheduleTree* tree,
108  const Scop& scop);
109 
110  bool isReadOnly() const;
111 
112  // Sets of tensor elements accessed below the scoping point.
113  isl::set writeFootprint() const;
114  isl::set readFootprint() const;
115  isl::set footprint() const {
116  return writeFootprint().unite(readFootprint());
117  }
118 
119  // Access relations in terms of partial schedule of the scoping point.
120  isl::map scopedWrites() const;
121  isl::map scopedReads() const;
122  isl::map scopedAccesses() const {
123  return scopedWrites().unite(scopedReads());
124  }
125 
126  // Access relations in terms of Scop domain elements.
127  // The resulting union relations have different domain spaces but identical
128  // range spaces.
129  isl::union_map originalWrites() const;
130  isl::union_map originalReads() const;
131  isl::union_map originalAccesses() const {
132  return originalWrites().unite(originalReads());
133  }
134 
135  // Rectangular overapproximation of the set of tensor elements accessed below
136  // the scoping point.
137  isl::set approximateFootprint() const {
138  return approximation.footprint(scopedAccesses().domain());
139  }
140 
141  isl::multi_aff promotion() const;
142  isl::set promotedFootprint() const;
143 
144  std::vector<size_t> approximationSizes() const;
145 
146  std::unordered_set<isl::id, isl::IslIdIslHash> referenceIds() const;
147 
148  static std::unique_ptr<TensorReferenceGroup> makeJoint(
149  std::unique_ptr<TensorReferenceGroup>&& g1,
150  std::unique_ptr<TensorReferenceGroup>&& g2);
151  static std::unique_ptr<TensorReferenceGroup> makeSingleton(
152  isl::map originalAccess,
153  isl::map scopedAccess,
154  AccessType type);
155 
156  public:
157  std::vector<std::unique_ptr<TensorReference>> references;
158  ScopedFootprint approximation;
159 };
160 
161 inline std::ostream& operator<<(std::ostream& os, const ScopedFootprint& fp) {
162  int i = 0;
163  for (const auto& f : fp) {
164  if (i++ == 0) {
165  os << "{\n";
166  }
167  os << f.lowerBound << " of size " << f.size << "\n";
168  }
169  os << "}";
170  return os;
171 }
172 
173 inline std::ostream& operator<<(std::ostream& os, const TensorReference& tr) {
174  os << ((tr.isRead()) ? "rd" : "wr") << " scopedAccess: " << tr.scopedAccess;
175  ;
176  return os;
177 }
178 
179 inline std::ostream& operator<<(
180  std::ostream& os,
181  const TensorReferenceGroup& tg) {
182  os << " with footprint BB: " << tg.approximation << " ";
183  for (const auto& tr : tg.references) {
184  os << *tr << " ";
185  }
186  return os;
187 }
188 
189 inline std::ostream& operator<<(std::ostream& os, const TensorGroupsInfo& ti) {
190  for (const auto& tg : ti) {
191  os << *tg << " ";
192  }
193  return os;
194 }
195 
196 inline std::ostream& operator<<(std::ostream& os, const TensorGroups& tg) {
197  int i = 0;
198  for (const auto& kvp : tg) {
199  os << "id: " << kvp.first << "; acc: " << kvp.second;
200  if (++i < tg.size()) {
201  os << std::endl;
202  }
203  }
204  return os;
205 }
206 
207 detail::ScheduleTree* insertCopiesUnder(
208  Scop& scop,
209  detail::ScheduleTree* tree,
210  const TensorReferenceGroup& group,
211  isl::id tensorId,
212  isl::id groupId = isl::id());
213 } // namespace polyhedral
214 } // namespace tc
std::ostream & operator<<(std::ostream &out, const MappingOptionsAsCpp &mo)
Definition: mapping_options_cpp_printer.h:79