Implement `std` for multiple dimensions on CPU devices. (#14535)
authorBrennan Vincent <btv@fb.com>
Sat, 8 Dec 2018 04:13:31 +0000 (20:13 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 8 Dec 2018 04:16:04 +0000 (20:16 -0800)
Summary:
Tested on a tensor with 1 billion elements and 3 dimensions on a powerful, highly
multi-core Linux machine.

parallelized: All operations (e.g., `t.std(1)`) that could be done in the old code are now several times faster. All
new operations (e.g., `t.std((0,2))` are significantly faster than the NumPy equivalents.
`t.std((0, 1, 2))`, a new operation, is logically equivalent to the
old `t.std()`, but faster.

serial: The above comment about old operationos now being faster still
holds, but `t.std((t1, ..., tn))` is now a few
times slower than `t.std()`. If this turns out to be important, we can
special-case that to use the old algorithm.

The approach is to create a new method, `TensorIterator::foreach_reduced_elt`,
valid for `TensorIterator`s that represent a dimension reduction. This
method calls a supplied function for each element in the output,
supplying it with the input elements that correspond to that output.

Given that primitive, we can implement reductions like the following pseudocode:

If there is more than one output element:
```
PARALLEL FOR EACH element IN output:
    accumulator = identity
    SERIAL FOR EACH data_point IN element.corresponding_input:
        accumulator.update(data_point)
    element = accumulator.to_output()
```

If there is only one output element, we still want to parallelize, so we
do so along the *input* instead:

```
accumulators[n_threads]
PARALLEL FOR EACH input_chunk IN input.chunks():
    accumulators[thread_num()] = identity
    SERIAL FOR EACH data_point IN input_chunk:
        accumulators[thread_num()].update_with_data(data_point)
accumulator = identity
SERIAL FOR EACH acc in accumulators:
    accumulator.update_with_other_accumulator(acc)
output_element = accumulator.to_output()
```

Note that accumulators and data points do not have to be the same type
in general, since it might be necessary to track arbitrary amounts of
data at intermediate stages.

For example, for `std`, we use a parallel version of Welford's
algorithm, which requies us to track the mean, second moment, and number
of elements, so the accumulator type for `std` contains three pieces of
data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14535

Differential Revision: D13283887

Pulled By: umanwizard

fbshipit-source-id: 8586b7bf00bf9f663c55d6f8323301e257f5ec3f

15 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/ReduceOps.h
aten/src/ATen/native/TensorIterator.cpp
aten/src/ATen/native/TensorIterator.h
aten/src/ATen/native/TensorIteratorReduce.cpp
aten/src/ATen/native/cpu/Reduce.h
aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
aten/src/ATen/native/native_functions.yaml
test/test_torch.py
tools/autograd/derivatives.yaml
tools/autograd/templates/Functions.cpp
torch/csrc/jit/passes/shape_analysis.cpp

index 10b6eaa..dd2e9a8 100644 (file)
@@ -455,7 +455,7 @@ public:
   Tensor sqrt() const;
   Tensor & sqrt_();
   Tensor std(bool unbiased=true) const;
-  Tensor std(int64_t dim, bool unbiased=true, bool keepdim=false) const;
+  Tensor std(IntList dim, bool unbiased=true, bool keepdim=false) const;
   Tensor prod(ScalarType dtype) const;
   Tensor prod() const;
   Tensor prod(int64_t dim, bool keepdim, ScalarType dtype) const;
index 440950c..f3eb453 100644 (file)
@@ -598,7 +598,7 @@ inline Tensor & Tensor::sqrt_() {
 inline Tensor Tensor::std(bool unbiased) const {
     return type().std(*this, unbiased);
 }
-inline Tensor Tensor::std(int64_t dim, bool unbiased, bool keepdim) const {
+inline Tensor Tensor::std(IntList dim, bool unbiased, bool keepdim) const {
     return type().std(*this, dim, unbiased, keepdim);
 }
 inline Tensor Tensor::prod(ScalarType dtype) const {
index e9a5411..003fade 100644 (file)
@@ -367,7 +367,7 @@ struct CAFFE2_API Type {
   virtual Tensor sqrt(const Tensor & self) const = 0;
   virtual Tensor & sqrt_(Tensor & self) const = 0;
   virtual Tensor std(const Tensor & self, bool unbiased) const = 0;
-  virtual Tensor std(const Tensor & self, int64_t dim, bool unbiased, bool keepdim) const = 0;
+  virtual Tensor std(const Tensor & self, IntList dim, bool unbiased, bool keepdim) const = 0;
   virtual Tensor prod(const Tensor & self, ScalarType dtype) const = 0;
   virtual Tensor prod(const Tensor & self) const = 0;
   virtual Tensor prod(const Tensor & self, int64_t dim, bool keepdim, ScalarType dtype) const = 0;
index e23feb6..c0eb9a3 100644 (file)
@@ -20,6 +20,7 @@ namespace at {
 namespace native {
 
 DEFINE_DISPATCH(sum_stub);
+DEFINE_DISPATCH(std_stub);
 DEFINE_DISPATCH(prod_stub);
 DEFINE_DISPATCH(norm_kernel);
 
@@ -541,21 +542,29 @@ Tensor std(const Tensor& self, bool unbiased) {
   return trivial_return.has_value() ? trivial_return.value() : at::_th_std(self, unbiased);
 }
 
-Tensor std(const Tensor& self, int64_t dim, bool unbiased, bool keepdim) {
+Tensor std(const Tensor& self, IntList dim, bool unbiased, bool keepdim) {
   Tensor result = at::empty({0}, self.options());
   return at::native::std_out(result, self, dim, unbiased, keepdim);
 }
 
-Tensor &std_out(Tensor &result, const Tensor &self, int64_t dim, bool unbiased, bool keepdim) {
+Tensor &std_out(Tensor &result, const Tensor &self, IntList dim, bool unbiased, bool keepdim) {
   AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
            "std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
   AT_CHECK(at::isFloatingType(self.type().scalarType()), "std only supports floating-point dtypes");
-  dim = maybe_wrap_dim(dim, self.dim());
-  if (_dimreduce_return_trivial(result, self, std::numeric_limits<double>::quiet_NaN(), dim, keepdim)) {
-    return result;
-  } else {
-    return at::_th_std_out(result, self, dim, unbiased, keepdim);
+  if (self.type().backend() != Backend::CPU) {
+    // TODO(btv): implement multi-dim `std` and `var` on CUDA.
+    AT_CHECK(dim.size() == 1, "`std` across arbitrarily many dimensions is not yet supported for CUDA.")
+    int64_t one_dim = maybe_wrap_dim(dim[0], self.dim());
+    if (_dimreduce_return_trivial(result, self, std::numeric_limits<double>::quiet_NaN(), one_dim, keepdim)) {
+      return result;
+    } else {
+      return at::_th_std_out(result, self, one_dim, unbiased, keepdim);
+    }
   }
+  ScalarType dtype = get_dtype(result, self, {}, true);
+  auto iter = make_reduction("std", result, self, dim, keepdim, dtype);
+  std_stub(iter->device_type(), *iter, unbiased);
+  return result;
 }
 
 }} // namespace at::native
index c86494f..34f4752 100644 (file)
@@ -15,6 +15,10 @@ using reduce_fn = void(*)(TensorIterator &);
 DECLARE_DISPATCH(reduce_fn, sum_stub);
 DECLARE_DISPATCH(reduce_fn, prod_stub);
 
+using reduce_std_function =
+  void (*)(TensorIterator&, bool unbiased);
+DECLARE_DISPATCH(reduce_std_function, std_stub);
+
 using reduce_norm_fn =
     void (*)(Tensor&, const Tensor&, Scalar, c10::optional<int64_t>);
 DECLARE_DISPATCH(reduce_norm_fn, norm_kernel);
index 27a758f..ccf0bf2 100644 (file)
@@ -6,19 +6,6 @@
 
 namespace at {
 
-struct DimCounter {
-  DimCounter(IntList shape, Range range);
-
-  void increment(const std::array<int64_t, 2>& step);
-  bool is_done() const;
-  std::array<int64_t, 2> max_step() const;
-
-  IntList shape;
-  Range range;
-  DimVector values;
-  int64_t offset;
-};
-
 using DimMask = TensorIterator::DimMask;
 using PtrVector = TensorIterator::PtrVector;
 using loop_t = TensorIterator::loop_t;
@@ -405,7 +392,7 @@ void TensorIterator::serial_for_each(const loop2d_t& loop, Range range) const {
     auto counter = DimCounter(shape_, range);
     while (!counter.is_done()) {
       auto ptrs = get_data_ptrs(base_ptrs, counter.values);
-      auto step = counter.max_step();
+      auto step = counter.max_2d_step();
       loop(ntensors(), ptrs.data(), strides.data(), step[0], step[1]);
       counter.increment(step);
     }
@@ -463,6 +450,16 @@ void TensorIterator::narrow(int dim, int64_t start, int64_t size) {
   }
 }
 
+void TensorIterator::select_all_keeping_dim(int start_dim, IntList indices) {
+  AT_ASSERT(start_dim <= ndim());
+  for (int i = start_dim; i < ndim(); ++i) {
+    for (auto& op : operands_) {
+      op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim];
+    }
+    shape_[i] = 1;
+  }
+}
+
 std::unique_ptr<TensorIterator> TensorIterator::binary_op(Tensor& out, const Tensor& a, const Tensor& b) {
   auto builder = TensorIterator::Builder();
   if (a.device().is_cuda() && b.device().is_cuda()) {
@@ -721,7 +718,7 @@ void DimCounter::increment(const std::array<int64_t, 2>& step) {
   AT_ASSERT(overflow == 0 || overflow == 1);
 }
 
-std::array<int64_t, 2> DimCounter::max_step() const {
+std::array<int64_t, 2> DimCounter::max_2d_step() const {
   int64_t step0 = std::min(shape[0] - values[0], range.end - offset);
   int64_t step1 = 1;
   if (step0 == shape[0] && shape.size() >= 1) {
index 4c36ed9..848baee 100644 (file)
 
 namespace at {
 
+struct DimCounter {
+  DimCounter(IntList shape, Range range);
+
+  void increment(const std::array<int64_t, 2>& step);
+  bool is_done() const;
+  std::array<int64_t, 2> max_2d_step() const;
+
+  IntList shape;
+  Range range;
+  DimVector values;
+  int64_t offset;
+};
 struct CAFFE2_API OperandInfo {
   OperandInfo() {}
   OperandInfo(const Tensor& t, const Type* type=nullptr)
@@ -109,6 +121,10 @@ struct CAFFE2_API TensorIterator {
   using loop_t = std::function<void(int ntensors, char** data, const int64_t* strides, int64_t size)>;
   using loop2d_t = std::function<void(int ntensors, char** data, const int64_t* strides, int64_t size0, int64_t size1)>;
 
+  using loop_subiter_t = std::function<void(TensorIterator& subiter)>;
+
+  void foreach_reduced_elt(const loop_subiter_t& loop);
+
   static std::unique_ptr<TensorIterator> binary_op(Tensor& out, const Tensor& a, const Tensor& b);
   static std::unique_ptr<TensorIterator> reduce_op(Tensor& out, const Tensor& a);
 
@@ -155,6 +171,8 @@ struct CAFFE2_API TensorIterator {
   void remove_dimension(int dim);
   /// Shrinks an iterated dimension
   void narrow(int dim, int64_t start, int64_t size);
+  /// Narrows every dim after and including `start_dim` to size one.
+  void select_all_keeping_dim(int start_dim, IntList starts);
   /// Replaces the data pointer and strides for the operand at index `arg`
   void replace_operand(int arg, void* data, IntList stride);
 
index 30af03a..41c0aa0 100644 (file)
@@ -115,4 +115,47 @@ static void parallel_dim_reduction(TensorIterator& iter, const loop2d_t& loop) {
   });
 }
 
+void TensorIterator::foreach_reduced_elt(const loop_subiter_t &loop) {
+  AT_ASSERT(ntensors() == 2 && num_outputs_ == 1);
+
+  auto shape = this->shape();
+  if (tensor(0).numel() == 0) {
+    return;
+  }
+  if (tensor(0).numel() == 1) {
+    loop(*this);
+  }
+  else if (numel() < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region()) {
+    auto reduce_dims = num_reduce_dims();
+
+    auto non_reduced_shape = shape.slice(reduce_dims, shape.size() - reduce_dims);
+
+    int64_t non_reduced_numel = 1;
+    for (int i = 0; i < non_reduced_shape.size(); ++i) {
+      non_reduced_numel *= non_reduced_shape[i];
+    }
+    DimCounter dims {non_reduced_shape, {0, non_reduced_numel}};
+    while (!dims.is_done()) {
+      TensorIterator reduced = *this;
+      reduced.select_all_keeping_dim(reduce_dims, dims.values);
+      loop(reduced);
+      dims.increment({1, 1});
+    }
+  }
+  else {
+    int dim = find_split_dim(*this);
+    int64_t cols = shape[dim];
+    at::parallel_for(0, cols, 1, [&](int64_t begin, int64_t end) {
+      if (begin == end) {
+        return;
+      }
+
+      auto sub_iter = *this;
+
+      sub_iter.narrow(dim, begin, end - begin);
+      sub_iter.foreach_reduced_elt(loop);
+    });
+  }
+}
+
 }  // namespace at
index bc48b6f..2b83585 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <ATen/native/cpu/Loops.h>
 #include <ATen/Parallel.h>
+#include <c10/util/TypeList.h>
 
 #include <sstream>
 
@@ -24,14 +25,101 @@ static inline bool is_outer_reduction(const int64_t* strides) {
          strides[3] == sizeof(typename traits::arg2_t);
 }
 
+template <typename T, typename... Args>
+struct all_same : c10::guts::conjunction<
+  std::is_same<T, Args>...
+> {};
+
+// data_t is the input/output data type.
+// acc_t is a type that contains all the necessary data
+// to continue reducing.
+//
+// Then:
+// reduce: (acc_t, data_t) -> acc_t adds one data point to the accumulated value.
+// combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one.
+// project: acc_t -> data_t finishes the reduction, getting the required output.
+//
+// Additionally, acc_t must be default-constructible:
+// acc_t {} is an identity for combine,
+// and project(acc_t {}) is the value of the operation on zero elements.
+//
+// The point of `combine` is to support parallelization -
+// the idea is to one sequence of `reduce` calls per thread of execution,
+// and then to combine them at the end with `combine`.
+//
+// If there is more than one output element,
+// our parallelization strategy is to use one thread for each of them,
+// which means that `combine` will never be called.
+//
+// If, on the other hand, there is only one, then we split the input into
+// into several pieces, reduce each separately, and then combine them.
+
+template <typename rf_t,
+          typename cf_t,
+          typename pf_t>
+void binary_kernel_reduce(TensorIterator& iter, rf_t const &reduce, cf_t const &combine, pf_t const &project) {
+  using r_traits = binary_function_traits<rf_t>;
+  using c_traits = binary_function_traits<cf_t>;
+  using p_traits = unary_function_traits<pf_t>;
+  using acc_t = typename p_traits::arg1_t;
+  using data_t = typename p_traits::result_type;
+  static_assert(
+    all_same<
+      acc_t,
+      typename r_traits::arg1_t,
+      typename r_traits::result_type,
+      typename c_traits::arg1_t,
+      typename c_traits::arg2_t,
+      typename c_traits::result_type>::value,
+    "all accumulate types must match");
+  static_assert(
+    std::is_same<data_t, typename r_traits::arg2_t>::value,
+    "all data types must match");
+  static_assert(
+    std::is_default_constructible<acc_t>::value,
+    "the accumulate type must be default-constructible"
+  );
+  iter.foreach_reduced_elt([&](TensorIterator &sub_iter) {
+    auto numel = sub_iter.numel();
+    bool serial = numel < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region();
+    int max_threads = serial ? 1 : at::get_max_threads();
+    AT_ASSERT(max_threads > 0);
+    std::vector<optional<acc_t>> buffer{(unsigned)max_threads, optional<acc_t> {}};
+    at::parallel_for(0, numel, serial ? (1 + numel) : internal::GRAIN_SIZE,
+    [&](int64_t begin, int64_t end) {
+      auto &acc = buffer[at::get_thread_num()];
+      sub_iter.serial_for_each([&acc, &reduce](int ntensors, char** data, const int64_t* strides, int64_t size) {
+        AT_ASSERT(ntensors == 2);
+        char *in = data[1];
+        int64_t stride = strides[1];
+        if (!acc && size > 0) {
+          acc = acc_t {};
+        }
+        for (int64_t i = 0; i < size; ++i) {
+          acc = reduce(*acc, *(data_t*)in);
+          in += stride;
+        }
+      }, {begin, end});
+    });
+    acc_t acc;
+    for (int i = 0; i < max_threads; ++i) {
+      if (buffer[i]) {
+        acc = combine(acc, *buffer[i]);
+      }
+    }
+    char *out = (char *)sub_iter.data_ptr(0);
+    *(data_t*)out = project(acc);
+  });
+}
+
 template <typename func_t, typename vec_func_t>
 void binary_kernel_reduce_vec(TensorIterator& iter, func_t op, vec_func_t vop, double ident=0) {
   using traits = binary_function_traits<func_t>;
   static_assert(
-    std::is_same<typename traits::result_type, typename traits::arg1_t>::value,
-    "all types must match");
-  static_assert(
-    std::is_same<typename traits::result_type, typename traits::arg2_t>::value,
+    all_same<
+      typename traits::result_type,
+      typename traits::arg1_t,
+      typename traits::arg2_t>::value,
     "all types must match");
 
   iter.output().fill_(ident);
index 29fb0d0..0796097 100644 (file)
@@ -22,6 +22,52 @@ static void sum_kernel_impl(TensorIterator& iter) {
   });
 }
 
+struct WelfordData {
+  double mean;
+  double m2;
+  int64_t n;
+  WelfordData() : mean(0), m2(0), n(0)  {}
+  WelfordData(double mean, double m2, int64_t n) : mean(mean), m2(m2), n(n) {}
+};
+
+static void std_kernel_impl(TensorIterator &iter, bool unbiased) {
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.type(), "std", [&] {
+    binary_kernel_reduce(
+      iter,
+      [](WelfordData acc, scalar_t data) -> WelfordData {
+        double delta = data - acc.mean;
+        double new_mean = acc.mean + delta / (acc.n + 1);
+        double new_delta = data - new_mean;
+        return {
+          new_mean,
+          acc.m2 + delta * new_delta,
+          acc.n + 1
+        };
+      },
+      [](WelfordData a, WelfordData b) -> WelfordData {
+        if (a.n == 0) {
+          return b;
+        }
+        if (b.n == 0) {
+          return a;
+        }
+        double delta = b.mean - a.mean;
+        int64_t new_count = a.n + b.n;
+        double nb_over_n = (double)b.n / new_count;
+        return {
+          a.mean + delta * nb_over_n,
+          a.m2 + b.m2 + delta * delta * a.n * nb_over_n,
+          new_count
+        };
+      },
+      [unbiased](WelfordData acc) -> scalar_t {
+        int64_t divisor = unbiased ? (acc.n - 1) : acc.n;
+        return (divisor > 0) ? std::sqrt(acc.m2 / divisor) : NAN;
+      }
+    );
+  });
+}
+
 static void prod_kernel_impl(TensorIterator& iter) {
   AT_DISPATCH_ALL_TYPES(iter.type(), "prod", [&] {
     binary_kernel_reduce_vec(
@@ -204,7 +250,9 @@ static void norm_kernel_impl(
 }  // anonymous namespace
 
 REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
+REGISTER_DISPATCH(std_stub, &std_kernel_impl);
 REGISTER_DISPATCH(prod_stub, &prod_kernel_impl);
 REGISTER_DISPATCH(norm_kernel, &norm_kernel_impl);
 
 }}  // namespace at::native
+
index 33767d3..93328f2 100644 (file)
 - func: std(Tensor self, bool unbiased=true) -> Tensor
   variants: function, method
 
-- func: std(Tensor self, int64_t dim, bool unbiased=true, bool keepdim=false) -> Tensor
+- func: std(Tensor self, IntList[1] dim, bool unbiased=true, bool keepdim=false) -> Tensor
   variants: function, method
 
-- func: std_out(Tensor result, Tensor self, int64_t dim, bool unbiased=true, bool keepdim=false) -> Tensor
+- func: std_out(Tensor result, Tensor self, IntList[1] dim, bool unbiased=true, bool keepdim=false) -> Tensor
 
 # FIXME: These could be combined as optional<ScalarType> but for https://github.com/pytorch/pytorch/issues/6593.
 - func: prod(Tensor self, *, ScalarType dtype) -> Tensor
index fb6d5ed..6808203 100644 (file)
@@ -93,6 +93,9 @@ class BytesIOContext(io.BytesIO):
     def __exit__(self, *args):
         pass
 
+DIM_TEST_SCENARIOS = [
+]
+
 
 # This is intentionally prefixed by an underscore. Otherwise pytest will try to
 # run its methods as test cases.
@@ -1931,56 +1934,72 @@ class _TestTorchMixin(object):
     def _assert_matches_numpy(self, t, n):
         self.assertEqual(n.shape, t.shape)
         if t.dtype == torch.float:
-            self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05))
+            self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05,
+                            equal_nan=True))
         else:
-            self.assertTrue(np.allclose(n, t.numpy()))
+            self.assertTrue(np.allclose(n, t.numpy(), equal_nan=True))
 
-    @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
-    def test_sum_dim(self):
-        def check_sum_dim(tensors_dict, dim):
+    def _test_dim_ops(self, pytorch_op, numpy_op,
+                      use_floating=True, use_integral=True):
+        def do_one(tensors_dict, dim):
             for category, tensors in tensors_dict.items():
                 if category == "slice":
                     dim = 0
                 for tensor in tensors:
-                    expected = tensor.numpy().sum(dim)
-                    actual = tensor.sum(dim)
+                    # we have no control over NumPy warnings...
+                    with warnings.catch_warnings():
+                        warnings.simplefilter("ignore")
+                        expected = numpy_op(tensor.numpy(), dim)
+                    actual = pytorch_op(tensor, dim)
                     self._assert_matches_numpy(actual, expected)
+        do_one(self._make_tensors((5, 400000), use_floating=use_floating,
+               use_integral=use_integral), 1)
+        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
+               use_integral=use_integral), 0)
+        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
+               use_integral=use_integral), 1)
+        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
+               use_integral=use_integral), 2)
+        do_one(self._make_tensors((100000, ), use_floating=use_floating,
+               use_integral=use_integral), -1)
+        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
+               use_integral=use_integral), 0)
+        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
+               use_integral=use_integral), 1)
+        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
+               use_integral=use_integral), 2)
+        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
+               use_integral=use_integral), (1, 2))
+        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
+               use_integral=use_integral), (1, -1))
+        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
+               use_integral=use_integral), (0, 2))
+        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
+               use_integral=use_integral), (0, 2, 1))
 
-        float_types = [torch.double, torch.float]
-        int_types = [torch.int64, torch.int32, torch.int16]
-
-        check_sum_dim(self._make_tensors((5, 400000)), 1)
-        check_sum_dim(self._make_tensors((3, 5, 7)), 0)
-        check_sum_dim(self._make_tensors((3, 5, 7)), 1)
-        check_sum_dim(self._make_tensors((3, 5, 7)), 2)
-        check_sum_dim(self._make_tensors((100000, )), -1)
-        check_sum_dim(self._make_tensors((50, 50, 50)), 0)
-        check_sum_dim(self._make_tensors((50, 50, 50)), 1)
-        check_sum_dim(self._make_tensors((50, 50, 50)), 2)
-        check_sum_dim(self._make_tensors((50, 50, 50)), (1, 2))
-        check_sum_dim(self._make_tensors((50, 50, 50)), (1, -1))
+    @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
+    def test_sum_dim(self):
+        for sizes, dim in DIM_TEST_SCENARIOS:
+            self._test_dim_ops(
+                lambda t, d: t.sum(d),
+                lambda n, d: n.sum(d))
 
     @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
     def test_mean_dim(self):
-        def check_mean_dim(tensors_dict, dim):
-            for category, tensors in tensors_dict.items():
-                if category == "slice":
-                    dim = 0
-                for tensor in tensors:
-                    expected = tensor.numpy().mean(dim)
-                    actual = tensor.mean(dim)
-                    self._assert_matches_numpy(actual, expected)
+        for sizes, dim in DIM_TEST_SCENARIOS:
+            self._test_dim_ops(
+                lambda t, d: t.mean(d),
+                lambda n, d: n.mean(d),
+                use_integral=False)
 
-        check_mean_dim(self._make_tensors((5, 400000), use_integral=False), 1)
-        check_mean_dim(self._make_tensors((3, 5, 7), use_integral=False), 0)
-        check_mean_dim(self._make_tensors((3, 5, 7), use_integral=False), 1)
-        check_mean_dim(self._make_tensors((3, 5, 7), use_integral=False), 2)
-        check_mean_dim(self._make_tensors((100000, ), use_integral=False), -1)
-        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), 0)
-        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), 1)
-        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), 2)
-        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), (1, 2))
-        check_mean_dim(self._make_tensors((50, 50, 50), use_integral=False), (1, -1))
+    @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
+    def test_std_dim(self):
+        for unbiased in [False, True]:
+            for sizes, dim in DIM_TEST_SCENARIOS:
+                self._test_dim_ops(
+                    lambda t, d: t.std(d, unbiased=unbiased),
+                    lambda n, d: n.std(d, ddof=1 if unbiased else 0),
+                    use_integral=False)
 
     def test_sum_out(self):
         x = torch.rand(100, 100)
index f907a98..be5a0c9 100644 (file)
 - name: std(Tensor self, bool unbiased)
   self: var_backward(grad / (result * 2), self, unbiased)
 
-- name: std(Tensor self, int64_t dim, bool unbiased, bool keepdim)
+- name: std(Tensor self, IntList dim, bool unbiased, bool keepdim)
   self: var_backward(grad / (result * 2), self, dim, unbiased, keepdim)
 
 - name: sub(Tensor self, Tensor other, *, Scalar alpha)
index 8cc45d9..8d71ba6 100644 (file)
@@ -159,18 +159,23 @@ Tensor permute_backwards(const Tensor & grad, IntList fwd_dims) {
   return grad.permute(dims);
 }
 
+Tensor unsqueeze_multiple(const Tensor & t, IntList dim, size_t n_dims) {
+    auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
+    Tensor res = t;
+    for (size_t i = 0; i < n_dims; i++){
+      if (dims_to_unsqueeze[i]) {
+        res = res.unsqueeze(i);
+      }
+    }
+    return res;
+}
+
 Tensor sum_backward(const Tensor & grad, IntList sizes, IntList dims, bool keepdim) {
   if (!keepdim && sizes.size() > 0) {
     if (dims.size()==1) {
       return grad.unsqueeze(dims[0]).expand(sizes);
     } else {
-      auto dims_to_unsqueeze = at::dim_list_to_bitset(dims, sizes.size());
-      Tensor res = grad;
-      for (size_t i = 0; i < sizes.size(); i++){
-        if (dims_to_unsqueeze[i]) {
-          res = res.unsqueeze(i);
-        }
-      }
+      Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
       return res.expand(sizes);
     }
   } else {
@@ -636,14 +641,14 @@ Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
   return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
 }
 
-Tensor var_backward(Tensor grad, const Tensor & self, int64_t dim, bool unbiased, bool keepdim) {
+Tensor var_backward(Tensor grad, const Tensor & self, IntList dim, bool unbiased, bool keepdim) {
   if (self.dim() == 0) {
     return var_backward(grad, self, unbiased);
   }
   if (!keepdim && self.dim() > 1) {
-    grad = grad.unsqueeze(dim);
+    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
   }
-  return (2.0 / (self.size(dim) - unbiased)) * grad * (self - self.mean(dim, true));
+  return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true));
 }
 
 Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntList sizes) {
index 9d14ef6..8b7531b 100644 (file)
@@ -899,7 +899,6 @@ class ShapePropagator {
             "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor",
-            "aten::std(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
             "aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
             "aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
@@ -956,6 +955,7 @@ class ShapePropagator {
     static const register_formula_for multidim_reduce_ops {
         {
             "aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
+            "aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
         },
         [](Node * node) -> type_vec_t {
           if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {