Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
convolution_op.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tc/c2/tc_op.h"
22 #include "tc/library/convolution.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context, class Engine = caffe2::DefaultEngine>
27 class TcConvolutionOp : public TcOp<T, Context, Engine> {
28  public:
29  static constexpr auto description = tc::CONVOLUTION2D_TC;
30 
32  const caffe2::OperatorDef& operator_def,
33  caffe2::Workspace* ws)
34  : TcOp<T, Context, Engine>(operator_def, ws) {
35  int strideH = 0;
36  int strideW = 0;
37  if (OperatorBase::HasArgument("stride")) {
38  strideH = OperatorBase::GetSingleArgument<int>("stride", 1);
39  strideW = OperatorBase::GetSingleArgument<int>("stride", 1);
40  } else {
41  strideH = OperatorBase::GetSingleArgument<int>("stride_h", 1);
42  strideW = OperatorBase::GetSingleArgument<int>("stride_w", 1);
43  }
44 
45  int padT = 0;
46  int padL = 0;
47  int padB = 0;
48  int padR = 0;
49  if (OperatorBase::HasArgument("pad")) {
50  padT = OperatorBase::GetSingleArgument<int>("pad", 0);
51  padL = OperatorBase::GetSingleArgument<int>("pad", 0);
52  padB = OperatorBase::GetSingleArgument<int>("pad", 0);
53  padR = OperatorBase::GetSingleArgument<int>("pad", 0);
54  } else {
55  padT = OperatorBase::GetSingleArgument<int>("pad_t", 0);
56  padL = OperatorBase::GetSingleArgument<int>("pad_l", 0);
57  padB = OperatorBase::GetSingleArgument<int>("pad_b", 0);
58  padR = OperatorBase::GetSingleArgument<int>("pad_r", 0);
59  }
60 
61  CHECK(padT == 0 && padL == 0 && padB == 0 && padR == 0)
62  << "NYI: padding larger than 0";
63 
64  this->tc_ = tc::makeConvolution2DTc(strideH, strideW);
65  this->tcName_ = tc::CONVOLUTION2D_TC_NAME;
66  this->gradTc_ = tc::makeConvolution2DGradTc(strideH, strideW);
67  this->gradTcName_ = tc::CONVOLUTION2D_GRAD_TC_NAME;
68  }
69 
70  ~TcConvolutionOp() override {}
71 
72  protected:
73  void setupNaiveMappingOptions() override {
75  this->gradMappingOptions_ =
77  }
78 };
79 } // namespace caffe2
tc::MappingOptions mappingOptions_
Definition: tc_op.h:133
std::string gradTc_
Definition: tc_op.h:129
Definition: tc_op.h:36
std::string gradTcName_
Definition: tc_op.h:131
std::string tc_
Definition: tc_op.h:128
std::string tcName_
Definition: tc_op.h:130
Definition: convolution_op.h:27
~TcConvolutionOp() override
Definition: convolution_op.h:70
std::string makeConvolution2DTc(int strideH, int strideW)
Definition: convolution.h:43
static MappingOptions makeConvolutionMappingOptions()
void setupNaiveMappingOptions() override
Definition: convolution_op.h:73
std::string makeConvolution2DGradTc(int strideH, int strideW)
Definition: convolution.h:52
static constexpr auto description
Definition: convolution_op.h:29
tc::MappingOptions gradMappingOptions_
Definition: tc_op.h:134
TcConvolutionOp(const caffe2::OperatorDef &operator_def, caffe2::Workspace *ws)
Definition: convolution_op.h:31