Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
codegen_cuda.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <sstream>
19 #include <string>
20 #include <unordered_map>
21 
24 #include "tc/external/isl.h"
25 
26 namespace tc {
27 namespace polyhedral {
28 
29 struct CodegenContext;
30 struct CodegenStatementContext;
31 
32 namespace detail {
33 
34 void emitDirectSubscripts(
35  isl::pw_multi_aff subscripts,
36  const CodegenStatementContext& context);
37 
38 std::string toString(isl::pw_aff subscript);
39 
40 isl::pw_aff makeAffFromMappedExpr(
41  const Halide::Expr& expr,
42  const CodegenStatementContext& context);
43 
44 void emitHalideExpr(
45  const Halide::Expr& e,
46  const CodegenStatementContext& context);
47 
48 void emitHalideExpr(
49  const Halide::Expr& e,
50  const CodegenStatementContext& context,
51  const std::map<std::string, std::string>& substitutions);
52 
53 void emitMappedSubscripts(
54  const std::vector<Halide::Expr>& exprs,
55  const CodegenStatementContext& context);
56 
57 void emitMappedArguments(
58  const std::vector<Halide::Expr>& exprs,
59  const CodegenStatementContext& context);
60 
61 void emitMappedTensorAccess(
62  std::string name,
63  const Halide::Internal::IRNode* node,
64  const std::vector<Halide::Expr>& subscripts,
65  const CodegenStatementContext& context);
66 
67 } // namespace detail
68 
69 using IteratorMapsType =
70  std::unordered_map<isl::id, isl::pw_multi_aff, isl::IslIdIslHash>;
71 
72 struct CodegenContext {
73  CodegenContext(
74  std::stringstream& ss_,
75  const MappedScop& s,
76  const IteratorMapsType& i)
77  : ss(ss_), mappedScop(s), iteratorMaps(i) {}
78  CodegenContext(const CodegenContext& c)
79  : ss(c.ss), mappedScop(c.mappedScop), iteratorMaps(c.iteratorMaps) {}
80 
81  const Scop& scop() const {
82  return mappedScop.scop();
83  }
84 
85  std::stringstream& ss;
86  const MappedScop& mappedScop;
87  const IteratorMapsType& iteratorMaps;
88 };
89 
90 struct CodegenStatementContext : CodegenContext {
91  CodegenStatementContext(const CodegenContext& c, isl::id astId)
92  : CodegenContext(c), astNodeId(astId) {}
93  isl::pw_multi_aff iteratorMap() const {
94  return this->iteratorMaps.at(astNodeId);
95  }
96  isl::id statementId() const {
97  return this->iteratorMaps.at(astNodeId).get_tuple_id(isl::dim_type::out);
98  }
99  std::vector<Scop::PromotionInfo> activePromotions() const {
100  auto stmtId = statementId();
101  const auto& promotions = this->scop().activePromotions();
102  if (promotions.count(stmtId) == 0) {
103  return {};
104  }
105  return promotions.at(stmtId);
106  }
107 
108  isl::id astNodeId;
109 };
110 
111 std::string emitCudaKernel(
112  const std::string& specializedName,
113  const MappedScop& scop);
114 
115 } // namespace polyhedral
116 } // namespace tc