Initial implementation of nanmean (#62671)
authorHeitor Schueroff <heitorschueroff@fb.com>
Mon, 13 Sep 2021 12:50:27 +0000 (05:50 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 12:53:58 +0000 (05:53 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62671

Very crude first implementation of `torch.nanmean`. The current reduction kernels do not have good support for implementing nan* variants. Rather than implementing new kernels for each nan* operator, I will work on new reduction kernels with support for a `nan_policy` flag and then I will port `nanmean` to use that.

**TODO**

- [x] Fix autograd issue

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D30515181

Pulled By: heitorschueroff

fbshipit-source-id: 303004ebd7ac9cf963dc4f8e2553eaded5f013f0

aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/native_functions.yaml
docs/source/tensors.rst
docs/source/torch.rst
torch/_tensor_docs.py
torch/_torch_docs.py
torch/overrides.py
torch/testing/_internal/common_methods_invocations.py

index fc8de31..d766c69 100644 (file)
@@ -470,6 +470,7 @@ _(aten, max_unpool3d_backward) \
 _(aten, max_unpool3d_forward) \
 _(aten, max_values) \
 _(aten, mean) \
+_(aten, nanmean) \
 _(aten, median) \
 _(aten, nanmedian) \
 _(aten, meshgrid) \
index 620908b..4bef219 100644 (file)
@@ -1169,6 +1169,36 @@ Tensor& mean_out(const Tensor& self, DimnameList dim,
   return at::mean_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype);
 }
 
+// TODO(@heitorschueroff) implement custom kernels for nanmean
+Tensor& nanmean_out(
+    const Tensor& self,
+    IntArrayRef dim,
+    bool keepdim,
+    c10::optional<ScalarType> opt_dtype,
+    Tensor& result) {
+  TORCH_CHECK(
+      self.is_floating_point(),
+      "nanmean(): expected input to have floating point dtype but got ",
+      self.scalar_type());
+  const auto factor = at::native::isnan(self).logical_not_().sum(dim, keepdim);
+  at::native::nansum_out(self, dim, keepdim, opt_dtype, result).div_(factor);
+  return result;
+}
+
+Tensor nanmean(
+    const Tensor& self,
+    IntArrayRef dim,
+    bool keepdim,
+    optional<ScalarType> opt_dtype) {
+  TORCH_CHECK(
+      self.is_floating_point(),
+      "nanmean(): expected input to have floating point dtype but got ",
+      self.scalar_type());
+  const auto factor =
+      at::native::isnan(self.detach()).logical_not_().sum(dim, keepdim);
+  return at::nansum(self, dim, keepdim, opt_dtype).div_(factor);
+}
+
 static Tensor squeeze_multiple(const Tensor& self, IntArrayRef dims) {
   int ndims = self.sizes().size();
   auto dims_to_squeeze = at::dim_list_to_bitset(dims, ndims);
index 7401900..1ef7e8d 100644 (file)
 - func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
   device_check: NoCheck   # TensorIterator
 
+- func: nanmean(Tensor self, int[1] dim=[], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # Composite
+  variants: function, method
+
+- func: nanmean.out(Tensor self, int[1] dim=[], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # Composite
+
 - func: median(Tensor self) -> Tensor
   variants: function, method
   dispatch:
index 47edcb1..3ec5f74 100644 (file)
@@ -491,6 +491,7 @@ Tensor class reference
     Tensor.max
     Tensor.maximum
     Tensor.mean
+    Tensor.nanmean
     Tensor.median
     Tensor.nanmedian
     Tensor.min
index 5aa5dbc..084dad3 100644 (file)
@@ -397,6 +397,7 @@ Reduction Ops
     dist
     logsumexp
     mean
+    nanmean
     median
     nanmedian
     mode
index eb8d4ec..cec7865 100644 (file)
@@ -2421,13 +2421,18 @@ argmax(dim=None, keepdim=False) -> LongTensor
 See :func:`torch.argmax`
 """)
 
-add_docstr_all('mean',
-               r"""
-mean(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor)
+add_docstr_all('mean', r"""
+mean(dim=None, keepdim=False, *, dtype=None) -> Tensor
 
 See :func:`torch.mean`
 """)
 
+add_docstr_all('nanmean', r"""
+nanmean(dim=None, keepdim=False, *, dtype=None) -> Tensor
+
+See :func:`torch.nanmean`
+""")
+
 add_docstr_all('median',
                r"""
 median(dim=None, keepdim=False) -> (Tensor, LongTensor)
index d0728dd..f52ffd5 100644 (file)
@@ -5872,15 +5872,17 @@ Example::
     tensor([ 0,  2,  0,  1])
 """.format(**single_dim_common))
 
-add_docstr(torch.mean,
-           r"""
-mean(input) -> Tensor
+add_docstr(torch.mean, r"""
+mean(input, *, dtype=None) -> Tensor
 
 Returns the mean value of all elements in the :attr:`input` tensor.
 
 Args:
     {input}
 
+Keyword args:
+    {dtype}
+
 Example::
 
     >>> a = torch.randn(1, 3)
@@ -5889,7 +5891,7 @@ Example::
     >>> torch.mean(a)
     tensor(0.3367)
 
-.. function:: mean(input, dim, keepdim=False, *, out=None) -> Tensor
+.. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor
    :noindex:
 
 Returns the mean value of each row of the :attr:`input` tensor in the given
@@ -5904,8 +5906,13 @@ Args:
     {keepdim}
 
 Keyword args:
+    {dtype}
     {out}
 
+.. seealso::
+
+    :func:`torch.nanmean` computes the mean value of `non-NaN` elements.
+
 Example::
 
     >>> a = torch.randn(4, 4)
@@ -5923,6 +5930,48 @@ Example::
             [ 0.1807]])
 """.format(**multi_dim_common))
 
+add_docstr(torch.nanmean, r"""
+nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor
+
+Computes the mean of all `non-NaN` elements along the specified dimensions.
+
+This function is identical to :func:`torch.mean` when there are no `NaN` values
+in the :attr:`input` tensor. In the presence of `NaN`, :func:`torch.mean` will
+propagate the `NaN` to the output whereas :func:`torch.nanmean` will ignore the
+`NaN` values (`torch.nanmean(a)` is equivalent to `torch.mean(a[~a.isnan()])`).
+
+{keepdim_details}
+
+Args:
+    {input}
+    {dim} If `None`, reduces all dimensions. Default is `None`.
+    {keepdim}
+
+Keyword args:
+    {dtype}
+    {out}
+
+.. seealso::
+
+    :func:`torch.mean` computes the mean value, propagating `NaN`.
+
+Example::
+
+    >>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]])
+    >>> x.mean()
+    tensor(nan)
+    >>> x.nanmean()
+    tensor(1.8000)
+    >>> x.mean(dim=0)
+    tensor([   nan, 1.5000, 2.5000])
+    >>> x.nanmean(dim=0)
+    tensor([1.0000, 1.5000, 2.5000])
+
+    # If all elements in the reduced dimensions are NaN then the result is NaN
+    >>> torch.tensor([torch.nan]).nanmean()
+    tensor(nan)
+""".format(**multi_dim_common))
+
 add_docstr(torch.median,
            r"""
 median(input) -> Tensor
index 6c545a0..f280746 100644 (file)
@@ -617,6 +617,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
         torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
                                         return_indices=False, ceil_mode=False: -1),
         torch.mean: lambda input, dim=None: -1,
+        torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
         torch.median: lambda input, dim=None: -1,
         torch.nanmedian: lambda input, dim=None: -1,
         torch.meshgrid: lambda *tensors, **kwargs: -1,
index 84ff419..da7b3a0 100644 (file)
@@ -2746,6 +2746,28 @@ def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad
                                           requires_grad=requires_grad),))
     return inputs
 
+def _generate_nan_reduction_inputs(device, dtype, requires_grad):
+    yield from _generate_reduction_inputs(device, dtype, requires_grad)
+    yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad)
+    yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad)
+
+def sample_inputs_nan_reduction(supports_multiple_dims):
+    # Generates sample inputs for reduction ops that contain the input tensor
+    # and dim and keepdim kwargs. If a reduction op needs to test additional
+    # args/kwargs then create a separate sample_inputs function
+    def fn(op_info, device, dtype, requires_grad):
+        inputs = []
+
+        for t in _generate_nan_reduction_inputs(device, dtype, requires_grad):
+            # Add case without dim and keepdim kwargs
+            inputs.append(SampleInput(t))
+            for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims):
+                inputs.append(SampleInput(t, kwargs=kwargs))
+
+        return inputs
+
+    return fn
+
 def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad):
     test_quantiles = (0.5, make_tensor((2,), device, dtype, low=0, high=1))
     test_interpolations = ['linear', 'midpoint']
@@ -9490,6 +9512,32 @@ op_db: List[OpInfo] = [
         ),
     ),
     ReductionOpInfo(
+        'nanmean',
+        nan_policy='omit',
+        assert_autodiffed=True,
+        promotes_int_to_float=True,
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True),
+        ref=reference_reduction_numpy(np.nanmean),
+        skips=(
+            # RuntimeError: deepEquals(input.iValue, deepCopiedInput)INTERNAL ASSERT FAILED at
+            # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":142, please report a bug to PyTorch.
+            SkipInfo('TestJit', 'test_variant_consistency_jit'),
+            # FIXME: prod reduces all dimensions when dim=[]
+            SkipInfo('TestReductions', 'test_dim_empty'),
+            SkipInfo('TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            SkipInfo('TestReductions', 'test_noncontiguous_all',
+                     dtypes=[torch.float16]),
+            SkipInfo('TestReductions', 'test_ref_small_input',
+                     dtypes=[torch.float16]),
+            SkipInfo('TestReductions', 'test_ref_duplicate_values',
+                     device_type='cuda', dtypes=[torch.float16]),
+            SkipInfo('TestReductions', 'test_ref_extremal_values',
+                     device_type='cuda', dtypes=[torch.complex64]),
+        ),
+    ),
+    ReductionOpInfo(
         'prod',
         identity=1,
         nan_policy='propagate',