Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
parser.h
Go to the documentation of this file.
1 
16 #pragma once
17 #include "tc/lang/lexer.h"
18 #include "tc/lang/tree.h"
19 #include "tc/lang/tree_views.h"
20 
21 namespace lang {
22 
23 struct Parser {
24  Parser(const std::string& str) : L(str), shared(sharedParserData()) {}
25 
27  auto t = L.expect(TK_IDENT);
28  // whenever we parse something that has a TreeView type we always
29  // use its create method so that the accessors and the constructor
30  // of the Compound tree are in the same place.
31  return Ident::create(t.range, t.text());
32  }
34  auto t = L.expect(TK_NUMBER);
35  auto type = (t.text().find('.') != std::string::npos ||
36  t.text().find('e') != std::string::npos)
37  ? TK_FLOAT
38  : TK_INT32;
39  return Const::create(t.range, d(t.doubleValue()), c(type, t.range, {}));
40  }
41  // things like a 1.0 or a(4) that are not unary/binary expressions
42  // and have higher precedence than all of them
44  TreeRef prefix;
45  if (L.cur().kind == TK_NUMBER) {
46  prefix = parseConst();
47  } else if (L.cur().kind == '(') {
48  L.next();
49  prefix = parseExp();
50  L.expect(')');
51  } else if (shared.isScalarType(L.cur().kind)) {
52  // cast operation float(4 + a)
53  auto type = parseScalarType();
54  L.expect('(');
55  auto value = parseExp();
56  L.expect(')');
57  return Cast::create(type->range(), value, type);
58  } else {
59  prefix = parseIdent();
60  auto range = L.cur().range;
61  if (L.cur().kind == '(') {
62  prefix = Apply::create(range, prefix, parseExpList());
63  } else if (L.nextIf('.')) {
64  auto t = L.expect(TK_NUMBER);
65  prefix = Select::create(range, prefix, d(t.doubleValue()));
66  }
67  }
68 
69  return prefix;
70  }
71  TreeRef
72  parseTrinary(TreeRef cond, const SourceRange& range, int binary_prec) {
73  auto true_branch = parseExp();
74  L.expect(':');
75  auto false_branch = parseExp(binary_prec);
76  return c('?', range, {cond, true_branch, false_branch});
77  }
78  // parse the longest expression whose binary operators have
79  // precedence strictly greater than 'precedence'
80  // precedence == 0 will parse _all_ expressions
81  // this is the core loop of 'top-down precedence parsing'
82  TreeRef parseExp(int precedence = 0) {
83  TreeRef prefix = nullptr;
84  int unary_prec;
85  if (shared.isUnary(L.cur().kind, &unary_prec)) {
86  auto kind = L.cur().kind;
87  auto pos = L.cur().range;
88  L.next();
89  prefix = c(kind, pos, {parseExp(unary_prec)});
90  } else {
91  prefix = parseBaseExp();
92  }
93  int binary_prec;
94  while (shared.isBinary(L.cur().kind, &binary_prec)) {
95  if (binary_prec <= precedence) // not allowed to parse something which is
96  // not greater than 'precedenc'
97  break;
98 
99  int kind = L.cur().kind;
100  auto pos = L.cur().range;
101  L.next();
102  if (shared.isRightAssociative(kind))
103  binary_prec--;
104 
105  // special case for trinary operator
106  if (kind == '?') {
107  prefix = parseTrinary(prefix, pos, binary_prec);
108  continue;
109  }
110 
111  prefix = c(kind, pos, {prefix, parseExp(binary_prec)});
112  }
113  return prefix;
114  }
115  TreeRef
116  parseList(int begin, int sep, int end, std::function<TreeRef(int)> parse) {
117  auto r = L.cur().range;
118  L.expect(begin);
119  TreeList elements;
120  if (L.cur().kind != end) {
121  int i = 0;
122  do {
123  elements.push_back(parse(i++));
124  } while (L.nextIf(sep));
125  }
126  L.expect(end);
127  return List::create(r, std::move(elements));
128  }
129  TreeRef parseNonEmptyList(int sep, std::function<TreeRef(int)> parse) {
130  TreeList elements;
131  int i = 0;
132  do {
133  elements.push_back(parse(i++));
134  } while (L.nextIf(sep));
135  auto range = elements.at(0)->range();
136  return List::create(range, std::move(elements));
137  }
139  return parseList('(', ',', ')', [&](int i) { return parseExp(); });
140  }
142  return parseList('(', ',', ')', [&](int i) { return parseIdent(); });
143  }
145  auto id = parseIdent();
146  L.expect(TK_IN);
147  auto l = parseExp();
148  L.expect(':');
149  auto r = parseExp();
150  return RangeConstraint::create(id->range(), id, l, r);
151  }
153  auto ident = parseIdent();
154  L.expect('=');
155  auto exp = parseExp();
156  return Let::create(ident->range(), ident, exp);
157  }
159  auto lookahead = L.lookahead();
160  if (lookahead.kind == '=') {
161  return parseLetBinding();
162  } else if (lookahead.kind == TK_IN) {
163  return parseRangeConstraint();
164  } else {
165  L.expect(TK_EXISTS);
166  auto exp = parseExp();
167  return Exists::create(exp->range(), {exp});
168  }
169  }
171  if (L.cur().kind == TK_IDENT) {
172  auto ident = parseIdent();
173  return Param::create(
174  ident->range(), ident, c(TK_INFERRED, ident->range(), {}));
175  }
176  auto typ = parseType();
177  auto ident = parseIdent();
178  return Param::create(typ->range(), ident, typ);
179  }
181  if (L.nextIf(TK_WHERE)) {
182  return parseNonEmptyList(',', [&](int i) { return parseWhereClause(); });
183  }
184  return List::create(L.cur().range, {});
185  }
187  auto r = L.cur().range;
188  if (L.nextIf(TK_EQUIVALENT)) {
189  auto name = L.expect(TK_IDENT);
190  auto accesses = parseExpList();
191  return c(TK_OPTION, r, {Equivalent::create(r, name.text(), accesses)});
192  }
193  return c(TK_OPTION, r, {});
194  }
195  // =, +=, +=!, etc.
197  switch (L.cur().kind) {
198  case TK_PLUS_EQ:
199  case TK_TIMES_EQ:
200  case TK_MIN_EQ:
201  case TK_MAX_EQ:
202  case TK_PLUS_EQ_B:
203  case TK_TIMES_EQ_B:
204  case TK_MIN_EQ_B:
205  case TK_MAX_EQ_B:
206  case '=':
207  return c(L.next().kind, L.cur().range, {});
208  default:
209  L.reportError("a valid assignment operator");
210  // unreachable, silence warnings
211  return nullptr;
212  }
213  }
215  auto ident = parseIdent();
216  TreeRef list = parseOptionalIdentList();
217  auto assign = parseAssignment();
218  auto rhs = parseExp();
219  TreeRef equivalent_statement = parseEquivalent();
220  TreeRef range_statements = parseWhereClauses();
221  TreeRef empty_reduction_variables = c(TK_LIST, ident->range(), {});
222  return Comprehension::create(
223  ident->range(),
224  ident,
225  list,
226  assign,
227  rhs,
228  range_statements,
229  equivalent_statement,
230  empty_reduction_variables);
231  }
233  if (shared.isScalarType(L.cur().kind)) {
234  auto t = L.next();
235  return c(t.kind, t.range, {});
236  }
237  L.reportError("a scalar type");
238  return nullptr;
239  }
241  TreeRef list = nullptr;
242  if (L.cur().kind == '(') {
243  list = parseIdentList();
244  } else {
245  list = List::create(L.cur().range, {});
246  }
247  return list;
248  }
250  return parseList('(', ',', ')', [&](int i) {
251  if (L.cur().kind == TK_NUMBER) {
252  return parseConst();
253  } else {
254  return parseIdent();
255  }
256  });
257  }
259  TreeRef list = nullptr;
260  if (L.cur().kind == '(') {
261  list = parseDimList();
262  } else {
263  list = List::create(L.cur().range, {});
264  }
265  return list;
266  }
268  auto st = parseScalarType();
269  auto list = parseOptionalDimList();
270  return TensorType::create(st->range(), st, list);
271  }
273  L.expect(TK_DEF);
274  auto name = parseIdent();
275  auto paramlist =
276  parseList('(', ',', ')', [&](int i) { return parseParam(); });
277  L.expect(TK_ARROW);
278  auto retlist =
279  parseList('(', ',', ')', [&](int i) { return parseParam(); });
280  L.expect('{');
281  auto r = L.cur().range;
282  TreeList stmts;
283  while (!L.nextIf('}')) {
284  stmts.push_back(parseStmt());
285  }
286  auto stmts_list = List::create(r, std::move(stmts));
287  return Def::create(name->range(), name, paramlist, retlist, stmts_list);
288  }
289 
291 
292  private:
293  // short helpers to create nodes
294  TreeRef d(double v) {
295  return Number::create(v);
296  }
297  TreeRef s(const std::string& s) {
298  return String::create(s);
299  }
300  TreeRef c(int kind, const SourceRange& range, TreeList&& trees) {
301  return Compound::create(kind, range, std::move(trees));
302  }
304 };
305 } // namespace lang
static TreeRef create(const SourceRange &range, TreeRef name, TreeRef paramlist, TreeRef retlist, TreeRef stmts_list)
Definition: tree_views.h:348
static TreeRef create(const SourceRange &range, const std::string &name, TreeRef accesses)
Definition: tree_views.h:244
static TreeRef create(const SourceRange &range, TreeRef name, TreeRef index)
Definition: tree_views.h:369
TreeRef parseExpList()
Definition: parser.h:138
static TreeRef create(const SourceRange &range, TreeRef ident, TreeRef type)
Definition: tree_views.h:215
TreeRef s(const std::string &s)
Definition: parser.h:297
TreeRef parseDimList()
Definition: parser.h:249
static TreeRef create(Args &&...args)
Definition: tree.h:113
TreeRef parseOptionalDimList()
Definition: parser.h:258
TreeRef parseStmt()
Definition: parser.h:214
TreeRef parseEquivalent()
Definition: parser.h:186
TreeRef parseLetBinding()
Definition: parser.h:152
TreeRef parseAssignment()
Definition: parser.h:196
TreeRef parseNonEmptyList(int sep, std::function< TreeRef(int)> parse)
Definition: parser.h:129
Parser(const std::string &str)
Definition: parser.h:24
TreeRef parseBaseExp()
Definition: parser.h:43
Lexer L
Definition: parser.h:290
TreeRef parseIdent()
Definition: parser.h:26
static TreeRef create(const SourceRange &range, TreeRef ident, TreeRef start, TreeRef end)
Definition: tree_views.h:261
Definition: lexer.h:127
Token expect(int kind)
Definition: lexer.h:405
static TreeRef create(const SourceRange &range, TreeRef name, TreeRef arguments)
Definition: tree_views.h:149
static TreeRef create(Args &&...args)
Definition: tree.h:100
Definition: lexer.h:303
static TreeRef create(const SourceRange &range, TreeRef value, TreeRef type)
Definition: tree_views.h:399
static TreeRef create(const SourceRange &range, TreeRef name, TreeRef rhs)
Definition: tree_views.h:414
static TreeRef create(const SourceRange &range, TreeRef value, TreeRef type)
Definition: tree_views.h:384
TreeRef parseRangeConstraint()
Definition: parser.h:144
TreeRef parseType()
Definition: parser.h:267
TreeRef parseIdentList()
Definition: parser.h:141
TreeRef parseFunction()
Definition: parser.h:272
SharedParserData & shared
Definition: parser.h:303
TreeRef parseWhereClauses()
Definition: parser.h:180
TreeRef parseWhereClause()
Definition: parser.h:158
TreeRef parseParam()
Definition: parser.h:170
TreeRef parseExp(int precedence=0)
Definition: parser.h:82
static TreeRef create(const SourceRange &range, TreeRef ident, TreeRef indices, TreeRef assignment, TreeRef rhs, TreeRef range_constraints, TreeRef equivalent, TreeRef reduction_variables)
Definition: tree_views.h:279
Definition: parser.h:23
TreeRef parseOptionalIdentList()
Definition: parser.h:240
std::vector< TreeRef > TreeList
Definition: tree.h:45
TreeRef parseScalarType()
Definition: parser.h:232
Definition: lexer.h:370
static TreeRef create(const SourceRange &range, const std::string &name)
Definition: tree_views.h:127
static TreeRef create(const SourceRange &range, TreeRef scalar_type_, TreeRef dims_)
Definition: tree_views.h:191
static TreeRef create(const SourceRange &range, TreeList elements)
Definition: tree_views.h:85
TreeRef c(int kind, const SourceRange &range, TreeList &&trees)
Definition: parser.h:300
SharedParserData & sharedParserData()
TreeRef parseTrinary(TreeRef cond, const SourceRange &range, int binary_prec)
Definition: parser.h:72
static TreeRef create(int kind, const SourceRange &range_, TreeList &&trees_)
Definition: tree.h:155
TreeRef d(double v)
Definition: parser.h:294
static TreeRef create(const SourceRange &range, TreeRef exp)
Definition: tree_views.h:426
std::shared_ptr< Tree > TreeRef
Definition: tree.h:44
TreeRef parseConst()
Definition: parser.h:33
TreeRef parseList(int begin, int sep, int end, std::function< TreeRef(int)> parse)
Definition: parser.h:116