26 #include "caffe2/core/common.h"
31 template <
typename C2Context>
47 if (meta.Match<
float>()) {
48 res.code = DLDataTypeCode::kDLFloat;
49 }
else if (meta.Match<
int>()) {
50 res.code = DLDataTypeCode::kDLInt;
52 CHECK(
false) <<
"NYI: getDLDataType(caffe2::Meta::Make<" << meta.name()
60 template <
typename C2Context>
62 const caffe2::Tensor<C2Context>& tensor,
63 const vector<TIndex>& shapeOverride = {}) {
64 const auto& dims = shapeOverride.empty() ? tensor.dims() : shapeOverride;
65 if (!shapeOverride.empty()) {
66 auto overrideSize = std::accumulate(
69 static_cast<TIndex
>(1),
70 std::multiplies<TIndex>());
71 CAFFE_ENFORCE_EQ(overrideSize, tensor.size());
74 res->data =
const_cast<void*
>(tensor.raw_data());
75 res->ctx = getDLContext<C2Context>();
76 auto ndim = dims.size();
79 res->shape =
new int64_t[ndim];
81 res->strides =
new int64_t[ndim];
87 template <
typename C2Context>
90 res->data = tensor->raw_mutable_data();
void SetSizes(DLTensor &t, const std::vector< int64_t > &sizes)
Definition: dlpack-inl.h:96
DLContext getDLContext< CUDAContext >()
Definition: dlpack_c2.h:41
tc::dlutils::DLTensorUPtr makeConstDLTensor(const caffe2::Tensor< C2Context > &tensor, const vector< TIndex > &shapeOverride={})
Definition: dlpack_c2.h:61
DLContext getCPUDLContext()
Definition: dlpack-inl.h:59
std::unique_ptr< DLTensor, DLTensorDeleter > DLTensorUPtr
Definition: dlpack.h:52
void SetStridesFromSizes(DLTensor &t, const std::vector< int64_t > &)
Definition: dlpack-inl.h:110
tc::dlutils::DLTensorUPtr makeDLTensor(caffe2::Tensor< C2Context > *tensor)
Definition: dlpack_c2.h:88
DLContext getGPUDLContext(int device_id)
Definition: dlpack-inl.h:68
DLDataType getDLDataType(const TypeMeta &meta)
Definition: dlpack_c2.h:45
DLContext getDLContext< CPUContext >()
Definition: dlpack_c2.h:35