Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dlpack_c2.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include <functional>
19 #include <sstream>
20 #include <string>
21 #include <vector>
22 
23 #include "tc/c2/context.h"
24 #include "tc/core/utils/dlpack.h"
25 
26 #include "caffe2/core/common.h"
27 
28 namespace caffe2 {
29 namespace dlpack {
30 
31 template <typename C2Context>
32 DLContext getDLContext();
33 
34 template <>
35 inline DLContext getDLContext<CPUContext>() {
37 }
38 
39 // Can't have a CUDAContext object, how do we get the GPU id of a C2 Tensor?
40 template <>
41 inline DLContext getDLContext<CUDAContext>() {
42  return tc::dlutils::getGPUDLContext(0 /*ctx ? ctx->cuda_gpu_id() : 0*/);
43 }
44 
45 inline DLDataType getDLDataType(const TypeMeta& meta) {
46  DLDataType res;
47  if (meta.Match<float>()) {
48  res.code = DLDataTypeCode::kDLFloat;
49  } else if (meta.Match<int>()) {
50  res.code = DLDataTypeCode::kDLInt;
51  } else {
52  CHECK(false) << "NYI: getDLDataType(caffe2::Meta::Make<" << meta.name()
53  << ">))";
54  }
55  res.bits = 32;
56  res.lanes = 1;
57  return res;
58 }
59 
60 template <typename C2Context>
62  const caffe2::Tensor<C2Context>& tensor,
63  const vector<TIndex>& shapeOverride = {}) {
64  const auto& dims = shapeOverride.empty() ? tensor.dims() : shapeOverride;
65  if (!shapeOverride.empty()) {
66  auto overrideSize = std::accumulate(
67  dims.begin(),
68  dims.end(),
69  static_cast<TIndex>(1),
70  std::multiplies<TIndex>());
71  CAFFE_ENFORCE_EQ(overrideSize, tensor.size());
72  }
73  tc::dlutils::DLTensorUPtr res(new DLTensor);
74  res->data = const_cast<void*>(tensor.raw_data());
75  res->ctx = getDLContext<C2Context>();
76  auto ndim = dims.size();
77  res->ndim = ndim;
78  res->dtype = getDLDataType(tensor.meta());
79  res->shape = new int64_t[ndim];
80  tc::dlutils::SetSizes(*res, dims);
81  res->strides = new int64_t[ndim];
82  tc::dlutils::SetStridesFromSizes(*res, tensor.dims());
83  res->byte_offset = 0;
84  return res;
85 }
86 
87 template <typename C2Context>
88 tc::dlutils::DLTensorUPtr makeDLTensor(caffe2::Tensor<C2Context>* tensor) {
89  auto res = makeConstDLTensor(*tensor);
90  res->data = tensor->raw_mutable_data();
91  return res;
92 }
93 
94 } // namespace dlpack
95 } // namespace caffe2
void SetSizes(DLTensor &t, const std::vector< int64_t > &sizes)
Definition: dlpack-inl.h:96
DLContext getDLContext< CUDAContext >()
Definition: dlpack_c2.h:41
tc::dlutils::DLTensorUPtr makeConstDLTensor(const caffe2::Tensor< C2Context > &tensor, const vector< TIndex > &shapeOverride={})
Definition: dlpack_c2.h:61
DLContext getCPUDLContext()
Definition: dlpack-inl.h:59
std::unique_ptr< DLTensor, DLTensorDeleter > DLTensorUPtr
Definition: dlpack.h:52
void SetStridesFromSizes(DLTensor &t, const std::vector< int64_t > &)
Definition: dlpack-inl.h:110
tc::dlutils::DLTensorUPtr makeDLTensor(caffe2::Tensor< C2Context > *tensor)
Definition: dlpack_c2.h:88
DLContext getGPUDLContext(int device_id)
Definition: dlpack-inl.h:68
DLDataType getDLDataType(const TypeMeta &meta)
Definition: dlpack_c2.h:45
DLContext getDLContext()
DLContext getDLContext< CPUContext >()
Definition: dlpack_c2.h:35