Port empty_strided to ATen. (#15948)
authorGregory Chanan <gchanan@fb.com>
Fri, 11 Jan 2019 15:55:17 +0000 (07:55 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 15:58:05 +0000 (07:58 -0800)
Summary:
Turns out this has basically been implemented already in Resize.h / Resize.cuh.
Also added some testing, basically just to check that empty_strided behaves equivalently to as_strided.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15948

Differential Revision: D13631098

Pulled By: gchanan

fbshipit-source-id: eb0e04eead45e4cff393ebde340f9d265779e185

aten/src/ATen/Declarations.cwrap
aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/cuda/TensorFactories.cu
aten/src/ATen/native/native_functions.yaml
test/test_torch.py

index 38cb05b..30a90c7 100644 (file)
         - THTensor* alpha
         - THTensor* total
 ]]
-[[
-  name: _th_tensor
-  return: THTensor*
-  cpu_half: True
-  variants: [function]
-  options:
-    - cname: newWithSize
-      arguments:
-        - IntListSize size
-        - IntList stride
-]]
 
 # In theory, this could be a part of the above declaration. But in
 # practice this leads to all sorts of problems with ambiguous overloads.
index 101d0b5..a2ca7bd 100644 (file)
@@ -13,6 +13,7 @@
 #include <ATen/LegacyTHDispatcher.h>
 #include <c10/core/ScalarType.h>
 #include <ATen/core/Deprecated.h>
+#include <ATen/native/Resize.h>
 #include <ATen/native/TensorFactories.h>
 #include <c10/core/TensorOptions.h>
 #include <TH/THRandom.h>
 #include <cmath>
 #include <cstddef>
 
-// Note [Native bindings for legacy TH factory functions]
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-// A number of factory functions are implemented in the following way:
-//
-//    return at::getType(options)._arange(start, end, step);
-//
-// That is to say, they grab a Type for TensorOptions, and then call some
-// internal method.  What's going on?
-//
-// The reason for the folderol is that these particular factory functions
-// are still implemented in a legacy way in TH.  The TH bindings don't
-// (and never will) understand TensorOptions, so we need to handle TensorOptions
-// inside native before batting over to TH.  The expectation is that when
-// these factories get ported to native, this is no longer necessary,
-// and we can eliminate the getType call.
-
 namespace at {
 namespace native {
 namespace {
@@ -125,6 +110,12 @@ Tensor empty_cpu(IntList size, const TensorOptions& options) {
   return tensor;
 }
 
+Tensor empty_strided_cpu(IntList size, IntList stride, const TensorOptions& options) {
+  auto t = at::native::empty_cpu({0}, options);
+  at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride);
+  return t;
+}
+
 Tensor& empty_out(Tensor& result, IntList size) {
   if (result.is_sparse()) {
     result.sparse_resize_and_clear_(size, size.size(), 0);
@@ -134,12 +125,6 @@ Tensor& empty_out(Tensor& result, IntList size) {
   return result;
 }
 
-Tensor empty_strided(IntList size, IntList stride, const TensorOptions& options) {
-  // Note [Native bindings for legacy TH factory functions]
-  return getFactoryType(options)._th_tensor(size, stride);
-}
-
-
 // Temporary type cast operators. These are needed to trace type-casts now since
 // Type's are not supported in the IR. Instead, we call down to these
 // specialized operators for each datatype.
index 744928f..6413277 100644 (file)
@@ -4,6 +4,7 @@
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/native/TensorFactories.h>
+#include <ATen/native/cuda/Resize.cuh>
 #include <c10/util/Exception.h>
 
 #include <THC/THCGeneral.h>
@@ -64,6 +65,12 @@ Tensor empty_cuda(IntList size, const TensorOptions& options) {
   return tensor;
 }
 
+Tensor empty_strided_cuda(IntList size, IntList stride, const TensorOptions& options) {
+  auto t = at::native::empty_cuda({0}, options);
+  at::native::resize_impl_cuda_(t.unsafeGetTensorImpl(), size, stride);
+  return t;
+}
+
 Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
   AT_CHECK(n >= 0, "n must be non-negative, got", n);
   AT_CHECK(at::scalar_tensor(n, result.options()).defined(),
index baef657..ef3ec16 100644 (file)
   device_guard: False
 
 - func: empty_strided(IntList size, IntList stride, *, TensorOptions options={}) -> Tensor
+  dispatch:
+    CPU: empty_strided_cpu
+    CUDA: empty_strided_cuda
 
 - func: erf(Tensor self) -> Tensor
   variants: function, method
index f35cda5..4d51c0f 100644 (file)
@@ -2580,6 +2580,7 @@ class _TestTorchMixin(object):
                 self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device)).shape)
                 self.assertEqual(shape, torch.empty(shape, device=device).shape)
                 self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device)).shape)
+                self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device).shape)
                 self.assertEqual(shape, torch.full(shape, 3, device=device).shape)
                 self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device), 3).shape)
                 self.assertEqual(shape, torch.ones(shape, device=device).shape)
@@ -8810,6 +8811,22 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(torch.empty_like(a).shape, a.shape)
             self.assertEqual(torch.empty_like(a).type(), a.type())
 
+    def test_empty_strided(self):
+        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+        for device in devices:
+            for shape in [(2, 3, 4), (0, 2, 0)]:
+                # some of these cases are pretty strange, just verifying that if as_strided
+                # allows them then empty_strided can as well.
+                for strides in [(12, 4, 1), (2, 4, 6), (0, 0, 0)]:
+                    empty_strided = torch.empty_strided(shape, strides, device=device)
+                    # as_strided checks the storage size is big enough to support such a strided tensor;
+                    # instead of repeating this calculation, we just use empty_strided which does the same
+                    # calculation when setting the storage size.
+                    as_strided = torch.empty(empty_strided.storage().size(),
+                                             device=device).as_strided(shape, strides)
+                    self.assertEqual(empty_strided.shape, as_strided.shape)
+                    self.assertEqual(empty_strided.stride(), as_strided.stride())
+
     @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
     @skipIfRocm
     def test_pin_memory(self):