Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
aten_compiler.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <chrono>
19 #include <string>
20 #include <vector>
21 
22 #include <ATen/ATen.h>
23 #include <ATen/DLConvertor.h>
24 
27 #include "tc/lang/parser.h"
28 
29 namespace tc {
30 
33 
35  public:
36  explicit ATenCompilationUnit();
37 
41  void define(const std::string& language);
42 
44  // TODO: Pass struct to allow autotuning
45  size_t compile(
46  const std::string& name,
47  const std::vector<at::Tensor>& inputs,
48  const MappingOptions& options);
49 
51  std::vector<const DLTensor*> inferOutputTensorInfo(
52  const std::string& name,
53  const std::vector<at::Tensor>& inputs);
54 
58  Duration run(
59  const std::string& name,
60  const std::vector<at::Tensor>& inputs,
61  std::vector<at::Tensor>& outputs,
62  size_t handle,
63  bool profile = false);
64 
68  void uncheckedRun(
69  const std::vector<at::Tensor>& inputs,
70  std::vector<at::Tensor>& outputs,
71  size_t handle);
72 
73  private:
74  std::unique_ptr<ExecutionEngine> executionEngine_;
75 };
76 
77 std::pair<std::vector<DLTensor*>, std::vector<DLManagedTensor*>>
78 toDlpackTensors(const std::vector<at::Tensor>& tensors);
79 
80 std::pair<std::vector<const DLTensor*>, std::vector<DLManagedTensor*>>
81 toConstDlpackTensors(const std::vector<at::Tensor>& tensors);
82 
83 void deleteDlmTensors(std::vector<DLManagedTensor*>& tensors);
84 
85 } // namespace tc
std::pair< std::vector< const DLTensor * >, std::vector< DLManagedTensor * > > toConstDlpackTensors(const std::vector< at::Tensor > &tensors)
std::unique_ptr< ExecutionEngine > executionEngine_
Definition: aten_compiler.h:74
Definition: aten_compiler.h:34
std::vector< const DLTensor * > inferOutputTensorInfo(const std::string &name, const std::vector< at::Tensor > &inputs)
Get the output Tensor info.
Definition: mapping_options.h:336
void uncheckedRun(const std::vector< at::Tensor > &inputs, std::vector< at::Tensor > &outputs, size_t handle)
std::chrono::high_resolution_clock::duration Duration
Definition: rtc.h:31
Duration run(const std::string &name, const std::vector< at::Tensor > &inputs, std::vector< at::Tensor > &outputs, size_t handle, bool profile=false)
size_t compile(const std::string &name, const std::vector< at::Tensor > &inputs, const MappingOptions &options)
Given a TC name, compile the TC.
void define(const std::string &language)
void deleteDlmTensors(std::vector< DLManagedTensor * > &tensors)
std::pair< std::vector< DLTensor * >, std::vector< DLManagedTensor * > > toDlpackTensors(const std::vector< at::Tensor > &tensors)