26 template <
typename T,
class Context,
class Engine = caffe2::DefaultEngine>
31 static constexpr
auto description = tc::GROUP_CONVOLUTION2D_TC;
34 const caffe2::OperatorDef& operator_def,
35 caffe2::Workspace* ws)
36 :
TcOp<T, Context, Engine>(operator_def, ws),
37 group_(OperatorBase::GetSingleArgument<int>(
"group", -1)) {
39 <<
"Caffe2 implements group convolution as a dilated convolution. "
40 <<
"Someone (not us) needs to reshape.";
44 if (OperatorBase::HasArgument(
"stride")) {
45 strideH = OperatorBase::GetSingleArgument<int>(
"stride", 1);
46 strideW = OperatorBase::GetSingleArgument<int>(
"stride", 1);
48 strideH = OperatorBase::GetSingleArgument<int>(
"stride_h", 1);
49 strideW = OperatorBase::GetSingleArgument<int>(
"stride_w", 1);
56 if (OperatorBase::HasArgument(
"pad")) {
57 padT = OperatorBase::GetSingleArgument<int>(
"pad", 0);
58 padL = OperatorBase::GetSingleArgument<int>(
"pad", 0);
59 padB = OperatorBase::GetSingleArgument<int>(
"pad", 0);
60 padR = OperatorBase::GetSingleArgument<int>(
"pad", 0);
62 padT = OperatorBase::GetSingleArgument<int>(
"pad_t", 0);
63 padL = OperatorBase::GetSingleArgument<int>(
"pad_l", 0);
64 padB = OperatorBase::GetSingleArgument<int>(
"pad_b", 0);
65 padR = OperatorBase::GetSingleArgument<int>(
"pad_r", 0);
68 CHECK(padT == 0 && padL == 0 && padB == 0 && padR == 0)
69 <<
"NYI: padding larger than 0";
72 this->
tcName_ = tc::GROUP_CONVOLUTION2D_TC_NAME;
74 this->
gradTcName_ = tc::GROUP_CONVOLUTION2D_GRAD_TC_NAME;
tc::MappingOptions mappingOptions_
Definition: tc_op.h:133
static MappingOptions makeGroupConvolutionMappingOptions()
std::string gradTc_
Definition: tc_op.h:129
std::string makeGroupConvolution2DGradTc(int strideH, int strideW)
Definition: group_convolution.h:60
std::string gradTcName_
Definition: tc_op.h:131
std::string tc_
Definition: tc_op.h:128
std::string tcName_
Definition: tc_op.h:130
~TcGroupConvolutionOp() override
Definition: group_convolution_op.h:77
void setupNaiveMappingOptions() override
Definition: group_convolution_op.h:80
TcGroupConvolutionOp(const caffe2::OperatorDef &operator_def, caffe2::Workspace *ws)
Definition: group_convolution_op.h:33
static constexpr auto description
Definition: group_convolution_op.h:31
Definition: group_convolution_op.h:27
std::string makeGroupConvolution2DTc(int strideH, int strideW)
Definition: group_convolution.h:51
int group_
Definition: group_convolution_op.h:28
tc::MappingOptions gradMappingOptions_
Definition: tc_op.h:134