Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
group_convolution.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include "tc/library/common.h"
19 
20 namespace tc {
21 
22 constexpr static auto GROUP_CONVOLUTION2D_TC_NAME = "group_convolution";
23 constexpr static auto GROUP_CONVOLUTION2D_GRAD_TC_NAME =
24  "group_convolution2dGrad";
25 
26 namespace {
27 constexpr static auto GROUP_CONVOLUTION2D_TC = R"TC( def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B) -> (O) { O(n, g, f, h, w) +=! I(n, g, c, <sh> * h + kh, <sw> * w + kw) * W1(g, f, c, kh, kw) O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f) } )TC";
28 
29 constexpr static auto GROUP_CONVOLUTION2D_GRAD_TC = R"TC( def group_convolution2dGrad(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(N,F,H,W) O_grad) -> (I_grad, W1_grad, B_grad) { I_grad(n, g, c, h, w) +=! O_grad(n, g, f, <sh> * h - kh, <sw> * w - kw) * W1(g, f, c, kh, kw) W1_grad(g, f, c, kh, kw) +=! O_grad(n, g, f, <sh> * h - kh, <sw> * w - kw) * I(n, g, c, h, w) B_grad(g, f) +=! O_grad(n, g, f, h, w) } )TC";
30 } // namespace
31 
32 std::string makeGroupConvolution2DTc(int strideH, int strideW) {
33  CHECK(strideH > 0 && strideW > 0) << "Stride must be greater than 0";
34  std::string tcStr;
35  tcStr = GROUP_CONVOLUTION2D_TC;
36  tcStr = replaceString(tcStr, "<sh>", std::to_string(strideH));
37  tcStr = replaceString(tcStr, "<sw>", std::to_string(strideW));
38  return tcStr;
39 }
40 
41 std::string makeGroupConvolution2DGradTc(int strideH, int strideW) {
42  CHECK(strideH > 0 && strideW > 0) << "Stride must be greater than 0";
43  std::string tcStr;
44  tcStr = GROUP_CONVOLUTION2D_GRAD_TC;
45  tcStr = replaceString(tcStr, "<sh>", std::to_string(strideH));
46  tcStr = replaceString(tcStr, "<sw>", std::to_string(strideW));
47  return tcStr;
48 }
49 } // namespace tc
50 
std::string makeGroupConvolution2DGradTc(int strideH, int strideW)
Definition: group_convolution.h:60
std::string replaceString(std::string str, const std::string &search, const std::string &replace)
Definition: common.h:20
std::string makeGroupConvolution2DTc(int strideH, int strideW)
Definition: group_convolution.h:51