Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
group_convolution_op.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tc/c2/tc_op.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context, class Engine = caffe2::DefaultEngine>
27 class TcGroupConvolutionOp : public TcOp<T, Context, Engine> {
28  int group_;
29 
30  public:
31  static constexpr auto description = tc::GROUP_CONVOLUTION2D_TC;
32 
34  const caffe2::OperatorDef& operator_def,
35  caffe2::Workspace* ws)
36  : TcOp<T, Context, Engine>(operator_def, ws),
37  group_(OperatorBase::GetSingleArgument<int>("group", -1)) {
38  CHECK_EQ(-1, group_)
39  << "Caffe2 implements group convolution as a dilated convolution. "
40  << "Someone (not us) needs to reshape.";
41 
42  int strideH = 0;
43  int strideW = 0;
44  if (OperatorBase::HasArgument("stride")) {
45  strideH = OperatorBase::GetSingleArgument<int>("stride", 1);
46  strideW = OperatorBase::GetSingleArgument<int>("stride", 1);
47  } else {
48  strideH = OperatorBase::GetSingleArgument<int>("stride_h", 1);
49  strideW = OperatorBase::GetSingleArgument<int>("stride_w", 1);
50  }
51 
52  int padT = 0;
53  int padL = 0;
54  int padB = 0;
55  int padR = 0;
56  if (OperatorBase::HasArgument("pad")) {
57  padT = OperatorBase::GetSingleArgument<int>("pad", 0);
58  padL = OperatorBase::GetSingleArgument<int>("pad", 0);
59  padB = OperatorBase::GetSingleArgument<int>("pad", 0);
60  padR = OperatorBase::GetSingleArgument<int>("pad", 0);
61  } else {
62  padT = OperatorBase::GetSingleArgument<int>("pad_t", 0);
63  padL = OperatorBase::GetSingleArgument<int>("pad_l", 0);
64  padB = OperatorBase::GetSingleArgument<int>("pad_b", 0);
65  padR = OperatorBase::GetSingleArgument<int>("pad_r", 0);
66  }
67 
68  CHECK(padT == 0 && padL == 0 && padB == 0 && padR == 0)
69  << "NYI: padding larger than 0";
70 
71  this->tc_ = tc::makeGroupConvolution2DTc(strideH, strideW);
72  this->tcName_ = tc::GROUP_CONVOLUTION2D_TC_NAME;
73  this->gradTc_ = tc::makeGroupConvolution2DGradTc(strideH, strideW);
74  this->gradTcName_ = tc::GROUP_CONVOLUTION2D_GRAD_TC_NAME;
75  }
76 
77  ~TcGroupConvolutionOp() override {}
78 
79  protected:
80  void setupNaiveMappingOptions() override {
81  this->mappingOptions_ =
83  this->gradMappingOptions_ =
85  }
86 };
87 } // namespace caffe2
tc::MappingOptions mappingOptions_
Definition: tc_op.h:133
static MappingOptions makeGroupConvolutionMappingOptions()
std::string gradTc_
Definition: tc_op.h:129
Definition: tc_op.h:36
std::string makeGroupConvolution2DGradTc(int strideH, int strideW)
Definition: group_convolution.h:60
std::string gradTcName_
Definition: tc_op.h:131
std::string tc_
Definition: tc_op.h:128
std::string tcName_
Definition: tc_op.h:130
~TcGroupConvolutionOp() override
Definition: group_convolution_op.h:77
void setupNaiveMappingOptions() override
Definition: group_convolution_op.h:80
TcGroupConvolutionOp(const caffe2::OperatorDef &operator_def, caffe2::Workspace *ws)
Definition: group_convolution_op.h:33
static constexpr auto description
Definition: group_convolution_op.h:31
Definition: group_convolution_op.h:27
std::string makeGroupConvolution2DTc(int strideH, int strideW)
Definition: group_convolution.h:51
int group_
Definition: group_convolution_op.h:28
tc::MappingOptions gradMappingOptions_
Definition: tc_op.h:134