Remove tensorFromBlob() from Type (#18779)
authorRoy Li <royboy@fb.com>
Sun, 7 Apr 2019 08:35:11 +0000 (01:35 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 7 Apr 2019 08:37:43 +0000 (01:37 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18779
ghimport-source-id: e7453b74fcce0e4f4a9cbce0324992a85272a426

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18780 Remove tensorWithAllocator() from Type
* **#18779 Remove tensorFromBlob() from Type**

Differential Revision: D14739335

fbshipit-source-id: 8a0619a5b412332efa3b2d60c1edebd53d089d50

19 files changed:
aten/src/ATen/DLConvertor.cpp
aten/src/ATen/SparseTensorUtils.h
aten/src/ATen/TensorUtils.cpp
aten/src/ATen/TensorUtils.h
aten/src/ATen/UndefinedType.cpp
aten/src/ATen/UndefinedType.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/Resize.h
aten/src/ATen/native/cuda/TensorTransformations.cu
aten/src/ATen/templates/Functions.h
aten/src/ATen/templates/NativeFunctions.h
aten/src/ATen/templates/Type.h
aten/src/ATen/templates/TypeDefault.cpp
aten/src/ATen/templates/TypeDefault.h
aten/src/ATen/test/atest.cpp
caffe2/contrib/aten/aten_op_template.h
tools/autograd/templates/VariableType.h
torch/csrc/autograd/VariableTypeManual.cpp
torch/csrc/utils/tensor_numpy.cpp

index cf9daa5..8604ec5 100644 (file)
@@ -67,21 +67,20 @@ static DLContext getDLContext(const Tensor& tensor, const int64_t& device_id) {
   return ctx;
 }
 
-static DeviceType getATenDeviceType(const DLContext& ctx) {
+static Device getATenDevice(const DLContext& ctx) {
   switch (ctx.device_type) {
     case DLDeviceType::kDLCPU:
-      return DeviceType::CPU;
+      return at::Device(DeviceType::CPU);
     case DLDeviceType::kDLGPU:
-      return DeviceType::CUDA;
+      return at::Device(DeviceType::CUDA, ctx.device_id);
     case DLDeviceType::kDLOpenCL:
-      return DeviceType::OPENCL;
+      return at::Device(DeviceType::OPENCL, ctx.device_id);
     case DLDeviceType::kDLROCM:
-      return DeviceType::HIP;
+      return at::Device(DeviceType::HIP, ctx.device_id);
     default:
       throw std::logic_error(
           "Unsupported device_type: " + std::to_string(ctx.device_type));
   }
-  return DeviceType::CPU; // impossible
 }
 
 ScalarType toScalarType(const DLDataType& dtype) {
@@ -173,7 +172,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
 }
 
 Tensor fromDLPack(const DLManagedTensor* src) {
-  DeviceType device_type = getATenDeviceType(src->dl_tensor.ctx);
+  Device device = getATenDevice(src->dl_tensor.ctx);
   ScalarType stype = toScalarType(src->dl_tensor.dtype);
   auto deleter = [src](void* self) {
     src->deleter(const_cast<DLManagedTensor*>(src));
@@ -182,7 +181,7 @@ Tensor fromDLPack(const DLManagedTensor* src) {
     return at::from_blob(src->dl_tensor.data,
         IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
         deleter,
-        at::device(device_type).dtype(stype));
+        at::device(device).dtype(stype));
   }
 
   return at::from_blob(
@@ -190,6 +189,6 @@ Tensor fromDLPack(const DLManagedTensor* src) {
       IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
       IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
       deleter,
-      at::device(device_type).dtype(stype));
+      at::device(device).dtype(stype));
 }
 } // namespace at
index 8113367..f602f09 100644 (file)
@@ -85,8 +85,10 @@ inline LongTensor flatten_indices(const Tensor& indices, IntArrayRef full_size,
       indices_mult_cpu_vec[i] = mult;
       mult *= full_size[i];
     }
-    auto indices_mult_cpu = indices.dispatch_type().cpu()
-                                   .tensorFromBlob(indices_mult_cpu_vec.data(), /*size=*/{sparse_dim, 1});
+    auto indices_mult_cpu = at::from_blob(
+        indices_mult_cpu_vec.data(),
+        /*size=*/{sparse_dim, 1},
+        indices.options().device(kCPU));
     // NB: must be blocking because this blob may be freed after this closure,
     //     and non_blocking copy will see garbage.
     auto indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false);
index d89243b..200bc23 100644 (file)
@@ -235,4 +235,29 @@ bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides) {
   return contig_if_nonempty;
 }
 
+namespace detail {
+
+std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
+  std::vector<int64_t> strides(sizes.size());
+  int64_t stride = 1;
+  for(size_t i = sizes.size(); i > 0; --i) {
+    strides[i-1] = stride;
+    stride *= sizes[i-1];
+  }
+  return strides;
+}
+
+int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) {
+  // size of the underlying storage is 1 bigger than the offset
+  // of the last element according to stride
+  int64_t size = 1;
+  for(size_t i = 0; i < sizes.size(); i++) {
+    if(sizes[i] == 0) {
+      return 0;
+    }
+    size += strides[i]*(sizes[i]-1);
+  }
+  return size;
 }
+}  // namespace detail
+}  // namespace at
index 92bf15d..2b548ae 100644 (file)
@@ -125,4 +125,9 @@ CAFFE2_API void* maybe_data_ptr(const TensorArg& tensor);
 // constructing a tensor, e.g., when you want to choose a kernel strategy based
 // on whether a subgeometry is contiguous.
 CAFFE2_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
-}
+
+namespace detail {
+CAFFE2_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
+CAFFE2_API int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides);
+} // namespace detail
+} // namespace at
index c608c43..2c780e7 100644 (file)
@@ -23,9 +23,6 @@ Device UndefinedType::getDeviceFromPtr(void*) const {
   AT_ERROR("getDeviceFromPtr not defined for UndefinedType");
 }
 
-Storage UndefinedType::storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const {
-  AT_ERROR("storageFromBlob not defined for UndefinedType");
-}
 Storage UndefinedType::unsafeStorageFromTH(void * th_pointer, bool retain) const {
   AT_ERROR("unsafeStorageFromTH not defined for UndefinedType");
 }
index 095b6fe..ff89b4a 100644 (file)
@@ -18,7 +18,6 @@ struct UndefinedType final : public TypeDefault {
   virtual Backend backend() const override;
   virtual Allocator* allocator() const override;
   virtual Device getDeviceFromPtr(void* data) const override;
-  virtual Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const override;
   virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const override;
   virtual std::unique_ptr<Generator> generator() const override;
   virtual const char * toString() const override;
index 0eea7bb..2b1dd0a 100644 (file)
@@ -128,7 +128,6 @@ struct CAFFE2_API Type {
   bool is_undefined() const noexcept { return is_undefined_; }
   virtual Allocator * allocator() const = 0;
   virtual Device getDeviceFromPtr(void * data) const = 0;
-  virtual Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
   virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const = 0;
   virtual std::unique_ptr<Generator> generator() const = 0;
   virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0;
@@ -176,8 +175,6 @@ struct CAFFE2_API Type {
       bool create_graph) const = 0;
   virtual void set_data(Tensor & self, Tensor new_data) const = 0;
 
-  virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
-  virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
   virtual Tensor tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const = 0;
   virtual Tensor tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const = 0;
 
index e39c28d..27ceb5f 100644 (file)
@@ -52,23 +52,12 @@ inline TensorImpl* resize_impl_cpu_(
   return self;
 }
 
-static inline int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) {
-  int64_t storage_size = 1;
-  for (size_t dim = 0; dim < sizes.size(); ++dim) {
-    if (sizes[dim] == 0) {
-      return 0;
-    }
-    storage_size += strides[dim] * (sizes[dim] - 1);
-  }
-  return storage_size;
-}
-
 static inline void checkInBoundsForStorage(
     IntArrayRef size,
     IntArrayRef stride,
     int64_t storage_offset,
     const Storage& new_storage) {
-  int64_t storage_size = computeStorageSize(size, stride);
+  int64_t storage_size = detail::computeStorageSize(size, stride);
   if (storage_size == 0) {
     // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
     return;
index d0bf923..d078634 100644 (file)
@@ -99,13 +99,16 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) {
     return out_tensor;
   }
 
-  auto flip_dims_t = at::CPU(kLong).tensorFromBlob(flip_dims.data(), {static_cast<int64_t>(flip_dims.size())});
+  auto flip_dims_t = at::from_blob(
+      flip_dims.data(), {static_cast<int64_t>(flip_dims.size())}, at::device(kCPU).dtype(kLong));
 
   auto shape = in_tensor.sizes().vec();
-  auto shape_t = at::CPU(kLong).tensorFromBlob(shape.data(), {static_cast<int64_t>(shape.size())});
+  auto shape_t = at::from_blob(
+      shape.data(), {static_cast<int64_t>(shape.size())}, at::device(kCPU).dtype(kLong));
 
   auto strides = in_tensor.strides().vec();
-  auto strides_t = at::CPU(kLong).tensorFromBlob(strides.data(), {static_cast<int64_t>(strides.size())});
+  auto strides_t = at::from_blob(
+      strides.data(), {static_cast<int64_t>(strides.size())}, at::device(kCPU).dtype(kLong));
 
   // stride_contiguous is the stride of non-contiguous tensor after calling contiguous(),
   // it is used to compute indices for each element in non-contiguous tensor
index a053284..47477ba 100644 (file)
 #include <c10/core/TensorOptions.h>
 #include <ATen/core/Reduction.h>
 #include <c10/util/Optional.h>
+#include <ATen/TensorUtils.h>
 
 namespace at {
 
-using native::from_blob;
 using native::tensor;
 
 ${function_declarations}
 
+inline Tensor from_blob(
+    void* data,
+    IntArrayRef sizes,
+    IntArrayRef strides,
+    const std::function<void(void*)>& deleter,
+    const TensorOptions& options = {}) {
+  auto storage = Storage(
+      options.dtype(),
+      detail::computeStorageSize(sizes, strides),
+      InefficientStdFunctionContext::makeDataPtr(
+          data, deleter, options.device()),
+      /*allocator=*/nullptr,
+      /*resizable=*/false);
+  return empty({0}, options).set_(storage, 0, sizes, strides);
+}
+
+inline Tensor from_blob(
+    void* data,
+    IntArrayRef sizes,
+    const std::function<void(void*)>& deleter,
+    const TensorOptions& options = {}) {
+  return from_blob(data, sizes, detail::defaultStrides(sizes), deleter, options);
+}
+
+inline Tensor from_blob(
+    void* data,
+    IntArrayRef sizes,
+    IntArrayRef strides,
+    const TensorOptions& options = {}) {
+  return from_blob(data, sizes, strides, [](void*) {}, options);
+}
+
+inline Tensor from_blob(
+    void* data,
+    IntArrayRef sizes,
+    const TensorOptions& options = {}) {
+  return from_blob(data, sizes, detail::defaultStrides(sizes), [](void*) {}, options);
+}
+
 namespace detail {
 
 static inline TypeExtendedInterface & infer_type(const Tensor & t) {
index 5040465..e3fa315 100644 (file)
@@ -25,23 +25,6 @@ struct Type;
 namespace at {
 namespace native {
 
-inline Tensor from_blob(
-    void* data,
-    IntArrayRef sizes,
-    const std::function<void(void*)>& deleter,
-    const TensorOptions& options = {}) {
-  return at::getType(options).tensorFromBlob(data, sizes, deleter);
-}
-
-inline Tensor from_blob(
-    void* data,
-    IntArrayRef sizes,
-    IntArrayRef strides,
-    const std::function<void(void*)>& deleter,
-    const TensorOptions& options = {}) {
-  return at::getType(options).tensorFromBlob(data, sizes, strides, deleter);
-}
-
 // These functions are defined in native/TensorFactories.cpp.
 #define TENSOR(T, S, _1)                                                      \
   CAFFE2_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
index d1d2a28..87914fc 100644 (file)
@@ -71,7 +71,6 @@ struct CAFFE2_API Type {
   bool is_undefined() const noexcept { return is_undefined_; }
   virtual Allocator * allocator() const = 0;
   virtual Device getDeviceFromPtr(void * data) const = 0;
-  virtual Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
   virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const = 0;
   virtual std::unique_ptr<Generator> generator() const = 0;
   virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0;
@@ -119,8 +118,6 @@ struct CAFFE2_API Type {
       bool create_graph) const = 0;
   virtual void set_data(Tensor & self, Tensor new_data) const = 0;
 
-  virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
-  virtual Tensor tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
   virtual Tensor tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const = 0;
   virtual Tensor tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const = 0;
 
index e89b6f1..2711e5c 100644 (file)
@@ -57,51 +57,14 @@ Type & TypeDefault::toBackend(Backend b) const {
 Type & TypeDefault::toScalarType(ScalarType s) const {
   return at::globalContext().getNonVariableType(backend(),s);
 }
-static std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
-  std::vector<int64_t> strides(sizes.size());
-  int64_t stride = 1;
-  for(size_t i = sizes.size(); i > 0; --i) {
-    strides[i-1] = stride;
-    stride *= sizes[i-1];
-  }
-  return strides;
-}
-static int64_t computeStorageSize(IntArrayRef sizes, IntArrayRef strides) {
-  // size of the underlying storage is 1 bigger than the offset
-  // of the last element according to stride
-  int64_t size = 1;
-  for(size_t i = 0; i < sizes.size(); i++) {
-    if(sizes[i] == 0) {
-      return 0;
-    }
-    size += strides[i]*(sizes[i]-1);
-  }
-  return size;
-}
-Tensor TypeDefault::tensorFromBlob(void * data, IntArrayRef sizes, const std::function<void(void*)> & deleter) const {
-  return tensorFromBlob(data, sizes, defaultStrides(sizes), deleter);
-}
-Tensor TypeDefault::tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function<void(void*)> & deleter) const {
-  auto storage = storageFromBlob(data, computeStorageSize(sizes, strides), deleter);
-  return at::empty({0}, options()).set_(storage, 0, sizes, strides);
-}
 Tensor TypeDefault::tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const {
-  return tensorWithAllocator(sizes, defaultStrides(sizes), std::move(allocator));
+  return tensorWithAllocator(sizes, detail::defaultStrides(sizes), std::move(allocator));
 }
 Tensor TypeDefault::tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const {
-  auto storage = storageWithAllocator(computeStorageSize(sizes, strides), std::move(allocator));
+  auto storage = storageWithAllocator(detail::computeStorageSize(sizes, strides), std::move(allocator));
   return at::empty({0}, options()).set_(storage, 0, sizes, strides);
 }
 
-Storage TypeDefault::storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const {
-  return Storage(
-      typeMeta(),
-      size,
-      InefficientStdFunctionContext::makeDataPtr(
-          data, deleter, getDeviceFromPtr(data)),
-      /*allocator=*/nullptr,
-      /*resizable=*/false);
-}
 Storage TypeDefault::storageWithAllocator(int64_t size, Allocator* allocator) const {
   // Potentially the storage might be marked as resizable too here
   return Storage(typeMeta(), size, allocator, /*resizable=*/false);
index e8a1517..b597ace 100644 (file)
@@ -38,12 +38,9 @@ struct CAFFE2_API TypeDefault : public TypeExtendedInterface {
       bool create_graph) const override;
   void set_data(Tensor & self, Tensor new_data) const override;
 
-  Tensor tensorFromBlob(void * data, IntArrayRef sizes, const std::function<void(void*)> & deleter=noop_deleter) const override;
-  Tensor tensorFromBlob(void * data, IntArrayRef sizes, IntArrayRef strides, const std::function<void(void*)> & deleter=noop_deleter) const override;
   Tensor tensorWithAllocator(IntArrayRef sizes, Allocator* allocator) const override;
   Tensor tensorWithAllocator(IntArrayRef sizes, IntArrayRef strides, Allocator* allocator) const override;
 
-  Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const override;
   Storage storageWithAllocator(int64_t size, Allocator* allocator) const override;
   Storage unsafeStorageFromTH(void * th_pointer, bool retain) const override;
   Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override;
index 86d07d4..e681d78 100644 (file)
@@ -53,7 +53,7 @@ TEST(atest, atest) {
 
   float data[] = {1, 2, 3, 4, 5, 6};
 
-  auto f = CPU(kFloat).tensorFromBlob(data, {1, 2, 3});
+  auto f = from_blob(data, {1, 2, 3});
   auto f_a = f.accessor<float, 3>();
 
   ASSERT_EQ(f_a[0][0][0], 1.0);
@@ -72,7 +72,7 @@ TEST(atest, atest) {
     int isgone = 0;
     {
       auto f2 =
-          CPU(kFloat).tensorFromBlob(data, {1, 2, 3}, [&](void*) { isgone++; });
+          from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
     }
     ASSERT_EQ(isgone, 1);
   }
@@ -81,7 +81,7 @@ TEST(atest, atest) {
     Tensor a_view;
     {
       auto f2 =
-          CPU(kFloat).tensorFromBlob(data, {1, 2, 3}, [&](void*) { isgone++; });
+          from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
       a_view = f2.view({3, 2, 1});
     }
     ASSERT_EQ(isgone, 0);
@@ -93,8 +93,7 @@ TEST(atest, atest) {
     int isgone = 0;
     {
       auto base = at::empty({1,2,3}, TensorOptions(kCUDA));
-      auto f2 = CUDA(kFloat).tensorFromBlob(
-          base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; });
+      auto f2 = from_blob(base.data_ptr(), {1, 2, 3}, [&](void*) { isgone++; });
     }
     ASSERT_EQ(isgone, 1);
   }
index b597084..8c5f02e 100644 (file)
@@ -54,18 +54,22 @@ private:
     #undef DEFINE_CASE
   }
 
-  at::Type& typeFor(const Tensor& ten) {
-    at::Backend b = backend();
+  at::TensorOptions optionsFor(const Tensor& ten) {
+    at::Device device = ten.GetDevice();
 #ifdef __HIP_PLATFORM_HCC__
-    if (b == at::Backend::HIP) {
-      b = at::Backend::CUDA;
+    if (backend() == at::Backend::HIP) {
+      device = at::Device(kCUDA, device.index());
     }
 #endif
-    return at::getNonVariableType(b, typeMetaToScalarType(ten.meta()));
+    return at::TensorOptions(device).dtype(ten.dtype());
   }
+
   at::Tensor tensorWrapping(const Tensor& ten_) {
     auto& ten = const_cast<Tensor&>(ten_);
-    return typeFor(ten).tensorFromBlob(ten.raw_mutable_data(), ten.sizes());
+    return at::from_blob(
+        ten.raw_mutable_data(),
+        ten.sizes(),
+        optionsFor(ten));
   }
 
   at::Tensor peek(size_t i, size_t N) {
index 1b1ded3..0183c27 100644 (file)
@@ -38,7 +38,6 @@ struct TORCH_API VariableType final : public at::TypeDefault {
   at::Backend backend() const override;
   at::Allocator* allocator() const override;
   at::Device getDeviceFromPtr(void * data) const override;
-  Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const override;
   Storage storageWithAllocator(int64_t size, at::Allocator* allocator) const override;
   std::unique_ptr<at::Generator> generator() const override;
   const char * toString() const override;
index 27d6fe4..e012df2 100644 (file)
@@ -31,9 +31,6 @@ Allocator* VariableType::allocator() const {
 Device VariableType::getDeviceFromPtr(void * data) const {
   return baseType->getDeviceFromPtr(data);
 }
-Storage VariableType::storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const {
-  return baseType->storageFromBlob(data, size, deleter);
-}
 Storage VariableType::unsafeStorageFromTH(void * th_pointer, bool retain) const {
   return baseType->unsafeStorageFromTH(th_pointer, retain);
 }
index 134ab7c..675c41c 100644 (file)
@@ -133,17 +133,22 @@ at::Tensor tensor_from_numpy(PyObject* obj) {
   }
 
   void* data_ptr = PyArray_DATA(array);
-  auto& type = CPU(numpy_dtype_to_aten(PyArray_TYPE(array)));
   if (!PyArray_EquivByteorders(PyArray_DESCR(array)->byteorder, NPY_NATIVE)) {
     throw ValueError(
         "given numpy array has byte order different from the native byte order. "
         "Conversion between byte orders is currently not supported.");
   }
   Py_INCREF(obj);
-  return type.tensorFromBlob(data_ptr, sizes, strides, [obj](void* data) {
-    AutoGIL gil;
-    Py_DECREF(obj);
-  });
+  return at::from_blob(
+      data_ptr,
+      sizes,
+      strides,
+      [obj](void* data) {
+          AutoGIL gil;
+          Py_DECREF(obj);
+      },
+      at::device(kCPU).dtype(numpy_dtype_to_aten(PyArray_TYPE(array)))
+  );
 }
 
 static int aten_to_dtype(const ScalarType scalar_type) {