Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
rtc.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <chrono>
19 #include <memory>
20 #include <mutex>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include <cuda.h>
26 #include <driver_types.h> // cuda driver types
27 
28 namespace tc {
29 
30 extern std::mutex nvrtc_mutex;
31 using Duration = std::chrono::high_resolution_clock::duration;
32 
33 //
34 // Basic interface to expose NVRTC JIT compilation and module
35 // loading/unloading + API kernel launches.
36 //
39 
40  public:
42 
43  static std::shared_ptr<CudaRTCFunction> Compile(
44  const std::string& name,
45  const std::string& source);
46 
47  // if profile is set it returns the kernel runtime
49  const std::array<size_t, 3>& grid,
50  const std::array<size_t, 3>& block,
51  unsigned int shared_mem,
52  cudaStream_t stream,
53  // by copy because we take an address to element when calling the kernel
54  // TODO: check the overhead of double indirection on kernel calls, this
55  // does not look ideal for low-latency
56  std::vector<int> params,
57  std::vector<void*> outputs,
58  std::vector<const void*> inputs,
59  bool profile = false) const;
60 
61  void clear();
62 
63  private:
64  mutable std::unordered_map<size_t, CUmodule> perGpuModule_;
65  mutable std::unordered_map<size_t, CUfunction> perGpuKernel_;
66  std::string specializedName;
67  std::vector<char> nvrtc_ptx;
68  bool cleared_;
69 };
70 
71 } // namespace tc
Duration Launch(const std::array< size_t, 3 > &grid, const std::array< size_t, 3 > &block, unsigned int shared_mem, cudaStream_t stream, std::vector< int > params, std::vector< void * > outputs, std::vector< const void * > inputs, bool profile=false) const
std::vector< char > nvrtc_ptx
Definition: rtc.h:67
std::mutex nvrtc_mutex
std::string specializedName
Definition: rtc.h:66
std::unordered_map< size_t, CUmodule > perGpuModule_
Definition: rtc.h:64
std::chrono::high_resolution_clock::duration Duration
Definition: rtc.h:31
std::unordered_map< size_t, CUfunction > perGpuKernel_
Definition: rtc.h:65
static std::shared_ptr< CudaRTCFunction > Compile(const std::string &name, const std::string &source)
bool cleared_
Definition: rtc.h:68
Definition: rtc.h:37