Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dlpack.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <iostream>
19 #include <memory>
20 #include <sstream>
21 #include <string>
22 #include <vector>
23 
24 #include <glog/logging.h>
25 
26 #include <dlpack/dlpack.h>
27 
28 //
29 // Various utilities for DLPack, in particular DLTensor.
30 //
31 
32 namespace tc {
33 namespace dlutils {
34 
35 DLContext getCPUDLContext();
36 DLContext getGPUDLContext(int device_id = 0);
37 template <typename T>
38 DLDataType getDLDataType();
39 
41  inline void operator()(DLTensor* t) {
42  if (t->shape) {
43  delete[] t->shape;
44  }
45  if (t->strides) {
46  delete[] t->strides;
47  }
48  delete t;
49  }
50 };
51 typedef std::shared_ptr<DLTensor> DLTensorSPtr;
52 typedef std::unique_ptr<DLTensor, DLTensorDeleter> DLTensorUPtr;
53 
54 void SetStridesFromSizes(DLTensor& t, const std::vector<int64_t>&);
55 void SetSizes(DLTensor& t, const std::vector<int64_t>& sizes);
56 void SetStrides(DLTensor& t, const std::vector<int64_t>& strides);
58  DLContext ctx,
59  DLDataType dtype,
60  const std::vector<int64_t>& sizes);
61 
62 std::vector<const DLTensor*> extractRawPtrs(
63  const std::vector<DLTensorUPtr>& uptrs);
64 std::vector<const DLTensor*> constPtrs(const std::vector<DLTensor*>& ptrs);
65 
66 // Deep copies
67 DLTensorUPtr makeDLTensor(const DLTensor* ptr);
68 
69 template <typename T>
70 std::vector<DLTensorUPtr> makeDLTensorVector(const std::vector<T*>& ptrs);
71 
72 bool operator==(const DLDataType& t1, const DLDataType& t2);
73 std::string toString(const DLDataType& t);
74 std::ostream& operator<<(std::ostream& os, const DLTensor& t);
75 std::ostream& operator<<(std::ostream& os, const DLDataType& t);
76 
77 // Shape/stride/type-only comparisons
78 bool compareDLTensorMetadata(const DLTensor& t1, const DLTensor& t2);
79 template <typename T, typename TT>
81  const std::vector<T*>& v1,
82  const std::vector<TT*>& v2);
83 } // namespace dlutils
84 } // namespace tc
85 
void SetSizes(DLTensor &t, const std::vector< int64_t > &sizes)
Definition: dlpack-inl.h:96
bool operator==(const DLDataType &t1, const DLDataType &t2)
Definition: dlpack-inl.h:210
std::vector< const DLTensor * > extractRawPtrs(const std::vector< DLTensorUPtr > &uptrs)
Definition: dlpack-inl.h:136
bool compareDLTensorMetadata(const DLTensor &t1, const DLTensor &t2)
Definition: dlpack-inl.h:219
Definition: dlpack.h:40
std::shared_ptr< DLTensor > DLTensorSPtr
Definition: dlpack.h:51
std::ostream & operator<<(std::ostream &os, const DLDataType &t)
Definition: dlpack-inl.h:214
std::string toString(const DLDataType &t)
Definition: dlpack-inl.h:21
bool compareDLTensorVectorMetadata(const std::vector< T * > &v1, const std::vector< TT * > &v2)
Definition: dlpack-inl.h:242
std::vector< const DLTensor * > constPtrs(const std::vector< DLTensor * > &ptrs)
Definition: dlpack-inl.h:146
void operator()(DLTensor *t)
Definition: dlpack.h:41
DLTensorUPtr makeDLTensorWithSizes(DLContext ctx, DLDataType dtype, const std::vector< int64_t > &sizes)
Definition: dlpack-inl.h:118
DLContext getCPUDLContext()
Definition: dlpack-inl.h:59
DLDataType getDLDataType()
std::unique_ptr< DLTensor, DLTensorDeleter > DLTensorUPtr
Definition: dlpack.h:52
void SetStridesFromSizes(DLTensor &t, const std::vector< int64_t > &)
Definition: dlpack-inl.h:110
DLContext getGPUDLContext(int device_id)
Definition: dlpack-inl.h:68
std::vector< DLTensorUPtr > makeDLTensorVector(const std::vector< T * > &ptrs)
Definition: dlpack-inl.h:181
DLTensorUPtr makeDLTensor(const DLTensor *ptr)
Definition: dlpack-inl.h:156
void SetStrides(DLTensor &t, const std::vector< int64_t > &strides)
Definition: dlpack-inl.h:103