Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
tree_views.h
Go to the documentation of this file.
1 
16 #pragma once
17 #include "tc/lang/error_report.h"
18 #include "tc/lang/tree.h"
19 
20 namespace lang {
21 
24 struct TreeView {
25  explicit TreeView(const TreeRef& tree_) : tree_(tree_) {}
26  TreeRef tree() const {
27  return tree_;
28  }
29  const SourceRange& range() const {
30  return tree_->range();
31  }
32  operator TreeRef() const {
33  return tree_;
34  }
35 
36  protected:
37  TreeRef subtree(size_t i) const {
38  return tree_->tree(i);
39  }
41 };
42 
43 template <typename T>
45  ListViewIterator(TreeList::const_iterator it) : it(it) {}
46  bool operator!=(const ListViewIterator& rhs) const {
47  return it != rhs.it;
48  }
49  T operator*() const {
50  return T(*it);
51  }
52  void operator++() {
53  ++it;
54  }
55  void operator--() {
56  --it;
57  }
58 
59  private:
60  TreeList::const_iterator it;
61 };
62 
63 template <typename T>
64 struct ListView : public TreeView {
65  ListView(const TreeRef& tree) : TreeView(tree) {
66  tree->expect(TK_LIST);
67  }
70  iterator begin() const {
71  return iterator(tree_->trees().begin());
72  }
73  iterator end() const {
74  return iterator(tree_->trees().end());
75  }
76  T operator[](size_t i) const {
77  return T(tree_->trees().at(i));
78  }
79  TreeRef map(std::function<TreeRef(const T&)> fn) {
80  return tree_->map([&](TreeRef v) { return fn(T(v)); });
81  }
82  size_t size() const {
83  return tree_->trees().size();
84  }
85  static TreeRef create(const SourceRange& range, TreeList elements) {
86  return Compound::create(TK_LIST, range, std::move(elements));
87  }
88 };
89 
91 
92 template <typename T>
93 struct OptionView : public TreeView {
94  explicit OptionView(const TreeRef& tree) : TreeView(tree) {
95  tree->expect(TK_OPTION);
96  }
97  bool present() const {
98  return tree_->trees().size() > 0;
99  }
100  T get() const {
101  TC_ASSERT(tree_, present());
102  return T(tree_->trees()[0]);
103  }
104  TreeRef map(std::function<TreeRef(const T&)> fn) {
105  return tree_->map([&](TreeRef v) { return fn(T(v)); });
106  }
107 };
108 
109 struct Ident : public TreeView {
110  // each subclass of TreeView provides:
111  // 1. a constructor that takes a TreeRef, and matches it to the right type.
112  explicit Ident(const TreeRef& tree) : TreeView(tree) {
113  tree_->expect(TK_IDENT, 1);
114  }
115  // 2. accessors that get underlying information out of the object
116  // in this case, we return the name of the identifier, and handle the
117  // converstion to a string in the method
118  const std::string& name() const {
119  return subtree(0)->stringValue();
120  }
121 
122  // 3. a static method 'create' that creates the underlying TreeRef object
123  // for every TreeRef kind that has a TreeView, the parser always uses
124  // (e.g.) Ident::create rather than Compound::Create, this means that
125  // changes to the structure of Ident are always made right here rather
126  // than both in the parser and in this code
127  static TreeRef create(const SourceRange& range, const std::string& name) {
128  return Compound::create(TK_IDENT, range, {String::create(name)});
129  }
130 
131  private:
133 };
134 
135 template <int kind>
136 struct ApplyLike : public TreeView {
137  explicit ApplyLike(const TreeRef& tree) : TreeView(tree) {
138  tree_->expect(kind, 2);
139  }
140 
141  Ident name() const {
142  return Ident(subtree(0));
143  }
145  return ListView<TreeRef>(subtree(1));
146  }
147 
148  static TreeRef
150  return Compound::create(kind, range, {name, arguments});
151  }
152 };
155 
156 struct BuiltIn : public TreeView {
157  explicit BuiltIn(const TreeRef& tree) : TreeView(tree) {
158  tree_->expect(TK_BUILT_IN, 3);
159  }
160  const std::string& name() const {
161  return subtree(0)->stringValue();
162  }
164  return ListView<TreeRef>(subtree(1));
165  }
166 
167  TreeRef type() const {
168  return subtree(2);
169  }
170 
171  static TreeRef create(
172  const SourceRange& range,
173  const std::string& name,
175  TreeRef type) {
176  return Compound::create(
177  TK_BUILT_IN, range, {String::create(name), arguments, type});
178  }
179 
180  private:
183  TreeRef type_; // because Halide needs to know the output type
184 };
185 
186 struct TensorType : public TreeView {
187  explicit TensorType(const TreeRef& tree) : TreeView(tree) {
188  tree_->expect(TK_TENSOR_TYPE, 2);
189  }
190  static TreeRef
191  create(const SourceRange& range, TreeRef scalar_type_, TreeRef dims_) {
192  return Compound::create(TK_TENSOR_TYPE, range, {scalar_type_, dims_});
193  }
195  auto scalar_type_ = subtree(0);
196  if (scalar_type_->kind() == TK_IDENT)
197  throw ErrorReport(tree_)
198  << " TensorType has a symbolic ident " << Ident(scalar_type_).name()
199  << " rather than a concrete type";
200  return scalar_type_;
201  }
202  int scalarType() const {
203  return scalarTypeTree()->kind();
204  }
205  // either an Ident or a constant
207  return ListView<TreeRef>(subtree(1));
208  }
209 };
210 
211 struct Param : public TreeView {
212  explicit Param(const TreeRef& tree) : TreeView(tree) {
213  tree_->expect(TK_PARAM, 2);
214  }
216  return Compound::create(TK_PARAM, range, {ident, type});
217  }
218  // when the type of a field is statically know the accessors return
219  // the wrapped type. for instance here we know ident_ is an identifier
220  // so the accessor returns an Ident
221  // this means that clients can do p.ident().name() to get the name of the
222  // parameter.
223  Ident ident() const {
224  return Ident(subtree(0));
225  }
226  // may be TensorType or TK_INFERRED
227  TreeRef type() const {
228  return subtree(1);
229  }
230  bool typeIsInferred() const {
231  return type()->kind() == TK_INFERRED;
232  }
233  // helper for when you know the type is not inferred.
235  return TensorType(type());
236  }
237 };
238 
239 struct Equivalent : public TreeView {
240  explicit Equivalent(const TreeRef& tree) : TreeView(tree) {
241  tree_->expect(TK_EQUIVALENT, 2);
242  }
243  static TreeRef
244  create(const SourceRange& range, const std::string& name, TreeRef accesses) {
245  return Compound::create(
246  TK_EQUIVALENT, range, {String::create(name), accesses});
247  }
248  const std::string& name() const {
249  return subtree(0)->stringValue();
250  }
252  return ListView<TreeRef>(subtree(1));
253  }
254 };
255 
256 struct RangeConstraint : public TreeView {
257  explicit RangeConstraint(const TreeRef& tree) : TreeView(tree) {
258  tree->expect(TK_RANGE_CONSTRAINT, 3);
259  }
260  static TreeRef
262  return Compound::create(TK_RANGE_CONSTRAINT, range, {ident, start, end});
263  }
264  Ident ident() const {
265  return Ident(subtree(0));
266  }
267  TreeRef start() const {
268  return subtree(1);
269  }
270  TreeRef end() const {
271  return subtree(2);
272  }
273 };
274 
275 struct Comprehension : public TreeView {
276  explicit Comprehension(const TreeRef& tree) : TreeView(tree) {
277  tree_->expect(TK_COMPREHENSION, 7);
278  }
279  static TreeRef create(
280  const SourceRange& range,
281  TreeRef ident,
284  TreeRef rhs,
285  TreeRef range_constraints,
287  TreeRef reduction_variables) {
288  return Compound::create(
289  TK_COMPREHENSION,
290  range,
291  {ident,
292  indices,
293  assignment,
294  rhs,
295  range_constraints,
296  equivalent,
297  reduction_variables});
298  }
299  // when the type of a field is statically know the accessors return
300  // the wrapped type. for instance here we know ident_ is an identifier
301  // so the accessor returns an Ident
302  // this means that clients can do p.ident().name() to get the name of the
303  // parameter.
304  Ident ident() const {
305  return Ident(subtree(0));
306  }
308  return ListView<Ident>(subtree(1));
309  }
310  // kind == '=', TK_PLUS_EQ, TK_PLUS_EQ_B, etc.
311  TreeRef assignment() const {
312  return subtree(2);
313  }
314  TreeRef rhs() const {
315  return subtree(3);
316  }
317 
318  // where clauses are either RangeConstraints or Let bindings
320  return ListView<TreeRef>(subtree(4));
321  }
323  return OptionView<Equivalent>(subtree(5));
324  }
326  return ListView<Ident>(subtree(6));
327  }
328 };
329 
330 struct Def : public TreeView {
331  explicit Def(const TreeRef& tree) : TreeView(tree) {
332  tree->expect(TK_DEF, 4);
333  }
335  return Ident(subtree(0));
336  }
337  // ListView helps turn TK_LISTs into vectors of TreeViews
338  // so that we can, e.g., return lists of parameters
340  return ListView<Param>(subtree(1));
341  }
343  return ListView<Param>(subtree(2));
344  }
346  return ListView<Comprehension>(subtree(3));
347  }
348  static TreeRef create(
349  const SourceRange& range,
350  TreeRef name,
351  TreeRef paramlist,
352  TreeRef retlist,
353  TreeRef stmts_list) {
354  return Compound::create(
355  TK_DEF, range, {name, paramlist, retlist, stmts_list});
356  }
357 };
358 
359 struct Select : public TreeView {
360  explicit Select(const TreeRef& tree) : TreeView(tree) {
361  tree_->expect('.', 2);
362  }
363  Ident name() const {
364  return Ident(subtree(0));
365  }
366  int index() const {
367  return subtree(1)->doubleValue();
368  }
370  return Compound::create('.', range, {name, index});
371  }
372 };
373 
374 struct Const : public TreeView {
375  explicit Const(const TreeRef& tree) : TreeView(tree) {
376  tree_->expect(TK_CONST, 2);
377  }
378  double value() const {
379  return subtree(0)->doubleValue();
380  }
381  TreeRef type() const {
382  return subtree(1);
383  }
385  return Compound::create(TK_CONST, range, {value, type});
386  }
387 };
388 
389 struct Cast : public TreeView {
390  explicit Cast(const TreeRef& tree) : TreeView(tree) {
391  tree_->expect(TK_CAST, 2);
392  }
393  TreeRef value() const {
394  return subtree(0);
395  }
396  TreeRef type() const {
397  return subtree(1);
398  }
400  return Compound::create(TK_CAST, range, {value, type});
401  }
402 };
403 
404 struct Let : public TreeView {
405  explicit Let(const TreeRef& tree) : TreeView(tree) {
406  tree_->expect(TK_LET, 2);
407  }
408  Ident name() const {
409  return Ident(subtree(0));
410  }
411  TreeRef rhs() const {
412  return subtree(1);
413  }
415  return Compound::create(TK_LET, range, {name, rhs});
416  }
417 };
418 
419 struct Exists : public TreeView {
420  explicit Exists(const TreeRef& tree) : TreeView(tree) {
421  tree_->expect(TK_EXISTS, 1);
422  }
423  TreeRef exp() const {
424  return subtree(0);
425  }
427  return Compound::create(TK_EXISTS, range, {exp});
428  }
429 };
430 
431 } // namespace lang
static TreeRef create(const SourceRange &range, TreeRef name, TreeRef paramlist, TreeRef retlist, TreeRef stmts_list)
Definition: tree_views.h:348
TreeRef name_
Definition: tree_views.h:132
static TreeRef create(const SourceRange &range, const std::string &name, TreeRef accesses)
Definition: tree_views.h:244
size_t size() const
Definition: tree_views.h:82
static TreeRef create(const SourceRange &range, TreeRef name, TreeRef index)
Definition: tree_views.h:369
TreeRef assignment() const
Definition: tree_views.h:311
ListViewIterator< T > const_iterator
Definition: tree_views.h:69
Select(const TreeRef &tree)
Definition: tree_views.h:360
const std::string & name() const
Definition: tree_views.h:118
static TreeRef create(const SourceRange &range, TreeRef ident, TreeRef type)
Definition: tree_views.h:215
TreeRef rhs() const
Definition: tree_views.h:314
TreeRef rhs() const
Definition: tree_views.h:411
Exists(const TreeRef &tree)
Definition: tree_views.h:420
Ident name() const
Definition: tree_views.h:408
Ident name() const
Definition: tree_views.h:363
static TreeRef create(const SourceRange &range, const std::string &name, TreeRef arguments, TreeRef type)
Definition: tree_views.h:171
Definition: tree_views.h:256
TreeRef start() const
Definition: tree_views.h:267
int index() const
Definition: tree_views.h:366
TreeRef type_
Definition: tree_views.h:183
void operator--()
Definition: tree_views.h:55
TreeList::const_iterator it
Definition: tree_views.h:60
ListView< Param > returns() const
Definition: tree_views.h:342
Ident name()
Definition: tree_views.h:334
Definition: tree_views.h:374
ListViewIterator(TreeList::const_iterator it)
Definition: tree_views.h:45
Ident name() const
Definition: tree_views.h:141
Ident(const TreeRef &tree)
Definition: tree_views.h:112
TensorType tensorType() const
Definition: tree_views.h:234
TreeRef tree() const
Definition: tree_views.h:26
const std::string & name() const
Definition: tree_views.h:160
Definition: tree_views.h:24
Comprehension(const TreeRef &tree)
Definition: tree_views.h:276
Param(const TreeRef &tree)
Definition: tree_views.h:212
bool typeIsInferred() const
Definition: tree_views.h:230
T operator[](size_t i) const
Definition: tree_views.h:76
TreeRef type() const
Definition: tree_views.h:227
TreeRef name_
Definition: tree_views.h:181
static TreeRef create(const SourceRange &range, TreeRef ident, TreeRef start, TreeRef end)
Definition: tree_views.h:261
TreeRef arguments_
Definition: tree_views.h:182
Definition: tree_views.h:389
Definition: tree_views.h:275
ListViewIterator< T > iterator
Definition: tree_views.h:68
static TreeRef create(const SourceRange &range, TreeRef name, TreeRef arguments)
Definition: tree_views.h:149
Definition: tree_views.h:109
Definition: tree_views.h:359
ListView< TreeRef > dims() const
Definition: tree_views.h:206
static TreeRef create(Args &&...args)
Definition: tree.h:100
RangeConstraint(const TreeRef &tree)
Definition: tree_views.h:257
Definition: lexer.h:303
static TreeRef create(const SourceRange &range, TreeRef value, TreeRef type)
Definition: tree_views.h:399
static TreeRef create(const SourceRange &range, TreeRef name, TreeRef rhs)
Definition: tree_views.h:414
static TreeRef create(const SourceRange &range, TreeRef value, TreeRef type)
Definition: tree_views.h:384
Cast(const TreeRef &tree)
Definition: tree_views.h:390
ListView< TreeRef > whereClauses() const
Definition: tree_views.h:319
Definition: tree_views.h:330
ListView< TreeRef > arguments() const
Definition: tree_views.h:144
Definition: tree_views.h:136
Definition: tree_views.h:404
Const(const TreeRef &tree)
Definition: tree_views.h:375
void operator++()
Definition: tree_views.h:52
Let(const TreeRef &tree)
Definition: tree_views.h:405
OptionView< Equivalent > equivalent() const
Definition: tree_views.h:322
Ident ident() const
Definition: tree_views.h:264
const SourceRange & range() const
Definition: tree_views.h:29
TreeRef map(std::function< TreeRef(const T &)> fn)
Definition: tree_views.h:104
static TreeRef create(const SourceRange &range, TreeRef ident, TreeRef indices, TreeRef assignment, TreeRef rhs, TreeRef range_constraints, TreeRef equivalent, TreeRef reduction_variables)
Definition: tree_views.h:279
TreeRef exp() const
Definition: tree_views.h:423
TreeRef type() const
Definition: tree_views.h:167
TreeRef type() const
Definition: tree_views.h:396
TreeRef value() const
Definition: tree_views.h:393
BuiltIn(const TreeRef &tree)
Definition: tree_views.h:157
Definition: tree_views.h:44
ListView< Comprehension > statements() const
Definition: tree_views.h:345
std::vector< TreeRef > TreeList
Definition: tree.h:45
double value() const
Definition: tree_views.h:378
OptionView(const TreeRef &tree)
Definition: tree_views.h:94
ListView< Param > params() const
Definition: tree_views.h:339
TreeRef tree_
Definition: tree_views.h:40
Equivalent(const TreeRef &tree)
Definition: tree_views.h:240
T operator*() const
Definition: tree_views.h:49
ListView< Ident > indices() const
Definition: tree_views.h:307
static TreeRef create(const SourceRange &range, const std::string &name)
Definition: tree_views.h:127
TreeRef subtree(size_t i) const
Definition: tree_views.h:37
Definition: tree_views.h:93
static TreeRef create(const SourceRange &range, TreeRef scalar_type_, TreeRef dims_)
Definition: tree_views.h:191
static TreeRef create(const SourceRange &range, TreeList elements)
Definition: tree_views.h:85
iterator end() const
Definition: tree_views.h:73
ApplyLike(const TreeRef &tree)
Definition: tree_views.h:137
Definition: tree_views.h:64
ListView(const TreeRef &tree)
Definition: tree_views.h:65
Definition: tree_views.h:186
TreeView(const TreeRef &tree_)
Definition: tree_views.h:25
ListView< TreeRef > accesses() const
Definition: tree_views.h:251
ListView< Ident > reductionVariables() const
Definition: tree_views.h:325
Definition: tree_views.h:156
int scalarType() const
Definition: tree_views.h:202
Definition: tree_views.h:239
TreeRef scalarTypeTree() const
Definition: tree_views.h:194
Definition: error_report.h:22
Definition: tree_views.h:419
bool operator!=(const ListViewIterator &rhs) const
Definition: tree_views.h:46
static TreeRef create(int kind, const SourceRange &range_, TreeList &&trees_)
Definition: tree.h:155
Definition: tree_views.h:211
ListView< TreeRef > arguments() const
Definition: tree_views.h:163
const std::string & name() const
Definition: tree_views.h:248
TensorType(const TreeRef &tree)
Definition: tree_views.h:187
static TreeRef create(const SourceRange &range, TreeRef exp)
Definition: tree_views.h:426
iterator begin() const
Definition: tree_views.h:70
std::shared_ptr< Tree > TreeRef
Definition: tree.h:44
bool present() const
Definition: tree_views.h:97
Ident ident() const
Definition: tree_views.h:223
TreeRef map(std::function< TreeRef(const T &)> fn)
Definition: tree_views.h:79
Ident ident() const
Definition: tree_views.h:304
TreeRef end() const
Definition: tree_views.h:270
#define TC_ASSERT(ctx, cond)
Definition: error_report.h:55
TreeRef type() const
Definition: tree_views.h:381
Def(const TreeRef &tree)
Definition: tree_views.h:331