22 constexpr
static auto CONVOLUTION2D_TC_NAME =
"convolution";
24 constexpr
static auto CONVOLUTION2D_GRAD_TC_NAME =
"convolution2dGrad";
27 constexpr
static auto CONVOLUTION2D_TC = R
"TC(
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) B) -> (O) {
O(n, m, h, w) +=! I(n, c, ${sh} * h + kh, ${sw} * w + kw) * W1(m, c, kh, kw)
O(n, m, h, w) = O(n, m, h, w) + B(m)
}
)TC";
29 constexpr static auto CONVOLUTION2D_GRAD_TC = R
"TC(
def convolution2dGrad(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) O_grad) -> (I_grad, W1_grad, B_grad) {
I_grad(n, c, h, w) +=! O_grad(n, m, ${sh} * h - kh, ${sw} * w - kw) * W1(m, c, kh, kw)
W1_grad(m, c, kh, kw) +=! O_grad(n, m, ${sh} * h - kh, ${sw} * w - kw) * I(n, c, h, w)
B_grad(m) +=! O_grad(n, m, h, w)
}
)TC";
33 CHECK(strideH > 0 && strideW > 0) <<
"Stride must be greater than 0";
35 tcStr = CONVOLUTION2D_TC;
36 tcStr =
replaceString(tcStr,
"${sh}", std::to_string(strideH));
37 tcStr =
replaceString(tcStr,
"${sw}", std::to_string(strideW));
42 CHECK(strideH > 0 && strideW > 0) <<
"Stride must be greater than 0";
44 tcStr = CONVOLUTION2D_GRAD_TC;
45 tcStr =
replaceString(tcStr,
"${sh}", std::to_string(strideH));
46 tcStr =
replaceString(tcStr,
"${sw}", std::to_string(strideW));
50 std::string makeConvolution2DTc(int strideH, int strideW)
Definition: convolution.h:43
std::string replaceString(std::string str, const std::string &search, const std::string &replace)
Definition: common.h:20
std::string makeConvolution2DGradTc(int strideH, int strideW)
Definition: convolution.h:52