Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
compilation_cache-inl.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <fstream>
21 #include <string>
22 
23 #include <glog/logging.h>
24 #include <version.h>
25 
26 namespace tc {
27 
28 template <typename CC>
30  CC::getGlobalSharedCache() = std::make_shared<CC>();
31 }
32 
33 template <typename CC>
35  CC::getGlobalSharedCache() = nullptr;
36 }
37 
38 template <typename CC>
39 std::shared_ptr<CC> Cache<CC>::getCache() {
40  if (not cacheEnabled()) {
41  throw std::runtime_error(
42  "EnableCache or LoadCacheFromProtobuf must be called before using the cache.");
43  }
44  return CC::getGlobalSharedCache();
45 }
46 
47 template <typename CC>
48 void Cache<CC>::dumpCacheToProtobuf(const std::string& filename) {
49  std::fstream serialized(
50  filename, std::ios::binary | std::ios::trunc | std::ios::out);
51  if (!serialized) {
52  LOG(ERROR) << "Failed to open the output stream for dumping protobuf: "
53  << filename;
54  } else {
55  getCache()->toProtobuf().SerializePartialToOstream(&serialized);
56  }
57 }
58 
59 template <typename CC>
60 void Cache<CC>::loadCacheFromProtobuf(const std::string& filename) {
61  typename CC::Protobuf buf;
62  struct stat buffer = {0};
63  if (stat(filename.c_str(), &buffer) == 0) {
64  std::ifstream serialized(filename, std::ios::binary);
65  buf.ParseFromIstream(&serialized);
66  }
67  loadCacheFromProtobuf(buf);
68 }
69 
70 template <typename CC>
71 template <typename Protobuf>
72 void Cache<CC>::loadCacheFromProtobuf(const Protobuf& buf) {
73  static_assert(
74  std::is_same<Protobuf, typename CC::Protobuf>::value,
75  "LoadCacheFromProtobuf called with invalide protobuf type.");
76  CC::getGlobalSharedCache() = std::make_shared<CC>(buf);
77 }
78 
79 template <typename CC>
81  return CC::getGlobalSharedCache() != nullptr;
82 }
83 
84 template <typename CC>
85 size_t Cache<CC>::size() const {
86  std::lock_guard<std::mutex> lock(mtx_);
87  return static_cast<const CC*>(this)->entries_.size();
88 }
89 
90 template <typename CC>
92  std::lock_guard<std::mutex> lock(mtx_);
93  numberAttemptedRetrievals = numberSuccessfulRetrievals = numberCacheAttemps =
94  0;
95  static_cast<CC*>(this)->entries_.clear();
96 }
97 
98 template <typename C, typename InputTy> // deduces whether C is const or
99 // non-const
101  C& c,
102  const std::string& id,
103  const MappingOptions& options,
104  const std::vector<InputTy>& inputs,
105  const std::vector<InputTy>& outputs)
106  -> decltype(c.searchKernel(id, options, inputs, outputs)) {
107  auto gpuStr = CudaGPUInfo::GPUInfo().GetCudaDeviceStr();
108  auto it = std::find_if(
109  c.entries_.begin(), c.entries_.end(), [&](const CachedEntry& c) {
110  using tc::operator==;
111  return id == c.key.id && options == c.key.mappingOptions &&
112  inputs == c.key.inputs && outputs == c.key.outputs &&
113  gpuStr == c.key.deviceStr;
114  });
115  if (it != c.entries_.end()) {
116  if (it->key.gitVersion != tc::git_version) {
117  std::cerr << "Proto version doesn't match. TC git version is: "
118  << tc::git_version
119  << " and Proto version is: " << it->key.gitVersion
120  << " .This proto might be incompatible"
121  << " with your TC binary and can break. Please autotune"
122  << " against the correct TC version." << std::endl;
123  }
124  return &*it;
125  }
126  return nullptr;
127 }
128 
129 // deduces whether C is const or non-const
130 template <typename C>
132  C& c,
133  const std::string& id,
134  const std::vector<const DLTensor*>& inputs,
135  const std::vector<const DLTensor*>& outputs)
136  -> decltype(c.searchKernel(id, inputs, outputs)) {
137  auto gpuStr = CudaGPUInfo::GPUInfo().GetCudaDeviceStr();
138  auto it = std::find_if(
139  c.entries_.begin(), c.entries_.end(), [&](const CachedEntry& c) {
140  using tc::operator==;
141  return id == c.key.id && inputs == c.key.inputs &&
142  outputs == c.key.outputs && gpuStr == c.key.deviceStr;
143  });
144  if (it != c.entries_.end()) {
145  if (it->key.gitVersion != tc::git_version) {
146  std::cerr << "Proto version doesn't match. TC git version is: "
147  << tc::git_version
148  << " and Proto version is: " << it->key.gitVersion
149  << " .This proto might be incompatible"
150  << " with your TC binary and can break. Please autotune"
151  << " against the correct TC version." << std::endl;
152  ;
153  }
154  return &*it;
155  }
156  return nullptr;
157 }
158 
159 // deduces whether C is const or non-const
160 template <typename C, typename TensorTy>
162  C& c,
163  const std::string& id,
164  const std::vector<TensorTy>& inputs,
165  const std::vector<TensorTy>& outputs)
166  -> decltype(c.searchKernel(id, inputs, outputs)) {
167  auto gpuStr = CudaGPUInfo::GPUInfo().GetCudaDeviceStr();
168  auto it = std::find_if(
169  c.entries_.begin(), c.entries_.end(), [&](const CachedEntry& c) {
170  using tc::operator==;
171  return id == c.key.id && inputs == c.key.inputs &&
172  outputs == c.key.outputs && gpuStr == c.key.deviceStr;
173  });
174  if (it != c.entries_.end()) {
175  std::cout << "RETURNING IT: " << it->key.gitVersion << std::endl;
176  if (it->key.gitVersion != tc::git_version) {
177  std::cerr << "Proto version doesn't match. TC git version is: "
178  << tc::git_version
179  << " and Proto version is: " << it->key.gitVersion
180  << " .This proto might be incompatible"
181  << " with your TC binary and can break. Please autotune"
182  << " against the correct TC version." << std::endl;
183  ;
184  }
185  return &*it;
186  }
187  return nullptr;
188 }
189 
190 } // namespace tc
static CudaGPUInfo & GPUInfo()
static bool cacheEnabled()
Definition: compilation_cache-inl.h:80
static void dumpCacheToProtobuf(const std::string &filename)
Definition: compilation_cache-inl.h:48
static auto searchKernelImpl(C &c, const std::string &id, const std::vector< InputTy > &inputs, const std::vector< InputTy > &outputs) -> decltype(c.searchKernel(id, inputs, outputs))
static auto searchKernelImpl(C &c, const std::string &id, const std::vector< const DLTensor * > &inputs, const std::vector< const DLTensor * > &outputs) -> decltype(c.searchKernel(id, inputs, outputs))
Definition: compilation_cache-inl.h:131
static void loadCacheFromProtobuf(const std::string &filename)
Definition: compilation_cache-inl.h:60
size_t size() const
Definition: compilation_cache-inl.h:85
static void disableCache()
Definition: compilation_cache-inl.h:34
static void enableCache()
Definition: compilation_cache-inl.h:29
static auto searchKernelImpl(C &c, const std::string &id, const MappingOptions &options, const std::vector< TensorTy > &inputs, const std::vector< TensorTy > &outputs) -> decltype(c.searchKernel(id, options, inputs, outputs))
Definition: mapping_options.h:336
Definition: compilation_cache.h:383
static std::shared_ptr< CC > getCache()
Definition: compilation_cache-inl.h:39
Definition: compilation_cache.h:123
std::string GetCudaDeviceStr() const
void clear()
Definition: compilation_cache-inl.h:91
Definition: compilation_cache.h:247