Port `mean` kernel to structured kernels. (#61643)
authorYukio Siraichi <yukio.siraichi@gmail.com>
Fri, 13 Aug 2021 15:20:19 +0000 (08:20 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 15:26:01 +0000 (08:26 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61643

Tracking issue: #55070

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D29783866

Pulled By: ezyang

fbshipit-source-id: dc95baf593096c03fb5f292ee6c36de3cc7f2b35

aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native/quantized/cpu/qreduction.cpp

index 6c2db2a..c527e22 100644 (file)
@@ -192,6 +192,31 @@ TORCH_META_FUNC2(sum, dim_IntList)
   namedinference::propagate_names_for_reduction(result, self, dims, keepdim);
 }
 
+TORCH_META_FUNC2(mean, dim)
+(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
+  auto self_dtype = self.scalar_type();
+  TORCH_CHECK(
+      at::isFloatingType(self_dtype) || at::isComplexType(self_dtype),
+      "Can only calculate the mean of floating types. Got ",
+      toString(self_dtype), " instead.");
+
+  ScalarType dtype;
+  const auto& result = maybe_get_output();
+
+  if (result.defined()) {
+    dtype = opt_dtype.value_or(result.scalar_type());
+  } else {
+    dtype = at::native::get_dtype_from_self(self, opt_dtype, true);
+  }
+
+  DimVector dims(dim);
+  maybe_wrap_dims(dims, self.dim());
+
+  DimVector shape = get_reduction_shape(self, dims, keepdim);
+  set_output(shape, self.options().dtype(dtype));
+  namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
+}
+
 } // namespace meta
 
 namespace meta {
@@ -1056,15 +1081,13 @@ Tensor& prod_out(const Tensor& self, Dimname dim,
   return at::prod_out(result, self, dimname_to_position(self, dim), keepdim, opt_dtype);
 }
 
-Tensor &mean_out_cpu_gpu(const Tensor &self, IntArrayRef dim,
-                 bool keepdim, c10::optional<ScalarType> opt_dtype, Tensor &result) {
-  ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
-  TORCH_CHECK(
-      at::isFloatingType(scalarType) || at::isComplexType(scalarType),
-      "Can only calculate the mean of floating types. Got ",
-      toString(scalarType),
-      " instead.");
-  ScalarType dtype = get_dtype_from_result(result, opt_dtype);
+TORCH_IMPL_FUNC(mean_out)
+(const Tensor& self,
+ IntArrayRef dim,
+ bool keepdim,
+ c10::optional<ScalarType> opt_dtype,
+ const Tensor& result) {
+  ScalarType dtype = result.scalar_type();
   // TODO: the TensorIterator reduction implementation of mean
   // (mean_kernel_impl()) is unvectorized and leads to very poor performance
   // for production workloads. Once that's fixed, the following code can be used
@@ -1078,27 +1101,22 @@ Tensor &mean_out_cpu_gpu(const Tensor &self, IntArrayRef dim,
         dim_prod *= self.size(d);
       }
     }
-    at::sum_out(result, self, dim, keepdim, dtype).div_(dim_prod);
-    return result;
-  }
-
-  auto iter = make_reduction("mean", result, self, dim, keepdim, dtype);
-  if (iter.numel() == 0) {
-    result.fill_(std::numeric_limits<double>::quiet_NaN());
+    auto& result_mut = const_cast<Tensor&>(result);
+    at::sum_out(result_mut, self, dim, keepdim, dtype).div_(dim_prod);
   } else {
-    mean_stub(iter.device_type(), iter);
+    DimVector dims(dim);
+    auto iter = at::meta::make_reduction_from_out_ty(
+        self, result, dims, keepdim, dtype);
+    if (iter.numel() == 0) {
+      result.fill_(std::numeric_limits<double>::quiet_NaN());
+    } else {
+      mean_stub(iter.device_type(), iter);
+    }
   }
-  return result;
 }
 
 Tensor mean_cpu_gpu(const Tensor &self, optional<ScalarType> dtype) {
-  return at::native::mean_cpu_gpu(self, IntArrayRef{}, false, dtype);
-}
-
-Tensor mean_cpu_gpu(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
-  ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
-  Tensor result = create_reduction_result(self, dim, keepdim, dtype);
-  return at::native::mean_out_cpu_gpu(self, dim, keepdim, dtype, result);
+  return at::mean(self, IntArrayRef{}, false, dtype);
 }
 
 Tensor mean(const Tensor& self, DimnameList dim, bool keepdim, optional<ScalarType> dtype) {
index b0d7131..92aac19 100644 (file)
     QuantizedCPU: mean_quantized_cpu
 
 - func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  structured_delegate: mean.out
   device_check: NoCheck   # TensorIterator
   variants: function, method
   dispatch:
-    CPU, CUDA: mean_cpu_gpu
     QuantizedCPU: mean_quantized_cpu
 
 - func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  structured: True
   device_check: NoCheck   # TensorIterator
   dispatch:
-    CPU, CUDA: mean_out_cpu_gpu
+    CPU, CUDA: mean_out
     QuantizedCPU: mean_out_quantized_cpu
 
 - func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
index d8090d0..55b7c68 100644 (file)
@@ -97,8 +97,7 @@ Tensor& mean_out_quantized_cpu(
   }
 #endif
   auto self_dequantized = self.dequantize();
-  auto result_dequantized =
-      at::native::mean_cpu_gpu(self_dequantized, dim, keepdim, opt_dtype);
+  auto result_dequantized = at::mean(self_dequantized, dim, keepdim, opt_dtype);
   result = at::quantize_per_tensor(
       result_dequantized,
       self.q_scale(),