Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
tc_executor.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include "tc/core/halide_utils.h"
21 #include "tc/core/utils/dlpack.h"
22 #include "tc/lang/parser.h"
23 
24 #include <dlpack/dlpack.h>
25 
26 namespace tc {
27 
28 class TcExecutor {
29  public:
30  TcExecutor(
31  const std::string& TCDefinition,
32  const std::vector<const DLTensor*>& inputsInfo);
33  TcExecutor(
34  lang::TreeRef TCDefinition,
35  const std::vector<const DLTensor*>& inputsInfo);
36  ~TcExecutor();
37 
38  TcExecutor(TcExecutor&&) = delete;
39  TcExecutor& operator=(TcExecutor&&) = delete;
40  TcExecutor(const TcExecutor&) = delete;
41  TcExecutor& operator=(const TcExecutor&) = delete;
42 
43  // Given a Tc and a list of input tensors that match the definition in the
44  // Tc in positional order, this generates the output tensor infos issued
45  // from forward inference.
46  // The typical flow is to infer output sizes, allocate/resize them within
47  // you favorite ML framework / tensor library and then call compile.
48  std::vector<const DLTensor*> inferOutputTensorInfo();
49 
50  // Can only be called once with specific kernel options. Input sizes are
51  // set up as constructor argument and output sizes are inferred.
52  //
53  // If you need another kernel for another Tc or another inputs, outputs,
54  // options then just instantiate another TcExecutor.
55  // This is because for the time being we fully specialize all the sizes and
56  // strides at runtime.
57  void compile(const tc::MappingOptions& options);
58 
59  // Run can be called multiple times given a compilation, inputs are allowed
60  // to change in that their data pointer is allowed to change.
61  // Sizes and strides must remain constant otherwise this is an error
62  // The only thing that is allowed to change across runs is the input
63  // and output pointers base address.
64  // It is the caller's responsibility to ensure proper non-aliasing (or
65  // advanced aliasing) properties of the input and output tensors.
66  // if profile is set the kernel runtime (nanoseconds) is returned
67  Duration run(
68  const std::vector<const DLTensor*>& inputs,
69  const std::vector<DLTensor*>& outputs,
70  bool profile = false) const;
71 
72  // This is the "low-latency" mode in which we just propagate raw pointers to
73  // data in GPU address space.
74  // No tensor-related information can be checked so it is the user's
75  // responsibility to ensure that shapes and strides match. If the user
76  // doesn't then segfault will likely occur.
77  void uncheckedRun(
78  const std::vector<const void*>& inputs,
79  const std::vector<void*>& outputs) const;
80 
81  std::string getCudaSource() {
82  return execInfo_.cudaSource;
83  }
84 
85  bool hasRTCFun() {
86  return execInfo_.rtcFun.get() != nullptr;
87  }
88 
89  // It is necessary to clear the RTC manually because it can throw and we
90  // can't have that in the destructor.
91  void clearRTC() {
92  if (!hasRTCFun()) {
93  return;
94  }
95  execInfo_.rtcFun->clear();
96  }
97 
98  std::string kernelName() const {
99  return execInfo_.kernelName;
100  }
101 
102  private:
103  void compileWithTcMapper();
104 
106  std::string kernelName;
107  std::vector<dlutils::DLTensorUPtr> inputsInfo;
108  std::vector<dlutils::DLTensorUPtr> outputsInfo;
109  std::vector<int> kernelParams;
111  std::unique_ptr<tc::MappingOptions> options;
112  std::string cudaSource;
113  Grid grid{{0, 0, 0}};
114  Block block{{0, 0, 0}};
115  std::shared_ptr<CudaRTCFunction> rtcFun;
116  };
117 
118  public:
119  Grid grid() const {
120  return execInfo_.grid;
121  }
122  Block block() const {
123  return execInfo_.block;
124  }
125 
126  const static size_t InvalidHandle = std::numeric_limits<size_t>::max();
127 
128  private:
130  const std::vector<const DLTensor*>& inputsInfo) const;
134  mutable isl::ctx ctx_;
135 };
136 
137 } // namespace tc
Block block() const
Definition: tc_executor.h:122
std::string getCudaSource()
Definition: tc_executor.h:81
std::shared_ptr< CudaRTCFunction > rtcFun
Definition: tc_executor.h:115
Specializing CudaDim to differentiate between Block and Grid sizes.
Definition: mapping_options.h:208
TcExecutor(const std::string &TCDefinition, const std::vector< const DLTensor * > &inputsInfo)
std::unique_ptr< tc::MappingOptions > options
Definition: tc_executor.h:111
std::string kernelSpecializedName
Definition: tc_executor.h:110
std::vector< dlutils::DLTensorUPtr > inputsInfo
Definition: tc_executor.h:107
Definition: tc2halide.h:29
std::vector< const DLTensor * > inferOutputTensorInfo()
void uncheckedRun(const std::vector< const void * > &inputs, const std::vector< void * > &outputs) const
isl::ctx ctx_
Definition: tc_executor.h:134
std::vector< int > kernelParams
Definition: tc_executor.h:109
void compileWithTcMapper()
void clearRTC()
Definition: tc_executor.h:91
lang::TreeRef tcTree_
Definition: tc_executor.h:133
void compile(const tc::MappingOptions &options)
std::vector< dlutils::DLTensorUPtr > outputsInfo
Definition: tc_executor.h:108
std::string kernelName
Definition: tc_executor.h:106
std::string cudaSource
Definition: tc_executor.h:112
Definition: mapping_options.h:336
std::chrono::high_resolution_clock::duration Duration
Definition: rtc.h:31
bool hasRTCFun()
Definition: tc_executor.h:85
void checkInputsCompliant(const std::vector< const DLTensor * > &inputsInfo) const
Duration run(const std::vector< const DLTensor * > &inputs, const std::vector< DLTensor * > &outputs, bool profile=false) const
static const size_t InvalidHandle
Definition: tc_executor.h:126
tc2halide::HalideComponents halideComponents_
Definition: tc_executor.h:131
TcExecutionInfo execInfo_
Definition: tc_executor.h:132
TcExecutor & operator=(TcExecutor &&)=delete
Block block
Definition: tc_executor.h:114
Specializing CudaDim to differentiate between Block and Grid sizes.
Definition: mapping_options.h:196
Grid grid() const
Definition: tc_executor.h:119
std::string kernelName() const
Definition: tc_executor.h:98
Grid grid
Definition: tc_executor.h:113
std::shared_ptr< Tree > TreeRef
Definition: tree.h:44
Definition: tc_executor.h:105
Definition: tc_executor.h:28