Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
matmul.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include "tc/library/common.h"
19 
20 namespace tc {
21 
22 static constexpr auto TC_MATMUL_NAME = "matmul";
23 
24 namespace {
25 static constexpr auto TC_MATMUL = R"TC( def matmul(float(${szA0}, ${szA1}) A, float(${szB0}, ${szB1}) B) -> (O) { O(i, j) +=! A(${itA0}, ${itA1}) * B(${itB0}, ${itB1}) } )TC";
26 } // namespace
27 
28 std::string makeMatmulTc(
29  bool transposeFirst = false,
30  bool transposeSecond = false) {
31  std::string tc(TC_MATMUL);
32  tc = replaceString(tc, "${szA0}", (transposeFirst ? "K" : "N"));
33  tc = replaceString(tc, "${szA1}", (transposeFirst ? "N" : "K"));
34  tc = replaceString(tc, "${szB0}", (transposeSecond ? "M" : "K"));
35  tc = replaceString(tc, "${szB1}", (transposeSecond ? "K" : "M"));
36  tc = replaceString(tc, "${itA0}", (transposeFirst ? "k" : "i"));
37  tc = replaceString(tc, "${itA1}", (transposeFirst ? "i" : "k"));
38  tc = replaceString(tc, "${itB0}", (transposeSecond ? "j" : "k"));
39  tc = replaceString(tc, "${itB1}", (transposeSecond ? "k" : "j"));
40  return tc;
41 }
42 } // namespace tc
43 
std::string makeMatmulTc(bool transposeFirst=false, bool transposeSecond=false)
Definition: matmul.h:32
std::string replaceString(std::string str, const std::string &search, const std::string &replace)
Definition: common.h:20