Make `mean` function work across multiple dimensions. (#14252)
authorBrennan Vincent <btv@fb.com>
Wed, 28 Nov 2018 14:50:49 +0000 (06:50 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 28 Nov 2018 14:53:09 +0000 (06:53 -0800)
Summary:
Multi-dimensional `sum` is already implemented, and it's trivial to implement `mean` in terms of `sum`, so just do it.

Bonus: Fix incomplete language in the `torch.sum` documentation which doesn't take into account multiple dimensions when describing `unsqueeze` (at the same time as introducing similar language in `torch.mean`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14252

Differential Revision: D13161157

Pulled By: umanwizard

fbshipit-source-id: c45da692ba83c0ec80815200c5543302128da75c

aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/native_functions.yaml
test/test_torch.py
tools/autograd/derivatives.yaml
tools/autograd/templates/Functions.cpp
torch/_torch_docs.py
torch/csrc/jit/passes/shape_analysis.cpp

index 16c2699..e7311a7 100644 (file)
@@ -385,9 +385,9 @@ public:
   Tensor max_values(int64_t dim, bool keepdim=false) const;
   Tensor mean(ScalarType dtype) const;
   Tensor mean() const;
-  Tensor mean(int64_t dim, bool keepdim, ScalarType dtype) const;
-  Tensor mean(int64_t dim, bool keepdim=false) const;
-  Tensor mean(int64_t dim, ScalarType dtype) const;
+  Tensor mean(IntList dim, bool keepdim, ScalarType dtype) const;
+  Tensor mean(IntList dim, bool keepdim=false) const;
+  Tensor mean(IntList dim, ScalarType dtype) const;
   std::tuple<Tensor,Tensor> median(int64_t dim, bool keepdim=false) const;
   std::tuple<Tensor,Tensor> min(int64_t dim, bool keepdim=false) const;
   Tensor min_values(int64_t dim, bool keepdim=false) const;
index 270b9c4..e0aa7ac 100644 (file)
@@ -405,13 +405,13 @@ inline Tensor Tensor::mean(ScalarType dtype) const {
 inline Tensor Tensor::mean() const {
     return type().mean(*this);
 }
-inline Tensor Tensor::mean(int64_t dim, bool keepdim, ScalarType dtype) const {
+inline Tensor Tensor::mean(IntList dim, bool keepdim, ScalarType dtype) const {
     return type().mean(*this, dim, keepdim, dtype);
 }
-inline Tensor Tensor::mean(int64_t dim, bool keepdim) const {
+inline Tensor Tensor::mean(IntList dim, bool keepdim) const {
     return type().mean(*this, dim, keepdim);
 }
-inline Tensor Tensor::mean(int64_t dim, ScalarType dtype) const {
+inline Tensor Tensor::mean(IntList dim, ScalarType dtype) const {
     return type().mean(*this, dim, dtype);
 }
 inline std::tuple<Tensor,Tensor> Tensor::median(int64_t dim, bool keepdim) const {
index a937b05..0347bfc 100644 (file)
@@ -293,9 +293,9 @@ struct CAFFE2_API Type {
   virtual Tensor max_values(const Tensor & self, int64_t dim, bool keepdim) const = 0;
   virtual Tensor mean(const Tensor & self, ScalarType dtype) const = 0;
   virtual Tensor mean(const Tensor & self) const = 0;
-  virtual Tensor mean(const Tensor & self, int64_t dim, bool keepdim, ScalarType dtype) const = 0;
-  virtual Tensor mean(const Tensor & self, int64_t dim, bool keepdim) const = 0;
-  virtual Tensor mean(const Tensor & self, int64_t dim, ScalarType dtype) const = 0;
+  virtual Tensor mean(const Tensor & self, IntList dim, bool keepdim, ScalarType dtype) const = 0;
+  virtual Tensor mean(const Tensor & self, IntList dim, bool keepdim) const = 0;
+  virtual Tensor mean(const Tensor & self, IntList dim, ScalarType dtype) const = 0;
   virtual std::tuple<Tensor,Tensor> median(const Tensor & self, int64_t dim, bool keepdim) const = 0;
   virtual std::tuple<Tensor,Tensor> min(const Tensor & self, int64_t dim, bool keepdim) const = 0;
   virtual Tensor min_values(const Tensor & self, int64_t dim, bool keepdim) const = 0;
index 9a5be73..5e297c8 100644 (file)
@@ -101,6 +101,14 @@ static std::unique_ptr<TensorIterator> make_reduction(
   return TensorIterator::reduce_op(viewed_result, self);
 }
 
+static inline int64_t n_dim_size(const Tensor& self, IntList dim) {
+  int64_t numel = 1;
+  for (auto d : dim) {
+    numel *= self.size(d);
+  }
+  return numel;
+}
+
 static inline Tensor cumsum(const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
   return at::_th_cumsum(integer_upcast(self, dtype), dim);
 }
@@ -260,7 +268,7 @@ Tensor prod(const Tensor &self) {
 
 // DIM REDUCE #################################################################
 
-static inline Tensor &mean_out(Tensor &result, const Tensor &self, int64_t dim,
+static inline Tensor &mean_out(Tensor &result, const Tensor &self, IntList dim,
                  bool keepdim, optional<ScalarType> dtype) {
   ScalarType scalarType = result.type().scalarType();
   AT_CHECK(
@@ -271,7 +279,7 @@ static inline Tensor &mean_out(Tensor &result, const Tensor &self, int64_t dim,
   at::native::sum_out(
       result, self.toType(result.type().scalarType()), dim, keepdim);
   if (result.numel() > 0 && self.ndimension() > 0) {
-    int64_t numel = self.size(dim);
+    int64_t numel = n_dim_size(self, dim);
     if (numel > 0) {
       result.div_(numel);
     } else {
@@ -282,15 +290,15 @@ static inline Tensor &mean_out(Tensor &result, const Tensor &self, int64_t dim,
   return result;
 }
 
-Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
+Tensor& mean_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
   return at::native::mean_out(
       result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
 }
-Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
+Tensor& mean_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim) {
   return at::native::mean_out(result, self, dim, keepdim, c10::nullopt);
 }
 
-Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
+Tensor& mean_out(Tensor& result, const Tensor& self, IntList dim, ScalarType dtype) {
   return at::native::mean_out(result, self, dim, false, dtype);
 }
 
@@ -320,7 +328,7 @@ Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dty
   return at::native::prod_out(result, self, dim, false, dtype);
 }
 
-static inline Tensor mean(const Tensor &self, int64_t dim, bool keepdim, optional<ScalarType> dtype) {
+static inline Tensor mean(const Tensor &self, IntList dim, bool keepdim, optional<ScalarType> dtype) {
   ScalarType scalarType = self.type().scalarType();
   AT_CHECK(
       at::isFloatingType(scalarType),
@@ -329,7 +337,7 @@ static inline Tensor mean(const Tensor &self, int64_t dim, bool keepdim, optiona
       " instead.");
   Tensor result = at::native::sum(self, dim, keepdim);
   if (result.numel() > 0 && self.ndimension() > 0) {
-    int64_t numel = self.size(dim);
+    int64_t numel = n_dim_size(self, dim);
     if (numel > 0) {
       result.div_(numel);
     } else {
@@ -340,15 +348,15 @@ static inline Tensor mean(const Tensor &self, int64_t dim, bool keepdim, optiona
   return result;
 }
 
-Tensor mean(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
+Tensor mean(const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
   return at::native::mean(self, dim, keepdim, c10::optional<ScalarType>(dtype));
 }
 
-Tensor mean(const Tensor& self, int64_t dim, bool keepdim) {
+Tensor mean(const Tensor& self, IntList dim, bool keepdim) {
   return at::native::mean(self, dim, keepdim, c10::nullopt);
 }
 
-Tensor mean(const Tensor& self, int64_t dim, ScalarType dtype) {
+Tensor mean(const Tensor& self, IntList dim, ScalarType dtype) {
   return at::native::mean(self, dim, false, dtype);
 }
 
index 7e71f47..e6de5a5 100644 (file)
 - func: mean(Tensor self) -> Tensor
   variants: function, method
 
-- func: mean(Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor
+- func: mean(Tensor self, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
   variants: function, method
 
-- func: mean(Tensor self, int64_t dim, bool keepdim=False) -> Tensor
+- func: mean(Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor
   variants: function, method
 
-- func: mean(Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor
+- func: mean(Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor
   variants: function, method
 
-- func: mean_out(Tensor result, Tensor self, int64_t dim, bool keepdim, *, ScalarType dtype) -> Tensor
+- func: mean_out(Tensor result, Tensor self, IntList[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
 
-- func: mean_out(Tensor result, Tensor self, int64_t dim, bool keepdim=False) -> Tensor
+- func: mean_out(Tensor result, Tensor self, IntList[1] dim, bool keepdim=False) -> Tensor
 
-- func: mean_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor
+- func: mean_out(Tensor result, Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor
 
 - func: median(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor)
   variants: function, method
index 69bafcb..33e0fde 100644 (file)
@@ -927,6 +927,7 @@ class _TestTorchMixin(object):
             self.assertEqual(x.mean().item(), 16.0 / 6)
             self.assertEqual(x.mean(0), torch.FloatTensor([2.0, 2.5, 7.0 / 2]))
             self.assertEqual(x.mean(1), torch.FloatTensor([2.0 / 3, 14.0 / 3]))
+            self.assertEqual(x.mean(), x.mean((0, 1)))
 
         for dtype in types:
             x = cast(torch.tensor(example, dtype=dtype))
@@ -1923,6 +1924,13 @@ class _TestTorchMixin(object):
         check_sum_all(torch.randn(200000))
         check_sum_all(torch.randn(2000, 2)[:, 0])
 
+    def _assert_matches_numpy(self, t, n):
+        self.assertEqual(n.shape, t.shape)
+        if t.dtype == torch.float:
+            self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05))
+        else:
+            self.assertTrue(np.allclose(n, t.numpy()))
+
     @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
     def test_sum_dim(self):
         def check_sum_dim(tensors_dict, dim):
@@ -1932,11 +1940,7 @@ class _TestTorchMixin(object):
                 for tensor in tensors:
                     expected = tensor.numpy().sum(dim)
                     actual = tensor.sum(dim)
-                    self.assertEqual(expected.shape, actual.shape)
-                    if actual.dtype == torch.float:
-                        self.assertTrue(np.allclose(expected, actual.numpy(), rtol=1e-03, atol=1e-05))
-                    else:
-                        self.assertTrue(np.allclose(expected, actual.numpy()))
+                    self._assert_matches_numpy(actual, expected)
 
         float_types = [torch.double, torch.float]
         int_types = [torch.int64, torch.int32, torch.int16]
@@ -1952,6 +1956,28 @@ class _TestTorchMixin(object):
         check_sum_dim(self._make_tensors((50, 50, 50)), (1, 2))
         check_sum_dim(self._make_tensors((50, 50, 50)), (1, -1))
 
+    @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
+    def test_mean_dim(self):
+        def check_mean_dim(tensors_dict, dim):
+            for category, tensors in tensors_dict.items():
+                if category == "slice":
+                    dim = 0
+                for tensor in tensors:
+                    expected = tensor.numpy().mean(dim)
+                    actual = tensor.mean(dim)
+                    self._assert_matches_numpy(actual, expected)
+
+        check_mean_dim(self._make_tensors((5, 400000), use_integral=False), 1)
+        check_mean_dim(self._make_tensors((3, 5, 7), use_integral=False), 0)
+        check_mean_dim(self._make_tensors((3, 5, 7), use_integral=False), 1)
+        check_mean_dim(self._make_tensors((3, 5, 7), use_integral=False), 2)
+        check_mean_dim(self._make_tensors((100000, ), use_integral=False), -1)
+        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), 0)
+        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), 1)
+        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), 2)
+        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), (1, 2))
+        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), (1, -1))
+
     def test_sum_out(self):
         x = torch.rand(100, 100)
         res1 = torch.sum(x, 1)
index 30432a2..ad83012 100644 (file)
   self: grad.clone().masked_fill_(self <= other, 0)
   other: grad.clone().masked_fill_(self > other, 0)
 
-- name: mean(Tensor self, int64_t dim, bool keepdim)
+- name: mean(Tensor self, IntList dim, bool keepdim)
   self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim)
 
 - name: mean(Tensor self)
index cb1bc3b..137f31d 100644 (file)
@@ -76,9 +76,16 @@ Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
   }
 }
 
-int64_t _safe_size(IntList sizes, int64_t dim) {
-  dim = at::maybe_wrap_dim(dim, sizes.size());
-  return sizes.size() != 0 ? sizes[dim] : 1;
+int64_t _safe_size(IntList sizes, IntList dim) {
+  int64_t size = 1;
+  if (sizes.size() == 0) {
+    return 1;
+  }
+  for (auto d : dim) {
+    d = at::maybe_wrap_dim(d, sizes.size());
+    size *= sizes[d];
+  }
+  return size;
 }
 
 Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p_, const Tensor & norm) {
index 270fba9..9d0371f 100644 (file)
@@ -2647,12 +2647,13 @@ Example::
 .. function:: mean(input, dim, keepdim=False, out=None) -> Tensor
 
 Returns the mean value of each row of the :attr:`input` tensor in the given
-dimension :attr:`dim`.
+dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
+reduce over all of them.
 
 If :attr:`keepdim` is ``True``, the output tensor is of the same size
-as :attr:`input` except in the dimension :attr:`dim` where it is of size 1.
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
 Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
-output tensor having 1 fewer dimension.
+output tensor having 1 (or ``len(dim)``) fewer dimension(s).
 
 Args:
     input (Tensor): the input tensor
@@ -4491,9 +4492,9 @@ dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
 reduce over all of them.
 
 If :attr:`keepdim` is ``True``, the output tensor is of the same size
-as :attr:`input` except in the dimension :attr:`dim` where it is of size 1.
-Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in
-the output tensor having 1 fewer dimension than :attr:`input`.
+as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
+Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
+output tensor having 1 (or ``len(dim)``) fewer dimension(s).
 
 Args:
     input (Tensor): the input tensor
index b1fa7f6..80b0641 100644 (file)
@@ -899,7 +899,6 @@ class ShapePropagator {
             "aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
-            "aten::mean(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor",
             "aten::std(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
             "aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
@@ -945,6 +944,27 @@ class ShapePropagator {
               node, /*num_reduce_dim=*/1, /*integer_upcast=*/true);
         }};
 
+
+    // Requirements:
+    //   dims           : preserved if keepdim == false, dim->size() smaller otherwise
+    //   scalar type    : preserved
+    //   device         : preserved
+    //   tensor inputs  : 1
+    //   tensor outputs : 1
+    // Additionally:
+    //   - First input should be the only tensor input
+    //   - has a bool keepdim argument
+    static const register_formula_for multidim_reduce_ops {
+        {
+            "aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
+        },
+        [](Node * node) -> type_vec_t {
+          if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
+            return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/dim->size(), /*integer_upcast=*/false);
+          }
+          return {};
+        }};
+
     // Requirements:
     //   dims           : preserved if keepdim == false, 1 smaller otherwise
     //   scalar type    : preserved if floating point, otherwise long/int64