Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
convolution.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include "tc/library/common.h"
19 
20 namespace tc {
21 
22 constexpr static auto CONVOLUTION2D_TC_NAME = "convolution";
23 
24 constexpr static auto CONVOLUTION2D_GRAD_TC_NAME = "convolution2dGrad";
25 
26 namespace {
27 constexpr static auto CONVOLUTION2D_TC = R"TC( def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) B) -> (O) { O(n, m, h, w) +=! I(n, c, ${sh} * h + kh, ${sw} * w + kw) * W1(m, c, kh, kw) O(n, m, h, w) = O(n, m, h, w) + B(m) } )TC";
28 
29 constexpr static auto CONVOLUTION2D_GRAD_TC = R"TC( def convolution2dGrad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) O_grad) -> (I_grad, W1_grad, B_grad) { I_grad(n, c, h, w) +=! O_grad(n, m, ${sh} * h - kh, ${sw} * w - kw) * W1(m, c, kh, kw) W1_grad(m, c, kh, kw) +=! O_grad(n, m, ${sh} * h - kh, ${sw} * w - kw) * I(n, c, h, w) B_grad(m) +=! O_grad(n, m, h, w) } )TC";
30 } // namespace
31 
32 std::string makeConvolution2DTc(int strideH, int strideW) {
33  CHECK(strideH > 0 && strideW > 0) << "Stride must be greater than 0";
34  std::string tcStr;
35  tcStr = CONVOLUTION2D_TC;
36  tcStr = replaceString(tcStr, "${sh}", std::to_string(strideH));
37  tcStr = replaceString(tcStr, "${sw}", std::to_string(strideW));
38  return tcStr;
39 }
40 
41 std::string makeConvolution2DGradTc(int strideH, int strideW) {
42  CHECK(strideH > 0 && strideW > 0) << "Stride must be greater than 0";
43  std::string tcStr;
44  tcStr = CONVOLUTION2D_GRAD_TC;
45  tcStr = replaceString(tcStr, "${sh}", std::to_string(strideH));
46  tcStr = replaceString(tcStr, "${sw}", std::to_string(strideW));
47  return tcStr;
48 }
49 } // namespace tc
50 
std::string makeConvolution2DTc(int strideH, int strideW)
Definition: convolution.h:43
std::string replaceString(std::string str, const std::string &search, const std::string &replace)
Definition: common.h:20
std::string makeConvolution2DGradTc(int strideH, int strideW)
Definition: convolution.h:52