22 constexpr
static auto GROUP_CONVOLUTION2D_TC_NAME =
"group_convolution";
23 constexpr
static auto GROUP_CONVOLUTION2D_GRAD_TC_NAME =
24 "group_convolution2dGrad";
27 constexpr
static auto GROUP_CONVOLUTION2D_TC = R
"TC(
def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B)
-> (O)
{
O(n, g, f, h, w) +=!
I(n, g, c, <sh> * h + kh, <sw> * w + kw) * W1(g, f, c, kh, kw)
O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)
}
)TC";
29 constexpr static auto GROUP_CONVOLUTION2D_GRAD_TC = R
"TC(
def group_convolution2dGrad(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1,
float(N,F,H,W) O_grad)
-> (I_grad, W1_grad, B_grad)
{
I_grad(n, g, c, h, w) +=!
O_grad(n, g, f, <sh> * h - kh, <sw> * w - kw) * W1(g, f, c, kh, kw)
W1_grad(g, f, c, kh, kw) +=!
O_grad(n, g, f, <sh> * h - kh, <sw> * w - kw) * I(n, g, c, h, w)
B_grad(g, f) +=! O_grad(n, g, f, h, w)
}
)TC";
33 CHECK(strideH > 0 && strideW > 0) <<
"Stride must be greater than 0";
35 tcStr = GROUP_CONVOLUTION2D_TC;
36 tcStr =
replaceString(tcStr,
"<sh>", std::to_string(strideH));
37 tcStr =
replaceString(tcStr,
"<sw>", std::to_string(strideW));
42 CHECK(strideH > 0 && strideW > 0) <<
"Stride must be greater than 0";
44 tcStr = GROUP_CONVOLUTION2D_GRAD_TC;
45 tcStr =
replaceString(tcStr,
"<sh>", std::to_string(strideH));
46 tcStr =
replaceString(tcStr,
"<sw>", std::to_string(strideW));
50 std::string makeGroupConvolution2DGradTc(int strideH, int strideW)
Definition: group_convolution.h:60
std::string replaceString(std::string str, const std::string &search, const std::string &replace)
Definition: common.h:20
std::string makeGroupConvolution2DTc(int strideH, int strideW)
Definition: group_convolution.h:51