Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
copy_op.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tc/c2/tc_op.h"
22 #include "tc/library/copy.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context, class Engine = caffe2::DefaultEngine>
27 class TcCopyOp : public TcOp<T, Context, Engine> {
28  public:
29  static constexpr auto description = tc::COPY_DOC;
30 
31  TcCopyOp(const caffe2::OperatorDef& operator_def, caffe2::Workspace* ws)
32  : TcOp<T, Context, Engine>(operator_def, ws) {}
33 
34  ~TcCopyOp() override {}
35 
36  bool RunOnDevice() override {
37  this->tc_ = tc::makeCopyTc(this->Input(0).dims().size());
38  this->tcName_ = tc::COPY_TC_NAME;
39  this->gradTc_ = tc::makeCopyGradTc(this->Input(0).dims().size());
40  this->gradTcName_ = tc::COPY_GRAD_TC_NAME;
42  }
43 
44  protected:
45  void setupNaiveMappingOptions() override {
47  .tile({4, 8, 8})
48  .mapToThreads({32, 4, 4})
49  .mapToBlocks({100, 100, 100})
50  .unroll({128});
51  this->gradMappingOptions_ =
53  }
54 };
55 } // namespace caffe2
tc::MappingOptions mappingOptions_
Definition: tc_op.h:133
void setupNaiveMappingOptions() override
Definition: copy_op.h:45
std::string gradTc_
Definition: tc_op.h:129
Definition: tc_op.h:36
std::string makeCopyTc(int numDims)
Definition: copy.h:58
std::string gradTcName_
Definition: tc_op.h:131
TcCopyOp(const caffe2::OperatorDef &operator_def, caffe2::Workspace *ws)
Definition: copy_op.h:31
std::string tc_
Definition: tc_op.h:128
std::string tcName_
Definition: tc_op.h:130
static constexpr auto description
Definition: copy_op.h:29
virtual bool RunOnDevice() override
Definition: tc_op.h:93
MappingOptions & tile(const std::vector< uint64_t > &sizes)
Definition: mapping_options-inl.h:251
std::string makeCopyGradTc(int numDims)
Definition: copy.h:62
static MappingOptions makePointwiseMappingOptions()
bool RunOnDevice() override
Definition: copy_op.h:36
Definition: copy_op.h:27
~TcCopyOp() override
Definition: copy_op.h:34
tc::MappingOptions gradMappingOptions_
Definition: tc_op.h:134