23 #include <glog/logging.h>
28 template <
typename CC>
30 CC::getGlobalSharedCache() = std::make_shared<CC>();
33 template <
typename CC>
35 CC::getGlobalSharedCache() =
nullptr;
38 template <
typename CC>
40 if (not cacheEnabled()) {
41 throw std::runtime_error(
42 "EnableCache or LoadCacheFromProtobuf must be called before using the cache.");
44 return CC::getGlobalSharedCache();
47 template <
typename CC>
49 std::fstream serialized(
50 filename, std::ios::binary | std::ios::trunc | std::ios::out);
52 LOG(ERROR) <<
"Failed to open the output stream for dumping protobuf: "
55 getCache()->toProtobuf().SerializePartialToOstream(&serialized);
59 template <
typename CC>
61 typename CC::Protobuf buf;
62 struct stat buffer = {0};
63 if (stat(filename.c_str(), &buffer) == 0) {
64 std::ifstream serialized(filename, std::ios::binary);
65 buf.ParseFromIstream(&serialized);
67 loadCacheFromProtobuf(buf);
70 template <
typename CC>
71 template <
typename Protobuf>
74 std::is_same<Protobuf, typename CC::Protobuf>::value,
75 "LoadCacheFromProtobuf called with invalide protobuf type.");
76 CC::getGlobalSharedCache() = std::make_shared<CC>(buf);
79 template <
typename CC>
81 return CC::getGlobalSharedCache() !=
nullptr;
84 template <
typename CC>
86 std::lock_guard<std::mutex> lock(mtx_);
87 return static_cast<const CC*
>(
this)->entries_.size();
90 template <
typename CC>
92 std::lock_guard<std::mutex> lock(mtx_);
93 numberAttemptedRetrievals = numberSuccessfulRetrievals = numberCacheAttemps =
95 static_cast<CC*
>(
this)->entries_.clear();
98 template <
typename C,
typename InputTy>
102 const std::string&
id,
104 const std::vector<InputTy>& inputs,
105 const std::vector<InputTy>& outputs)
106 -> decltype(c.searchKernel(
id, options, inputs, outputs)) {
108 auto it = std::find_if(
109 c.entries_.begin(), c.entries_.end(), [&](
const CachedEntry& c) {
110 using tc::operator==;
111 return id == c.key.id && options == c.key.mappingOptions &&
112 inputs == c.key.inputs && outputs == c.key.outputs &&
113 gpuStr == c.key.deviceStr;
115 if (it != c.entries_.end()) {
116 if (it->key.gitVersion != tc::git_version) {
117 std::cerr <<
"Proto version doesn't match. TC git version is: "
119 <<
" and Proto version is: " << it->key.gitVersion
120 <<
" .This proto might be incompatible"
121 <<
" with your TC binary and can break. Please autotune"
122 <<
" against the correct TC version." << std::endl;
130 template <
typename C>
133 const std::string&
id,
134 const std::vector<const DLTensor*>& inputs,
135 const std::vector<const DLTensor*>& outputs)
136 -> decltype(c.searchKernel(
id, inputs, outputs)) {
138 auto it = std::find_if(
139 c.entries_.begin(), c.entries_.end(), [&](
const CachedEntry& c) {
140 using tc::operator==;
141 return id == c.key.id && inputs == c.key.inputs &&
142 outputs == c.key.outputs && gpuStr == c.key.deviceStr;
144 if (it != c.entries_.end()) {
145 if (it->key.gitVersion != tc::git_version) {
146 std::cerr <<
"Proto version doesn't match. TC git version is: "
148 <<
" and Proto version is: " << it->key.gitVersion
149 <<
" .This proto might be incompatible"
150 <<
" with your TC binary and can break. Please autotune"
151 <<
" against the correct TC version." << std::endl;
160 template <
typename C,
typename TensorTy>
163 const std::string&
id,
164 const std::vector<TensorTy>& inputs,
165 const std::vector<TensorTy>& outputs)
166 -> decltype(c.searchKernel(
id, inputs, outputs)) {
168 auto it = std::find_if(
169 c.entries_.begin(), c.entries_.end(), [&](
const CachedEntry& c) {
170 using tc::operator==;
171 return id == c.key.id && inputs == c.key.inputs &&
172 outputs == c.key.outputs && gpuStr == c.key.deviceStr;
174 if (it != c.entries_.end()) {
175 std::cout <<
"RETURNING IT: " << it->key.gitVersion << std::endl;
176 if (it->key.gitVersion != tc::git_version) {
177 std::cerr <<
"Proto version doesn't match. TC git version is: "
179 <<
" and Proto version is: " << it->key.gitVersion
180 <<
" .This proto might be incompatible"
181 <<
" with your TC binary and can break. Please autotune"
182 <<
" against the correct TC version." << std::endl;
static CudaGPUInfo & GPUInfo()
static bool cacheEnabled()
Definition: compilation_cache-inl.h:80
static void dumpCacheToProtobuf(const std::string &filename)
Definition: compilation_cache-inl.h:48
static auto searchKernelImpl(C &c, const std::string &id, const std::vector< InputTy > &inputs, const std::vector< InputTy > &outputs) -> decltype(c.searchKernel(id, inputs, outputs))
static auto searchKernelImpl(C &c, const std::string &id, const std::vector< const DLTensor * > &inputs, const std::vector< const DLTensor * > &outputs) -> decltype(c.searchKernel(id, inputs, outputs))
Definition: compilation_cache-inl.h:131
static void loadCacheFromProtobuf(const std::string &filename)
Definition: compilation_cache-inl.h:60
size_t size() const
Definition: compilation_cache-inl.h:85
static void disableCache()
Definition: compilation_cache-inl.h:34
static void enableCache()
Definition: compilation_cache-inl.h:29
static auto searchKernelImpl(C &c, const std::string &id, const MappingOptions &options, const std::vector< TensorTy > &inputs, const std::vector< TensorTy > &outputs) -> decltype(c.searchKernel(id, options, inputs, outputs))
Definition: mapping_options.h:336
Definition: compilation_cache.h:383
static std::shared_ptr< CC > getCache()
Definition: compilation_cache-inl.h:39
Definition: compilation_cache.h:123
std::string GetCudaDeviceStr() const
void clear()
Definition: compilation_cache-inl.h:91
Definition: compilation_cache.h:247