Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
tc_op.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <functional>
19 #include <sstream>
20 #include <string>
21 #include <vector>
22 
25 #include "tc/core/utils/dlpack.h"
26 
27 #include "tc/c2/context.h"
28 #include "tc/c2/dlpack_c2.h"
29 
30 #include "caffe2/core/common.h"
31 #include "caffe2/utils/math.h"
32 
33 namespace caffe2 {
34 
35 template <typename T, class Context, class Engine = DefaultEngine>
36 class TcOp : public Operator<Context> {
37  public:
38  TcOp(const OperatorDef& operator_def, Workspace* ws)
39  : caffe2::Operator<Context>(operator_def, ws),
40  tc_(OperatorBase::GetSingleArgument<std::string>("tcDef", "ERROR")),
41  tcName_(
42  OperatorBase::GetSingleArgument<std::string>("tcName", "ERROR")),
43  mappingOptions_(tc::MappingOptions::makeNaiveMappingOptions()),
44  gradMappingOptions_(tc::MappingOptions::makeNaiveMappingOptions()) {
45  gradTc_ =
46  OperatorBase::GetSingleArgument<std::string>("tcGradDef", "ERROR");
47  gradTcName_ =
48  OperatorBase::GetSingleArgument<std::string>("tcGradName", "ERROR");
49  profile_ = OperatorBase::GetSingleArgument<bool>("profile", false);
50  ArgumentHelper args(operator_def);
51  if (args.HasArgument("mappingOptions")) {
53  args.GetSingleArgument<std::string>("mappingOptions", "ERROR"));
54  } else {
56  }
57 
58  if (args.HasArgument("gradMappingOptions")) {
60  args.GetSingleArgument<std::string>("gradMappingOptions", "ERROR"));
61  } else {
63  }
65  std::unique_ptr<tc::ExecutionEngine>(new tc::ExecutionEngine());
66  }
67 
69 
70  ~TcOp() override {}
71 
72  protected:
76  virtual void setupNaiveMappingOptions() {}
77 
82 
83  void prepareOutputs(const std::vector<const DLTensor*> tensorInfo) {
84  for (int i = 0; i < tensorInfo.size(); ++i) {
85  auto info = tensorInfo[i];
86  std::vector<int64_t> shape(info->shape, info->shape + info->ndim);
87  Output(i)->Resize(shape);
88  // Note: this mutable_data() call actually creates the data storage.
89  Output(i)->template mutable_data<T>();
90  }
91  }
92 
93  virtual bool RunOnDevice() override {
94  // first, given the TC, define it in the executionEngine_
95  executionEngine_->define(tc_);
96 
97  // now, given the input tensors, convert them to dlpack tensors so that
98  // we can call the compile command
99  std::vector<::tc::dlutils::DLTensorUPtr> inTensorUPtrs;
100  std::vector<const DLTensor*> inputDLTensors;
101  for (int idx = 0; idx < this->InputSize(); ++idx) {
102  auto dims = this->Input(idx).dims();
103  inTensorUPtrs.emplace_back(
104  dlpack::makeConstDLTensor(this->Input(idx), dims));
105  inputDLTensors.push_back(inTensorUPtrs.back().get());
106  }
107 
108  auto outTensorInfo =
109  executionEngine_->inferOutputTensorInfo(tcName_, inputDLTensors);
110  prepareOutputs(outTensorInfo);
111 
112  // now create the outputDLTensors
113  std::vector<::tc::dlutils::DLTensorUPtr> outTensorUPtrs;
114  std::vector<DLTensor*> outputDLTensors;
115  for (int i = 0; i < OutputSize(); ++i) {
116  outTensorUPtrs.emplace_back(dlpack::makeDLTensor(Output(i)));
117  outputDLTensors.push_back(outTensorUPtrs.back().get());
118  }
119 
120  // compile and run
121  auto handle =
122  executionEngine_->compile(tcName_, inputDLTensors, mappingOptions_);
123  executionEngine_->run(handle, inputDLTensors, outputDLTensors, profile_);
124  return true;
125  }
126 
127  protected:
128  std::string tc_;
129  std::string gradTc_;
130  std::string tcName_;
131  std::string gradTcName_;
132  bool profile_;
135 
136  private:
137  std::unique_ptr<tc::ExecutionEngine> executionEngine_;
138 };
139 
140 class GetTcOpGradient : public GradientMakerBase {
141  public:
142  using GradientMakerBase::GradientMakerBase;
143 
144  std::vector<OperatorDef> GetGradientDefs() override {
145  ArgumentHelper args(Def());
146  CHECK(false) << "NYI gradient";
147  return {};
148  }
149 };
150 } // namespace caffe2
tc::MappingOptions mappingOptions_
Definition: tc_op.h:133
Definition: tc_op.h:140
Definition: execution_engine.h:34
bool profile_
Definition: tc_op.h:132
USE_OPERATOR_CONTEXT_FUNCTIONS
Definition: tc_op.h:68
std::string gradTc_
Definition: tc_op.h:129
Definition: tc_op.h:36
std::string gradTcName_
Definition: tc_op.h:131
virtual void setupNaiveMappingOptions()
Definition: tc_op.h:76
std::string tc_
Definition: tc_op.h:128
~TcOp() override
Definition: tc_op.h:70
std::string tcName_
Definition: tc_op.h:130
void prepareOutputs(const std::vector< const DLTensor * > tensorInfo)
Definition: tc_op.h:83
std::unique_ptr< tc::ExecutionEngine > executionEngine_
Definition: tc_op.h:137
std::vector< OperatorDef > GetGradientDefs() override
Definition: tc_op.h:144
tc::dlutils::DLTensorUPtr makeConstDLTensor(const caffe2::Tensor< C2Context > &tensor, const vector< TIndex > &shapeOverride={})
Definition: dlpack_c2.h:61
Definition: mapping_options.h:336
virtual bool RunOnDevice() override
Definition: tc_op.h:93
tc::dlutils::DLTensorUPtr makeDLTensor(caffe2::Tensor< C2Context > *tensor)
Definition: dlpack_c2.h:88
TcOp(const OperatorDef &operator_def, Workspace *ws)
Definition: tc_op.h:38
virtual void setupDefaultGradMappingOptions()
Definition: tc_op.h:81
tc::MappingOptions gradMappingOptions_
Definition: tc_op.h:134