From f6af76ead7f03b1e75a920d93c3d2d387f5eaef7 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Sun, 7 Apr 2019 01:35:11 -0700 Subject: [PATCH] Remove tensorFromBlob() from Type (#18779) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18779 ghimport-source-id: e7453b74fcce0e4f4a9cbce0324992a85272a426 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18780 Remove tensorWithAllocator() from Type * **#18779 Remove tensorFromBlob() from Type** Differential Revision: D14739335 fbshipit-source-id: 8a0619a5b412332efa3b2d60c1edebd53d089d50 --- aten/src/ATen/DLConvertor.cpp | 17 +++++---- aten/src/ATen/SparseTensorUtils.h | 6 ++-- aten/src/ATen/TensorUtils.cpp | 25 +++++++++++++ aten/src/ATen/TensorUtils.h | 7 +++- aten/src/ATen/UndefinedType.cpp | 3 -- aten/src/ATen/UndefinedType.h | 1 - aten/src/ATen/core/Type.h | 3 -- aten/src/ATen/native/Resize.h | 13 +------ aten/src/ATen/native/cuda/TensorTransformations.cu | 9 +++-- aten/src/ATen/templates/Functions.h | 41 +++++++++++++++++++++- aten/src/ATen/templates/NativeFunctions.h | 17 --------- aten/src/ATen/templates/Type.h | 3 -- aten/src/ATen/templates/TypeDefault.cpp | 41 ++-------------------- aten/src/ATen/templates/TypeDefault.h | 3 -- aten/src/ATen/test/atest.cpp | 9 +++-- caffe2/contrib/aten/aten_op_template.h | 16 +++++---- tools/autograd/templates/VariableType.h | 1 - torch/csrc/autograd/VariableTypeManual.cpp | 3 -- torch/csrc/utils/tensor_numpy.cpp | 15 +++++--- 19 files changed, 116 insertions(+), 117 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index cf9daa5..8604ec5 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -67,21 +67,20 @@ static DLContext getDLContext(const Tensor& tensor, const int64_t& device_id) { return ctx; } -static DeviceType getATenDeviceType(const DLContext& ctx) { +static Device getATenDevice(const DLContext& ctx) { switch (ctx.device_type) { case DLDeviceType::kDLCPU: - return DeviceType::CPU; + return at::Device(DeviceType::CPU); case DLDeviceType::kDLGPU: - return DeviceType::CUDA; + return at::Device(DeviceType::CUDA, ctx.device_id); case DLDeviceType::kDLOpenCL: - return DeviceType::OPENCL; + return at::Device(DeviceType::OPENCL, ctx.device_id); case DLDeviceType::kDLROCM: - return DeviceType::HIP; + return at::Device(DeviceType::HIP, ctx.device_id); default: throw std::logic_error( "Unsupported device_type: " + std::to_string(ctx.device_type)); } - return DeviceType::CPU; // impossible } ScalarType toScalarType(const DLDataType& dtype) { @@ -173,7 +172,7 @@ DLManagedTensor* toDLPack(const Tensor& src) { } Tensor fromDLPack(const DLManagedTensor* src) { - DeviceType device_type = getATenDeviceType(src->dl_tensor.ctx); + Device device = getATenDevice(src->dl_tensor.ctx); ScalarType stype = toScalarType(src->dl_tensor.dtype); auto deleter = [src](void* self) { src->deleter(const_cast(src)); @@ -182,7 +181,7 @@ Tensor fromDLPack(const DLManagedTensor* src) { return at::from_blob(src->dl_tensor.data, IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), deleter, - at::device(device_type).dtype(stype)); + at::device(device).dtype(stype)); } return at::from_blob( @@ -190,6 +189,6 @@ Tensor fromDLPack(const DLManagedTensor* src) { IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim), deleter, - at::device(device_type).dtype(stype)); + at::device(device).dtype(stype)); } } // namespace at diff --git a/aten/src/ATen/SparseTensorUtils.h b/aten/src/ATen/SparseTensorUtils.h index 8113367..f602f09 100644 --- a/aten/src/ATen/SparseTensorUtils.h +++ b/aten/src/ATen/SparseTensorUtils.h @@ -85,8 +85,10 @@ inline LongTensor flatten_indices(const Tensor& indices, IntArrayRef full_size, indices_mult_cpu_vec[i] = mult; mult *= full_size[i]; } - auto indices_mult_cpu = indices.dispatch_type().cpu() - .tensorFromBlob(indices_mult_cpu_vec.data(), /*size=*/{sparse_dim, 1}); + auto indices_mult_cpu = at::from_blob( + indices_mult_cpu_vec.data(), + /*size=*/{sparse_dim, 1}, + indices.options().device(kCPU)); // NB: must be blocking because this blob may be freed after this closure, // and non_blocking copy will see garbage. auto indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false); diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index d89243b..200bc23 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -235,4 +235,29 @@ bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides) { return contig_if_nonempty; } +namespace detail { + +std::vector defaultStrides(IntArrayRef sizes) { + std::vector strides(sizes.size()); + int64_t stride = 1; + for(size_t i = sizes.size(); i > 0; --i) { + strides[i-1] = stride; + stride *= sizes[i-1]; + } + return strides; +} + +int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) { + // size of the underlying storage is 1 bigger than the offset + // of the last element according to stride + int64_t size = 1; + for(size_t i = 0; i < sizes.size(); i++) { + if(sizes[i] == 0) { + return 0; + } + size += strides[i]*(sizes[i]-1); + } + return size; } +} // namespace detail +} // namespace at diff --git a/aten/src/ATen/TensorUtils.h b/aten/src/ATen/TensorUtils.h index 92bf15d..2b548ae 100644 --- a/aten/src/ATen/TensorUtils.h +++ b/aten/src/ATen/TensorUtils.h @@ -125,4 +125,9 @@ CAFFE2_API void* maybe_data_ptr(const TensorArg& tensor); // constructing a tensor, e.g., when you want to choose a kernel strategy based // on whether a subgeometry is contiguous. CAFFE2_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides); -} + +namespace detail { +CAFFE2_API std::vector defaultStrides(IntArrayRef sizes); +CAFFE2_API int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides); +} // namespace detail +} // namespace at diff --git a/aten/src/ATen/UndefinedType.cpp b/aten/src/ATen/UndefinedType.cpp index c608c43..2c780e7 100644 --- a/aten/src/ATen/UndefinedType.cpp +++ b/aten/src/ATen/UndefinedType.cpp @@ -23,9 +23,6 @@ Device UndefinedType::getDeviceFromPtr(void*) const { AT_ERROR("getDeviceFromPtr not defined for UndefinedType"); } -Storage UndefinedType::storageFromBlob(void * data, int64_t size, const std::function & deleter) const { - AT_ERROR("storageFromBlob not defined for UndefinedType"); -} Storage UndefinedType::unsafeStorageFromTH(void * th_pointer, bool retain) const { AT_ERROR("unsafeStorageFromTH not defined for UndefinedType"); } diff --git a/aten/src/ATen/UndefinedType.h b/aten/src/ATen/UndefinedType.h index 095b6fe..ff89b4a 100644 --- a/aten/src/ATen/UndefinedType.h +++ b/aten/src/ATen/UndefinedType.h @@ -18,7 +18,6 @@ struct UndefinedType final : public TypeDefault { virtual Backend backend() const override; virtual Allocator* allocator() const override; virtual Device getDeviceFromPtr(void* data) const override; - virtual Storage storageFromBlob(void * data, int64_t size, const std::function & deleter) const override; virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const override; virtual std::unique_ptr generator() const override; virtual const char * toString() const override; diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 0eea7bb..2b1dd0a 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -128,7 +128,6 @@ struct CAFFE2_API Type { bool is_undefined() const noexcept { return is_undefined_; } virtual Allocator * allocator() const = 0; virtual Device getDeviceFromPtr(void * data) const = 0; - virtual Storage storageFromBlob(void * data, int64_t size, const std::function & deleter=noop_deleter) const = 0; virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const = 0; virtual std::unique_ptr generator() const = 0; virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0; @@ -176,8 +175,6 @@ struct CAFFE2_API Type { bool create_graph) const = 0; virtual void set_data(Tensor & self, Tensor new_data) const = 0; - virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, const std::function & deleter=noop_deleter) const = 0; - virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function & deleter=noop_deleter) const = 0; virtual Tensor tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const = 0; virtual Tensor tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const = 0; diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index e39c28d..27ceb5f 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -52,23 +52,12 @@ inline TensorImpl* resize_impl_cpu_( return self; } -static inline int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) { - int64_t storage_size = 1; - for (size_t dim = 0; dim < sizes.size(); ++dim) { - if (sizes[dim] == 0) { - return 0; - } - storage_size += strides[dim] * (sizes[dim] - 1); - } - return storage_size; -} - static inline void checkInBoundsForStorage( IntArrayRef size, IntArrayRef stride, int64_t storage_offset, const Storage& new_storage) { - int64_t storage_size = computeStorageSize(size, stride); + int64_t storage_size = detail::computeStorageSize(size, stride); if (storage_size == 0) { // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel. return; diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index d0bf923..d078634 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -99,13 +99,16 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { return out_tensor; } - auto flip_dims_t = at::CPU(kLong).tensorFromBlob(flip_dims.data(), {static_cast(flip_dims.size())}); + auto flip_dims_t = at::from_blob( + flip_dims.data(), {static_cast(flip_dims.size())}, at::device(kCPU).dtype(kLong)); auto shape = in_tensor.sizes().vec(); - auto shape_t = at::CPU(kLong).tensorFromBlob(shape.data(), {static_cast(shape.size())}); + auto shape_t = at::from_blob( + shape.data(), {static_cast(shape.size())}, at::device(kCPU).dtype(kLong)); auto strides = in_tensor.strides().vec(); - auto strides_t = at::CPU(kLong).tensorFromBlob(strides.data(), {static_cast(strides.size())}); + auto strides_t = at::from_blob( + strides.data(), {static_cast(strides.size())}, at::device(kCPU).dtype(kLong)); // stride_contiguous is the stride of non-contiguous tensor after calling contiguous(), // it is used to compute indices for each element in non-contiguous tensor diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h index a053284..47477ba 100644 --- a/aten/src/ATen/templates/Functions.h +++ b/aten/src/ATen/templates/Functions.h @@ -14,14 +14,53 @@ #include #include #include +#include namespace at { -using native::from_blob; using native::tensor; ${function_declarations} +inline Tensor from_blob( + void* data, + IntArrayRef sizes, + IntArrayRef strides, + const std::function& deleter, + const TensorOptions& options = {}) { + auto storage = Storage( + options.dtype(), + detail::computeStorageSize(sizes, strides), + InefficientStdFunctionContext::makeDataPtr( + data, deleter, options.device()), + /*allocator=*/nullptr, + /*resizable=*/false); + return empty({0}, options).set_(storage, 0, sizes, strides); +} + +inline Tensor from_blob( + void* data, + IntArrayRef sizes, + const std::function& deleter, + const TensorOptions& options = {}) { + return from_blob(data, sizes, detail::defaultStrides(sizes), deleter, options); +} + +inline Tensor from_blob( + void* data, + IntArrayRef sizes, + IntArrayRef strides, + const TensorOptions& options = {}) { + return from_blob(data, sizes, strides, [](void*) {}, options); +} + +inline Tensor from_blob( + void* data, + IntArrayRef sizes, + const TensorOptions& options = {}) { + return from_blob(data, sizes, detail::defaultStrides(sizes), [](void*) {}, options); +} + namespace detail { static inline TypeExtendedInterface & infer_type(const Tensor & t) { diff --git a/aten/src/ATen/templates/NativeFunctions.h b/aten/src/ATen/templates/NativeFunctions.h index 5040465..e3fa315 100644 --- a/aten/src/ATen/templates/NativeFunctions.h +++ b/aten/src/ATen/templates/NativeFunctions.h @@ -25,23 +25,6 @@ struct Type; namespace at { namespace native { -inline Tensor from_blob( - void* data, - IntArrayRef sizes, - const std::function& deleter, - const TensorOptions& options = {}) { - return at::getType(options).tensorFromBlob(data, sizes, deleter); -} - -inline Tensor from_blob( - void* data, - IntArrayRef sizes, - IntArrayRef strides, - const std::function& deleter, - const TensorOptions& options = {}) { - return at::getType(options).tensorFromBlob(data, sizes, strides, deleter); -} - // These functions are defined in native/TensorFactories.cpp. #define TENSOR(T, S, _1) \ CAFFE2_API Tensor tensor(ArrayRef values, const TensorOptions& options); \ diff --git a/aten/src/ATen/templates/Type.h b/aten/src/ATen/templates/Type.h index d1d2a28..87914fc 100644 --- a/aten/src/ATen/templates/Type.h +++ b/aten/src/ATen/templates/Type.h @@ -71,7 +71,6 @@ struct CAFFE2_API Type { bool is_undefined() const noexcept { return is_undefined_; } virtual Allocator * allocator() const = 0; virtual Device getDeviceFromPtr(void * data) const = 0; - virtual Storage storageFromBlob(void * data, int64_t size, const std::function & deleter=noop_deleter) const = 0; virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const = 0; virtual std::unique_ptr generator() const = 0; virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0; @@ -119,8 +118,6 @@ struct CAFFE2_API Type { bool create_graph) const = 0; virtual void set_data(Tensor & self, Tensor new_data) const = 0; - virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, const std::function & deleter=noop_deleter) const = 0; - virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function & deleter=noop_deleter) const = 0; virtual Tensor tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const = 0; virtual Tensor tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const = 0; diff --git a/aten/src/ATen/templates/TypeDefault.cpp b/aten/src/ATen/templates/TypeDefault.cpp index e89b6f1..2711e5c 100644 --- a/aten/src/ATen/templates/TypeDefault.cpp +++ b/aten/src/ATen/templates/TypeDefault.cpp @@ -57,51 +57,14 @@ Type & TypeDefault::toBackend(Backend b) const { Type & TypeDefault::toScalarType(ScalarType s) const { return at::globalContext().getNonVariableType(backend(),s); } -static std::vector defaultStrides(IntArrayRef sizes) { - std::vector strides(sizes.size()); - int64_t stride = 1; - for(size_t i = sizes.size(); i > 0; --i) { - strides[i-1] = stride; - stride *= sizes[i-1]; - } - return strides; -} -static int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) { - // size of the underlying storage is 1 bigger than the offset - // of the last element according to stride - int64_t size = 1; - for(size_t i = 0; i < sizes.size(); i++) { - if(sizes[i] == 0) { - return 0; - } - size += strides[i]*(sizes[i]-1); - } - return size; -} -Tensor TypeDefault::tensorFromBlob(void * data, IntArrayRef sizes, const std::function & deleter) const { - return tensorFromBlob(data, sizes, defaultStrides(sizes), deleter); -} -Tensor TypeDefault::tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function & deleter) const { - auto storage = storageFromBlob(data, computeStorageSize(sizes, strides), deleter); - return at::empty({0}, options()).set_(storage, 0, sizes, strides); -} Tensor TypeDefault::tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const { - return tensorWithAllocator(sizes, defaultStrides(sizes), std::move(allocator)); + return tensorWithAllocator(sizes, detail::defaultStrides(sizes), std::move(allocator)); } Tensor TypeDefault::tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const { - auto storage = storageWithAllocator(computeStorageSize(sizes, strides), std::move(allocator)); + auto storage = storageWithAllocator(detail::computeStorageSize(sizes, strides), std::move(allocator)); return at::empty({0}, options()).set_(storage, 0, sizes, strides); } -Storage TypeDefault::storageFromBlob(void * data, int64_t size, const std::function & deleter) const { - return Storage( - typeMeta(), - size, - InefficientStdFunctionContext::makeDataPtr( - data, deleter, getDeviceFromPtr(data)), - /*allocator=*/nullptr, - /*resizable=*/false); -} Storage TypeDefault::storageWithAllocator(int64_t size, Allocator* allocator) const { // Potentially the storage might be marked as resizable too here return Storage(typeMeta(), size, allocator, /*resizable=*/false); diff --git a/aten/src/ATen/templates/TypeDefault.h b/aten/src/ATen/templates/TypeDefault.h index e8a1517..b597ace 100644 --- a/aten/src/ATen/templates/TypeDefault.h +++ b/aten/src/ATen/templates/TypeDefault.h @@ -38,12 +38,9 @@ struct CAFFE2_API TypeDefault : public TypeExtendedInterface { bool create_graph) const override; void set_data(Tensor & self, Tensor new_data) const override; - Tensor tensorFromBlob(void * data, IntArrayRef sizes, const std::function & deleter=noop_deleter) const override; - Tensor tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function & deleter=noop_deleter) const override; Tensor tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const override; Tensor tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const override; - Storage storageFromBlob(void * data, int64_t size, const std::function & deleter) const override; Storage storageWithAllocator(int64_t size, Allocator* allocator) const override; Storage unsafeStorageFromTH(void * th_pointer, bool retain) const override; Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override; diff --git a/aten/src/ATen/test/atest.cpp b/aten/src/ATen/test/atest.cpp index 86d07d4..e681d78 100644 --- a/aten/src/ATen/test/atest.cpp +++ b/aten/src/ATen/test/atest.cpp @@ -53,7 +53,7 @@ TEST(atest, atest) { float data[] = {1, 2, 3, 4, 5, 6}; - auto f = CPU(kFloat).tensorFromBlob(data, {1, 2, 3}); + auto f = from_blob(data, {1, 2, 3}); auto f_a = f.accessor(); ASSERT_EQ(f_a[0][0][0], 1.0); @@ -72,7 +72,7 @@ TEST(atest, atest) { int isgone = 0; { auto f2 = - CPU(kFloat).tensorFromBlob(data, {1, 2, 3}, [&](void*) { isgone++; }); + from_blob(data, {1, 2, 3}, [&](void*) { isgone++; }); } ASSERT_EQ(isgone, 1); } @@ -81,7 +81,7 @@ TEST(atest, atest) { Tensor a_view; { auto f2 = - CPU(kFloat).tensorFromBlob(data, {1, 2, 3}, [&](void*) { isgone++; }); + from_blob(data, {1, 2, 3}, [&](void*) { isgone++; }); a_view = f2.view({3, 2, 1}); } ASSERT_EQ(isgone, 0); @@ -93,8 +93,7 @@ TEST(atest, atest) { int isgone = 0; { auto base = at::empty({1,2,3}, TensorOptions(kCUDA)); - auto f2 = CUDA(kFloat).tensorFromBlob( - base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; }); + auto f2 = from_blob(base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; }); } ASSERT_EQ(isgone, 1); } diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index b597084..8c5f02e 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -54,18 +54,22 @@ private: #undef DEFINE_CASE } - at::Type& typeFor(const Tensor& ten) { - at::Backend b = backend(); + at::TensorOptions optionsFor(const Tensor& ten) { + at::Device device = ten.GetDevice(); #ifdef __HIP_PLATFORM_HCC__ - if (b == at::Backend::HIP) { - b = at::Backend::CUDA; + if (backend() == at::Backend::HIP) { + device = at::Device(kCUDA, device.index()); } #endif - return at::getNonVariableType(b, typeMetaToScalarType(ten.meta())); + return at::TensorOptions(device).dtype(ten.dtype()); } + at::Tensor tensorWrapping(const Tensor& ten_) { auto& ten = const_cast(ten_); - return typeFor(ten).tensorFromBlob(ten.raw_mutable_data(), ten.sizes()); + return at::from_blob( + ten.raw_mutable_data(), + ten.sizes(), + optionsFor(ten)); } at::Tensor peek(size_t i, size_t N) { diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h index 1b1ded3..0183c27 100644 --- a/tools/autograd/templates/VariableType.h +++ b/tools/autograd/templates/VariableType.h @@ -38,7 +38,6 @@ struct TORCH_API VariableType final : public at::TypeDefault { at::Backend backend() const override; at::Allocator* allocator() const override; at::Device getDeviceFromPtr(void * data) const override; - Storage storageFromBlob(void * data, int64_t size, const std::function & deleter) const override; Storage storageWithAllocator(int64_t size, at::Allocator* allocator) const override; std::unique_ptr generator() const override; const char * toString() const override; diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 27d6fe4..e012df2 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -31,9 +31,6 @@ Allocator* VariableType::allocator() const { Device VariableType::getDeviceFromPtr(void * data) const { return baseType->getDeviceFromPtr(data); } -Storage VariableType::storageFromBlob(void * data, int64_t size, const std::function & deleter) const { - return baseType->storageFromBlob(data, size, deleter); -} Storage VariableType::unsafeStorageFromTH(void * th_pointer, bool retain) const { return baseType->unsafeStorageFromTH(th_pointer, retain); } diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 134ab7c..675c41c 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -133,17 +133,22 @@ at::Tensor tensor_from_numpy(PyObject* obj) { } void* data_ptr = PyArray_DATA(array); - auto& type = CPU(numpy_dtype_to_aten(PyArray_TYPE(array))); if (!PyArray_EquivByteorders(PyArray_DESCR(array)->byteorder, NPY_NATIVE)) { throw ValueError( "given numpy array has byte order different from the native byte order. " "Conversion between byte orders is currently not supported."); } Py_INCREF(obj); - return type.tensorFromBlob(data_ptr, sizes, strides, [obj](void* data) { - AutoGIL gil; - Py_DECREF(obj); - }); + return at::from_blob( + data_ptr, + sizes, + strides, + [obj](void* data) { + AutoGIL gil; + Py_DECREF(obj); + }, + at::device(kCPU).dtype(numpy_dtype_to_aten(PyArray_TYPE(array))) + ); } static int aten_to_dtype(const ScalarType scalar_type) { -- 2.7.4