23 #include <glog/logging.h> 
   28 template <
typename CC>
 
   30   CC::getGlobalSharedCache() = std::make_shared<CC>();
 
   33 template <
typename CC>
 
   35   CC::getGlobalSharedCache() = 
nullptr;
 
   38 template <
typename CC>
 
   40   if (not cacheEnabled()) {
 
   41     throw std::runtime_error(
 
   42         "EnableCache or LoadCacheFromProtobuf must be called before using the cache.");
 
   44   return CC::getGlobalSharedCache();
 
   47 template <
typename CC>
 
   49   std::fstream serialized(
 
   50       filename, std::ios::binary | std::ios::trunc | std::ios::out);
 
   52     LOG(ERROR) << 
"Failed to open the output stream for dumping protobuf: " 
   55     getCache()->toProtobuf().SerializePartialToOstream(&serialized);
 
   59 template <
typename CC>
 
   61   typename CC::Protobuf buf;
 
   62   struct stat buffer = {0};
 
   63   if (stat(filename.c_str(), &buffer) == 0) {
 
   64     std::ifstream serialized(filename, std::ios::binary);
 
   65     buf.ParseFromIstream(&serialized);
 
   67   loadCacheFromProtobuf(buf);
 
   70 template <
typename CC>
 
   71 template <
typename Protobuf>
 
   74       std::is_same<Protobuf, typename CC::Protobuf>::value,
 
   75       "LoadCacheFromProtobuf called with invalide protobuf type.");
 
   76   CC::getGlobalSharedCache() = std::make_shared<CC>(buf);
 
   79 template <
typename CC>
 
   81   return CC::getGlobalSharedCache() != 
nullptr;
 
   84 template <
typename CC>
 
   86   std::lock_guard<std::mutex> lock(mtx_);
 
   87   return static_cast<const CC*
>(
this)->entries_.size();
 
   90 template <
typename CC>
 
   92   std::lock_guard<std::mutex> lock(mtx_);
 
   93   numberAttemptedRetrievals = numberSuccessfulRetrievals = numberCacheAttemps =
 
   95   static_cast<CC*
>(
this)->entries_.clear();
 
   98 template <
typename C, 
typename InputTy> 
 
  102     const std::string& 
id,
 
  104     const std::vector<InputTy>& inputs,
 
  105     const std::vector<InputTy>& outputs)
 
  106     -> decltype(c.searchKernel(
id, options, inputs, outputs)) {
 
  108   auto it = std::find_if(
 
  109       c.entries_.begin(), c.entries_.end(), [&](
const CachedEntry& c) {
 
  110         using tc::operator==;
 
  111         return id == c.key.id && options == c.key.mappingOptions &&
 
  112             inputs == c.key.inputs && outputs == c.key.outputs &&
 
  113             gpuStr == c.key.deviceStr;
 
  115   if (it != c.entries_.end()) {
 
  116     if (it->key.gitVersion != tc::git_version) {
 
  117       std::cerr << 
"Proto version doesn't match. TC git version is: " 
  119                 << 
" and Proto version is: " << it->key.gitVersion
 
  120                 << 
" .This proto might be incompatible" 
  121                 << 
" with your TC binary and can break. Please autotune" 
  122                 << 
" against the correct TC version." << std::endl;
 
  130 template <
typename C>
 
  133     const std::string& 
id,
 
  134     const std::vector<const DLTensor*>& inputs,
 
  135     const std::vector<const DLTensor*>& outputs)
 
  136     -> decltype(c.searchKernel(
id, inputs, outputs)) {
 
  138   auto it = std::find_if(
 
  139       c.entries_.begin(), c.entries_.end(), [&](
const CachedEntry& c) {
 
  140         using tc::operator==;
 
  141         return id == c.key.id && inputs == c.key.inputs &&
 
  142             outputs == c.key.outputs && gpuStr == c.key.deviceStr;
 
  144   if (it != c.entries_.end()) {
 
  145     if (it->key.gitVersion != tc::git_version) {
 
  146       std::cerr << 
"Proto version doesn't match. TC git version is: " 
  148                 << 
" and Proto version is: " << it->key.gitVersion
 
  149                 << 
" .This proto might be incompatible" 
  150                 << 
" with your TC binary and can break. Please autotune" 
  151                 << 
" against the correct TC version." << std::endl;
 
  160 template <
typename C, 
typename TensorTy>
 
  163     const std::string& 
id,
 
  164     const std::vector<TensorTy>& inputs,
 
  165     const std::vector<TensorTy>& outputs)
 
  166     -> decltype(c.searchKernel(
id, inputs, outputs)) {
 
  168   auto it = std::find_if(
 
  169       c.entries_.begin(), c.entries_.end(), [&](
const CachedEntry& c) {
 
  170         using tc::operator==;
 
  171         return id == c.key.id && inputs == c.key.inputs &&
 
  172             outputs == c.key.outputs && gpuStr == c.key.deviceStr;
 
  174   if (it != c.entries_.end()) {
 
  175     std::cout << 
"RETURNING IT: " << it->key.gitVersion << std::endl;
 
  176     if (it->key.gitVersion != tc::git_version) {
 
  177       std::cerr << 
"Proto version doesn't match. TC git version is: " 
  179                 << 
" and Proto version is: " << it->key.gitVersion
 
  180                 << 
" .This proto might be incompatible" 
  181                 << 
" with your TC binary and can break. Please autotune" 
  182                 << 
" against the correct TC version." << std::endl;
 
static CudaGPUInfo & GPUInfo()
 
static bool cacheEnabled()
Definition: compilation_cache-inl.h:80
 
static void dumpCacheToProtobuf(const std::string &filename)
Definition: compilation_cache-inl.h:48
 
static auto searchKernelImpl(C &c, const std::string &id, const std::vector< InputTy > &inputs, const std::vector< InputTy > &outputs) -> decltype(c.searchKernel(id, inputs, outputs))
 
static auto searchKernelImpl(C &c, const std::string &id, const std::vector< const DLTensor * > &inputs, const std::vector< const DLTensor * > &outputs) -> decltype(c.searchKernel(id, inputs, outputs))
Definition: compilation_cache-inl.h:131
 
static void loadCacheFromProtobuf(const std::string &filename)
Definition: compilation_cache-inl.h:60
 
size_t size() const 
Definition: compilation_cache-inl.h:85
 
static void disableCache()
Definition: compilation_cache-inl.h:34
 
static void enableCache()
Definition: compilation_cache-inl.h:29
 
static auto searchKernelImpl(C &c, const std::string &id, const MappingOptions &options, const std::vector< TensorTy > &inputs, const std::vector< TensorTy > &outputs) -> decltype(c.searchKernel(id, options, inputs, outputs))
 
Definition: mapping_options.h:336
 
Definition: compilation_cache.h:383
 
static std::shared_ptr< CC > getCache()
Definition: compilation_cache-inl.h:39
 
Definition: compilation_cache.h:123
 
std::string GetCudaDeviceStr() const 
 
void clear()
Definition: compilation_cache-inl.h:91
 
Definition: compilation_cache.h:247