`var` for multiple dimensions (#15892)
authorBrennan Vincent <btv@fb.com>
Tue, 15 Jan 2019 04:14:04 +0000 (20:14 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 15 Jan 2019 04:17:42 +0000 (20:17 -0800)
Summary:
Timings are the same as for `std` .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15892

Differential Revision: D13651173

Pulled By: umanwizard

fbshipit-source-id: a26bf1021dd972aa9e3e60fb901cd4983bfa190f

12 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/ReduceOps.h
aten/src/ATen/native/SharedReduceOps.h
aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
aten/src/ATen/native/cuda/ReduceOpsKernel.cu
aten/src/ATen/native/native_functions.yaml
test/test_torch.py
tools/autograd/derivatives.yaml
torch/csrc/jit/passes/shape_analysis.cpp

index 762789f..2f02f92 100644 (file)
@@ -493,7 +493,7 @@ public:
   Tensor unsqueeze(int64_t dim) const;
   Tensor & unsqueeze_(int64_t dim);
   Tensor var(bool unbiased=true) const;
-  Tensor var(int64_t dim, bool unbiased=true, bool keepdim=false) const;
+  Tensor var(IntList dim, bool unbiased=true, bool keepdim=false) const;
   Tensor view_as(const Tensor & other) const;
   Tensor where(const Tensor & condition, const Tensor & other) const;
   Tensor norm(Scalar p=2) const;
index b44ca0f..33f3dc2 100644 (file)
@@ -664,7 +664,7 @@ inline Tensor & Tensor::unsqueeze_(int64_t dim) {
 inline Tensor Tensor::var(bool unbiased) const {
     return type().var(*this, unbiased);
 }
-inline Tensor Tensor::var(int64_t dim, bool unbiased, bool keepdim) const {
+inline Tensor Tensor::var(IntList dim, bool unbiased, bool keepdim) const {
     return type().var(*this, dim, unbiased, keepdim);
 }
 inline Tensor Tensor::view_as(const Tensor & other) const {
index 73b3cfd..7de537f 100644 (file)
@@ -388,7 +388,7 @@ struct CAFFE2_API Type {
   virtual Tensor unsqueeze(const Tensor & self, int64_t dim) const = 0;
   virtual Tensor & unsqueeze_(Tensor & self, int64_t dim) const = 0;
   virtual Tensor var(const Tensor & self, bool unbiased) const = 0;
-  virtual Tensor var(const Tensor & self, int64_t dim, bool unbiased, bool keepdim) const = 0;
+  virtual Tensor var(const Tensor & self, IntList dim, bool unbiased, bool keepdim) const = 0;
   virtual Tensor view_as(const Tensor & self, const Tensor & other) const = 0;
   virtual Tensor where(const Tensor & condition, const Tensor & self, const Tensor & other) const = 0;
   virtual Tensor norm(const Tensor & self, Scalar p) const = 0;
index 6c57ecc..c1c4a75 100644 (file)
@@ -21,7 +21,7 @@ namespace at {
 namespace native {
 
 DEFINE_DISPATCH(sum_stub);
-DEFINE_DISPATCH(std_stub);
+DEFINE_DISPATCH(std_var_stub);
 DEFINE_DISPATCH(prod_stub);
 DEFINE_DISPATCH(mean_stub);
 
@@ -462,6 +462,20 @@ Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
   }
 }
 
+static Tensor &std_var_out(Tensor &result, const Tensor &self, IntList dim, bool unbiased, bool keepdim, bool take_sqrt) {
+  AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
+           "std and var only support CPU AND CUDA backend, got: ", toString(self.type().backend()));
+  AT_CHECK(at::isFloatingType(self.type().scalarType()), "std and var only support floating-point dtypes");
+  ScalarType dtype = get_dtype(result, self, {}, true);
+  auto iter = make_reduction("std or var", result, self, dim, keepdim, dtype);
+  if (iter->numel() == 0) {
+    result.fill_(NAN);
+  } else {
+    std_var_stub(iter->device_type(), *iter, unbiased, take_sqrt);
+  }
+  return result;
+}
+
 Tensor var(const Tensor& self, bool unbiased) {
   AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
            "var only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
@@ -470,21 +484,13 @@ Tensor var(const Tensor& self, bool unbiased) {
   return trivial_return.has_value() ? trivial_return.value() : at::legacy::th::_th_var(self, unbiased);
 }
 
-Tensor var(const Tensor& self, int64_t dim, bool unbiased, bool keepdim) {
+Tensor var(const Tensor& self, IntList dim, bool unbiased, bool keepdim) {
   Tensor result = at::empty({0}, self.options());
   return at::native::var_out(result, self, dim, unbiased, keepdim);
 }
 
-Tensor &var_out(Tensor &result, const Tensor &self, int64_t dim, bool unbiased, bool keepdim) {
-  AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
-           "var only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
-  AT_CHECK(at::isFloatingType(self.type().scalarType()), "var only supports floating-point dtypes");
-  dim = maybe_wrap_dim(dim, self.dim());
-  if (_dimreduce_return_trivial(result, self, std::numeric_limits<double>::quiet_NaN(), dim, keepdim)) {
-    return result;
-  } else {
-    return at::legacy::th::_th_var_out(result, self, dim, unbiased, keepdim);
-  }
+Tensor &var_out(Tensor &result, const Tensor &self, IntList dim, bool unbiased, bool keepdim) {
+  return std_var_out(result, self, dim, unbiased, keepdim, false);
 }
 
 Tensor std(const Tensor& self, bool unbiased) {
@@ -501,17 +507,7 @@ Tensor std(const Tensor& self, IntList dim, bool unbiased, bool keepdim) {
 }
 
 Tensor &std_out(Tensor &result, const Tensor &self, IntList dim, bool unbiased, bool keepdim) {
-  AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
-           "std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
-  AT_CHECK(at::isFloatingType(self.type().scalarType()), "std only supports floating-point dtypes");
-  ScalarType dtype = get_dtype(result, self, {}, true);
-  auto iter = make_reduction("std", result, self, dim, keepdim, dtype);
-  if (iter->numel() == 0) {
-    result.fill_(NAN);
-  } else {
-    std_stub(iter->device_type(), *iter, unbiased);
-  }
-  return result;
+  return std_var_out(result, self, dim, unbiased, keepdim, true);
 }
 
 }} // namespace at::native
index 9bc83d2..2d8cd4e 100644 (file)
@@ -16,9 +16,9 @@ DECLARE_DISPATCH(reduce_fn, sum_stub);
 DECLARE_DISPATCH(reduce_fn, prod_stub);
 DECLARE_DISPATCH(reduce_fn, mean_stub);
 
-using reduce_std_function =
-  void (*)(TensorIterator&, bool unbiased);
-DECLARE_DISPATCH(reduce_std_function, std_stub);
+using reduce_std_var_function =
+  void (*)(TensorIterator&, bool unbiased, bool take_sqrt);
+DECLARE_DISPATCH(reduce_std_var_function, std_var_stub);
 
 using reduce_norm_fn =
     void (*)(Tensor&, const Tensor&, Scalar, c10::optional<int64_t>);
index 1c75eaa..30b9a07 100644 (file)
@@ -29,6 +29,7 @@ struct WelfordData {
 template <typename scalar_t, typename acc_scalar_t>
 struct WelfordOps {
   bool unbiased;
+  bool take_sqrt;
  public:
   using acc_t = WelfordData<acc_scalar_t>;
   inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data) const {
@@ -59,7 +60,10 @@ struct WelfordOps {
   }
   inline C10_DEVICE scalar_t project(acc_t acc) const {
     int64_t divisor = unbiased ? (acc.n - 1) : acc.n;
-    return (divisor > 0) ? (scalar_t)device_sqrt(acc.m2 / divisor) : (scalar_t)NAN;
+    auto ret = (divisor > 0) ?
+      (take_sqrt ? device_sqrt(acc.m2 / divisor) : (acc.m2 / divisor))
+      : NAN;
+    return (scalar_t) ret;
   }
 #if defined(__CUDACC__) || defined(__HIPCC__)
   inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
@@ -70,7 +74,8 @@ struct WelfordOps {
     };
   }
 #endif
-  WelfordOps(bool unbiased) : unbiased(unbiased) {
+  WelfordOps(bool unbiased, bool take_sqrt)
+    : unbiased(unbiased), take_sqrt(take_sqrt) {
   }
 };
 
index 5dfed3f..8b365fb 100644 (file)
@@ -34,11 +34,11 @@ static void mean_kernel_impl(TensorIterator& iter) {
   });
 }
 
-static void std_kernel_impl(TensorIterator &iter, bool unbiased) {
+static void std_var_kernel_impl(TensorIterator &iter, bool unbiased, bool take_sqrt) {
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&] {
     binary_kernel_reduce(
       iter,
-      WelfordOps<scalar_t, double> { unbiased },
+      WelfordOps<scalar_t, double> { unbiased, take_sqrt },
       WelfordData<double>()
     );
   });
@@ -57,7 +57,7 @@ static void prod_kernel_impl(TensorIterator& iter) {
 }  // anonymous namespace
 
 REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
-REGISTER_DISPATCH(std_stub, &std_kernel_impl);
+REGISTER_DISPATCH(std_var_stub, &std_var_kernel_impl);
 REGISTER_DISPATCH(prod_stub, &prod_kernel_impl);
 REGISTER_DISPATCH(mean_stub, &mean_kernel_impl);
 
index d1d9dc1..3110458 100644 (file)
@@ -33,13 +33,13 @@ void sum_kernel_impl(TensorIterator& iter) {
 }
 
 template <typename scalar_t>
-void std_kernel_impl(TensorIterator& iter, bool unbiased) {
-  gpu_reduce_kernel<scalar_t, scalar_t>(iter, WelfordOps<scalar_t, scalar_t> { unbiased }, WelfordData<scalar_t> {});
+void std_var_kernel_impl(TensorIterator& iter, bool unbiased, bool take_sqrt) {
+  gpu_reduce_kernel<scalar_t, scalar_t>(iter, WelfordOps<scalar_t, scalar_t> { unbiased, take_sqrt }, WelfordData<scalar_t> {});
 }
 
 template <>
-void std_kernel_impl<at::Half>(TensorIterator& iter, bool unbiased) {
-  gpu_reduce_kernel<at::Half, at::Half>(iter, WelfordOps<at::Half, float> { unbiased }, WelfordData<float> {});
+void std_var_kernel_impl<at::Half>(TensorIterator& iter, bool unbiased, bool take_sqrt) {
+  gpu_reduce_kernel<at::Half, at::Half>(iter, WelfordOps<at::Half, float> { unbiased, take_sqrt }, WelfordData<float> {});
 }
 
 #ifdef __HIPCC__
@@ -62,9 +62,9 @@ void prod_kernel_impl(TensorIterator& iter) {
   }), 1);
 }
 
-static void std_kernel_cuda(TensorIterator& iter, bool unbiased) {
+static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_sqrt) {
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&]() {
-    std_kernel_impl<scalar_t>(iter, unbiased);
+    std_var_kernel_impl<scalar_t>(iter, unbiased, take_sqrt);
   });
 }
 
@@ -119,7 +119,7 @@ static void mean_kernel_cuda(TensorIterator& iter) {
   });
 }
 
-REGISTER_DISPATCH(std_stub, &std_kernel_cuda);
+REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda);
 REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
 REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
 REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
index ef3ec16..2144a88 100644 (file)
 - func: var(Tensor self, bool unbiased=true) -> Tensor
   variants: function, method
 
-- func: var(Tensor self, int64_t dim, bool unbiased=true, bool keepdim=false) -> Tensor
+- func: var(Tensor self, IntList[1] dim, bool unbiased=true, bool keepdim=false) -> Tensor
   variants: function, method
 
-- func: var_out(Tensor result, Tensor self, int64_t dim, bool unbiased=true, bool keepdim=false) -> Tensor
+- func: var_out(Tensor result, Tensor self, IntList[1] dim, bool unbiased=true, bool keepdim=false) -> Tensor
 
 - func: view_as(Tensor self, Tensor other) -> Tensor
   variants: method
index ce88932..c1e388d 100644 (file)
@@ -2051,6 +2051,14 @@ class _TestTorchMixin(object):
                 lambda n, d: n.std(d, ddof=1 if unbiased else 0),
                 use_integral=False)
 
+    @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
+    def test_var_dim(self):
+        for unbiased in [False, True]:
+            self._test_dim_ops(
+                lambda t, d: t.var(d, unbiased=unbiased),
+                lambda n, d: n.var(d, ddof=1 if unbiased else 0),
+                use_integral=False)
+
     def test_sum_out(self):
         x = torch.rand(100, 100)
         res1 = torch.sum(x, 1)
index edc88c6..e652b2e 100644 (file)
 - name: var(Tensor self, bool unbiased)
   self: var_backward(grad, self, unbiased)
 
-- name: var(Tensor self, int64_t dim, bool unbiased, bool keepdim)
+- name: var(Tensor self, IntList dim, bool unbiased, bool keepdim)
   self: var_backward(grad, self, dim, unbiased, keepdim)
 
 - name: view(Tensor self, IntList size)
index 4a0fa77..1265465 100644 (file)
@@ -903,7 +903,6 @@ class ShapePropagator {
             "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::norm(Tensor self, Scalar? p, int dim, bool keepdim) -> Tensor",
-            "aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
             "aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
@@ -957,6 +956,7 @@ class ShapePropagator {
         {
             "aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
             "aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
+            "aten::var(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
         },
         [](Node* node) -> type_vec_t {
           if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {