Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
lexer.h
Go to the documentation of this file.
1 
16 #pragma once
17 #include <assert.h>
18 #include <algorithm>
19 #include <iostream>
20 #include <memory>
21 #include <sstream>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 
26 namespace lang {
27 
28 // single character tokens are just the character itself '+'
29 // multi-character tokens need an entry here
30 // if the third entry is not the empty string, it is used
31 // in the lexer to match this token.
32 
33 // These kinds are also used in Tree.h as the kind of the AST node.
34 // Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the
35 // lexer.
36 
37 #define TC_FORALL_TOKEN_KINDS(_) \
38  _(TK_EOF, "eof", "") \
39  _(TK_NUMBER, "number", "") \
40  _(TK_BOOL_VALUE, "bool_value", "") \
41  _(TK_MIN, "min", "min") \
42  _(TK_MAX, "max", "max") \
43  _(TK_WHERE, "where", "where") \
44  _(TK_FLOAT, "float", "float") \
45  _(TK_DOUBLE, "double", "double") \
46  _(TK_DEF, "def", "def") \
47  _(TK_ARROW, "arrow", "->") \
48  _(TK_EQUIVALENT, "equivalent", "<=>") \
49  _(TK_IDENT, "ident", "") \
50  _(TK_STRING, "string", "") \
51  _(TK_CONST, "const", "") \
52  _(TK_LIST, "list", "") \
53  _(TK_OPTION, "option", "") \
54  _(TK_APPLY, "apply", "") \
55  _(TK_COMPREHENSION, "comprehension", "") \
56  _(TK_TENSOR_TYPE, "tensor_type", "") \
57  _(TK_RANGE_CONSTRAINT, "range_constraint", "") \
58  _(TK_PARAM, "param", "") \
59  _(TK_INFERRED, "inferred", "") \
60  _(TK_ACCESS, "access", "") \
61  _(TK_BUILT_IN, "built-in", "") \
62  _(TK_PLUS_EQ, "plus_eq", "+=") \
63  _(TK_TIMES_EQ, "times_eq", "*=") \
64  _(TK_MIN_EQ, "min_eq", "min=") \
65  _(TK_MAX_EQ, "max_eq", "max=") \
66  _(TK_PLUS_EQ_B, "plus_eq_b", "+=!") \
67  _(TK_TIMES_EQ_B, "times_eq_b", "*=!") \
68  _(TK_MIN_EQ_B, "min_eq_b", "min=!") \
69  _(TK_MAX_EQ_B, "max_eq_b", "max=!") \
70  _(TK_INT8, "int8", "int8") \
71  _(TK_INT16, "int16", "int16") \
72  _(TK_INT32, "int32", "int32") \
73  _(TK_INT64, "int64", "int64") \
74  _(TK_UINT8, "uint8", "uint8") \
75  _(TK_UINT16, "uint16", "uint16") \
76  _(TK_UINT32, "uint32", "uint32") \
77  _(TK_UINT64, "uint64", "uint64") \
78  _(TK_BOOL, "bool", "bool") \
79  _(TK_CAST, "cast", "") \
80  _(TK_IN, "in", "in") \
81  _(TK_GE, "ge", ">=") \
82  _(TK_LE, "le", "<=") \
83  _(TK_EQ, "eq", "==") \
84  _(TK_NE, "neq", "!=") \
85  _(TK_AND, "and", "&&") \
86  _(TK_OR, "or", "||") \
87  _(TK_LET, "let", "") \
88  _(TK_EXISTS, "exists", "exists")
89 
90 static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!";
91 
92 enum TokenKind {
93  // we use characters to represent themselves so skip all valid characters
94  // before
95  // assigning enum values to multi-char tokens.
97 #define DEFINE_TOKEN(tok, _, _2) tok,
99 #undef DEFINE_TOKEN
100 };
101 
102 std::string kindToString(int kind);
103 
104 // nested hash tables that indicate char-by-char what is a valid token.
105 struct TokenTrie;
106 using TokenTrieRef = std::unique_ptr<TokenTrie>;
107 struct TokenTrie {
108  TokenTrie() : kind(0) {}
109  void insert(const char* str, int tok) {
110  if (*str == '\0') {
111  assert(kind == 0);
112  kind = tok;
113  return;
114  }
115  auto& entry = children[*str];
116  if (entry == nullptr) {
117  entry.reset(new TokenTrie());
118  }
119  entry->insert(str + 1, tok);
120  }
121  int kind; // 0 == invalid token
122  std::unordered_map<char, TokenTrieRef> children;
123 };
124 
125 // stuff that is shared against all TC lexers/parsers and is initialized only
126 // once.
129  // listed in increasing order of precedence
130  std::vector<std::vector<int>> binary_ops = {
131  {'?'},
132  {TK_OR},
133  {TK_AND},
134  {'>', '<', TK_LE, TK_GE, TK_EQ, TK_NE},
135  {'+', '-'},
136  {'*', '/'},
137  };
138  std::vector<std::vector<int>> unary_ops = {
139  {'-', '!'},
140  };
141 
142  std::stringstream ss;
143  for (const char* c = valid_single_char_tokens; *c; c++) {
144  const char str[] = {*c, '\0'};
145  head->insert(str, *c);
146  }
147 
148 #define ADD_CASE(tok, _, tokstring) \
149  if (*tokstring != '\0') { \
150  head->insert(tokstring, tok); \
151  }
153 #undef ADD_CASE
154 
155  // precedence starts at 1 so that there is always a 0 precedence
156  // less than any other precedence
157  int prec = 1;
158  for (auto& group : binary_ops) {
159  for (auto& element : group) {
160  binary_prec[element] = prec;
161  }
162  prec++;
163  }
164  // unary ops
165  for (auto& group : unary_ops) {
166  for (auto& element : group) {
167  unary_prec[element] = prec;
168  }
169  prec++;
170  }
171  }
172  bool isNumber(const std::string& str, size_t start, size_t* len) {
173  char first = str[start];
174  // strtod allows numbers to start with + or -
175  // http://en.cppreference.com/w/cpp/string/byte/strtof
176  // but we want only the number part, otherwise 1+3 will turn into two
177  // adjacent numbers in the lexer
178  if (first == '-' || first == '+')
179  return false;
180  const char* startptr = str.c_str() + start;
181  char* endptr;
182  std::strtod(startptr, &endptr);
183  *len = endptr - startptr;
184  return *len > 0;
185  }
186  // find the longest match of str.substring(pos) against a token, return true
187  // if successful
188  // filling in kind, start,and len
189  bool match(
190  const std::string& str,
191  size_t pos,
192  int* kind,
193  size_t* start,
194  size_t* len) {
195  // skip whitespace
196  while (pos < str.size() && isspace(str[pos]))
197  pos++;
198  // skip comments
199  if (pos < str.size() && str[pos] == '#') {
200  while (pos < str.size() && str[pos] != '\n')
201  pos++;
202  // tail call, handle whitespace and more comments
203  return match(str, pos, kind, start, len);
204  }
205  *start = pos;
206  if (pos == str.size()) {
207  *kind = TK_EOF;
208  *len = 0;
209  return true;
210  }
211  // check for a valid number
212  if (isNumber(str, pos, len)) {
213  *kind = TK_NUMBER;
214  return true;
215  }
216  // check for either an ident or a token
217  // ident tracks whether what we have scanned so far could be an identifier
218  // matched indicates if we have found any match.
219  bool matched = false;
220  bool ident = true;
221  TokenTrie* cur = head.get();
222  for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); i++) {
223  ident = ident && validIdent(i, str[pos + i]);
224  if (ident) {
225  matched = true;
226  *len = i + 1;
227  *kind = TK_IDENT;
228  }
229  // check for token second, so that e.g. 'max' matches the token TK_MAX
230  // rather the
231  // identifier 'max'
232  if (cur) {
233  auto it = cur->children.find(str[pos + i]);
234  cur = (it == cur->children.end()) ? nullptr : it->second.get();
235  if (cur && cur->kind != 0) {
236  matched = true;
237  *len = i + 1;
238  *kind = cur->kind;
239  }
240  }
241  }
242  return matched;
243  }
244  bool isUnary(int kind, int* prec) {
245  auto it = unary_prec.find(kind);
246  if (it != unary_prec.end()) {
247  *prec = it->second;
248  return true;
249  }
250  return false;
251  }
252  bool isBinary(int kind, int* prec) {
253  auto it = binary_prec.find(kind);
254  if (it != binary_prec.end()) {
255  *prec = it->second;
256  return true;
257  }
258  return false;
259  }
260  bool isRightAssociative(int kind) {
261  switch (kind) {
262  case '?':
263  return true;
264  default:
265  return false;
266  }
267  }
268  bool isScalarType(int kind) {
269  switch (kind) {
270  case TK_INT8:
271  case TK_INT16:
272  case TK_INT32:
273  case TK_INT64:
274  case TK_UINT8:
275  case TK_UINT16:
276  case TK_UINT32:
277  case TK_UINT64:
278  case TK_BOOL:
279  case TK_FLOAT:
280  case TK_DOUBLE:
281  return true;
282  default:
283  return false;
284  }
285  }
286 
287  private:
288  bool validIdent(size_t i, char n) {
289  return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
290  }
292  std::unordered_map<int, int>
293  unary_prec; // map from token to its unary precedence
294  std::unordered_map<int, int>
295  binary_prec; // map from token to its binary precedence
296 };
297 
299 
300 // a range of a shared string 'file_' with functions to help debug by highlight
301 // that
302 // range.
303 struct SourceRange {
305  const std::shared_ptr<std::string>& file_,
306  size_t start_,
307  size_t end_)
308  : file_(file_), start_(start_), end_(end_) {}
309  const std::string text() const {
310  return file().substr(start(), end() - start());
311  }
312  size_t size() const {
313  return end() - start();
314  }
315  void highlight(std::ostream& out) const {
316  const std::string& str = file();
317  size_t begin = start();
318  size_t end = start();
319  while (begin > 0 && str[begin - 1] != '\n')
320  --begin;
321  while (end < str.size() && str[end] != '\n')
322  ++end;
323  out << str.substr(0, end) << "\n";
324  out << std::string(start() - begin, ' ');
325  size_t len = std::min(size(), end - start());
326  out << std::string(len, '~')
327  << (len < size() ? "... <--- HERE" : " <--- HERE");
328  out << str.substr(end);
329  if (str.size() > 0 && str.back() != '\n')
330  out << "\n";
331  }
332  const std::string& file() const {
333  return *file_;
334  }
335  const std::shared_ptr<std::string>& file_ptr() const {
336  return file_;
337  }
338  size_t start() const {
339  return start_;
340  }
341  size_t end() const {
342  return end_;
343  }
344 
345  private:
346  std::shared_ptr<std::string> file_;
347  size_t start_;
348  size_t end_;
349 };
350 
351 struct Token {
352  int kind;
354  Token(int kind, const SourceRange& range) : kind(kind), range(range) {}
355  double doubleValue() {
356  assert(TK_NUMBER == kind);
357  size_t idx;
358  double r = std::stod(text(), &idx);
359  assert(idx == range.size());
360  return r;
361  }
362  std::string text() {
363  return range.text();
364  }
365  std::string kindString() const {
366  return kindToString(kind);
367  }
368 };
369 
370 struct Lexer {
371  std::shared_ptr<std::string> file;
372  Lexer(const std::string& str)
373  : file(std::make_shared<std::string>(str)),
374  pos(0),
375  cur_(TK_EOF, SourceRange(file, 0, 0)),
377  next();
378  }
379  bool nextIf(int kind) {
380  if (cur_.kind != kind)
381  return false;
382  next();
383  return true;
384  }
386  if (!lookahead_) {
387  lookahead_.reset(new Token(lex()));
388  }
389  return *lookahead_;
390  }
392  auto r = cur_;
393  if (lookahead_) {
394  cur_ = *lookahead_;
395  lookahead_.reset();
396  } else {
397  cur_ = lex();
398  }
399  return r;
400  }
401  void reportError(const std::string& what, const Token& t);
402  void reportError(const std::string& what) {
403  reportError(what, cur_);
404  }
405  Token expect(int kind) {
406  if (cur_.kind != kind) {
407  reportError(kindToString(kind));
408  }
409  return next();
410  }
411  Token& cur() {
412  return cur_;
413  }
414 
415  private:
416  Token lex() {
417  int kind;
418  size_t start;
419  size_t length;
420  assert(file);
421  if (!shared.match(*file, pos, &kind, &start, &length)) {
422  reportError(
423  "a valid token",
424  Token((*file)[start], SourceRange(file, start, start + 1)));
425  }
426  auto t = Token(kind, SourceRange(file, start, start + length));
427  pos = start + length;
428  return t;
429  }
430  size_t pos;
432  std::unique_ptr<Token> lookahead_;
434 };
435 } // namespace lang
std::unordered_map< char, TokenTrieRef > children
Definition: lexer.h:122
bool isUnary(int kind, int *prec)
Definition: lexer.h:244
Token lex()
Definition: lexer.h:416
void highlight(std::ostream &out) const
Definition: lexer.h:315
size_t size() const
Definition: lexer.h:312
bool isBinary(int kind, int *prec)
Definition: lexer.h:252
Token lookahead()
Definition: lexer.h:385
std::unique_ptr< Token > lookahead_
Definition: lexer.h:432
size_t end_
Definition: lexer.h:348
std::shared_ptr< std::string > file_
Definition: lexer.h:346
size_t start() const
Definition: lexer.h:338
bool validIdent(size_t i, char n)
Definition: lexer.h:288
int kind
Definition: lexer.h:121
Definition: lexer.h:351
SourceRange(const std::shared_ptr< std::string > &file_, size_t start_, size_t end_)
Definition: lexer.h:304
TokenTrieRef head
Definition: lexer.h:291
#define TC_FORALL_TOKEN_KINDS(_)
Definition: lexer.h:37
Definition: lexer.h:127
Token expect(int kind)
Definition: lexer.h:405
Token cur_
Definition: lexer.h:431
Definition: lexer.h:303
bool match(const std::string &str, size_t pos, int *kind, size_t *start, size_t *len)
Definition: lexer.h:189
Token & cur()
Definition: lexer.h:411
SharedParserData()
Definition: lexer.h:128
bool isScalarType(int kind)
Definition: lexer.h:268
Token next()
Definition: lexer.h:391
Lexer(const std::string &str)
Definition: lexer.h:372
#define ADD_CASE(tok, _, tokstring)
std::string text()
Definition: lexer.h:362
bool isNumber(const std::string &str, size_t start, size_t *len)
Definition: lexer.h:172
std::unordered_map< int, int > binary_prec
Definition: lexer.h:295
TokenTrie()
Definition: lexer.h:108
Definition: lexer.h:96
int kind
Definition: lexer.h:352
std::shared_ptr< std::string > file
Definition: lexer.h:371
size_t pos
Definition: lexer.h:430
Definition: lexer.h:370
Definition: lexer.h:107
void reportError(const std::string &what, const Token &t)
size_t start_
Definition: lexer.h:347
void reportError(const std::string &what)
Definition: lexer.h:402
TokenKind
Definition: lexer.h:92
std::string kindString() const
Definition: lexer.h:365
const std::string text() const
Definition: lexer.h:309
std::string kindToString(int kind)
bool nextIf(int kind)
Definition: lexer.h:379
bool isRightAssociative(int kind)
Definition: lexer.h:260
SourceRange range
Definition: lexer.h:353
double doubleValue()
Definition: lexer.h:355
SharedParserData & sharedParserData()
const std::string & file() const
Definition: lexer.h:332
Token(int kind, const SourceRange &range)
Definition: lexer.h:354
void insert(const char *str, int tok)
Definition: lexer.h:109
size_t end() const
Definition: lexer.h:341
const std::shared_ptr< std::string > & file_ptr() const
Definition: lexer.h:335
std::unique_ptr< TokenTrie > TokenTrieRef
Definition: lexer.h:106
std::unordered_map< int, int > unary_prec
Definition: lexer.h:293
SharedParserData & shared
Definition: lexer.h:433
#define DEFINE_TOKEN(tok, _, _2)
Definition: lexer.h:97