Move isnan to C++ (#15722)
authorPeter Goldsborough <psag@fb.com>
Tue, 8 Jan 2019 18:26:32 +0000 (10:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 8 Jan 2019 18:42:33 +0000 (10:42 -0800)
Summary:
Wanted to use `Tensor.isnan` in C++, figured it'd be nice to have, so I made it into a tiny native function.

gchanan ezyang apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15722

Differential Revision: D13591315

Pulled By: goldsborough

fbshipit-source-id: a78bd22101fde87a0257f759b9bfcf3b4208f5fa

aten/src/ATen/native/TensorCompare.cpp
aten/src/ATen/native/native_functions.yaml
test/cpp/api/integration.cpp
torch/functional.py

index b9cc8a8..6d0f7b1 100644 (file)
@@ -57,6 +57,10 @@ Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol
   return close;
 }
 
+Tensor isnan(const Tensor& self) {
+  return self != self;
+}
+
 bool is_nonzero(const Tensor& self) {
   auto n = self.numel();
   AT_ASSERT(n >= 0);
index 1a35f06..7c6cbe9 100644 (file)
 - func: isclose(Tensor self, Tensor other, double rtol=1e-5, double atol=1e-8, bool equal_nan=False) -> Tensor
   variants: function, method
 
+- func: isnan(Tensor self) -> Tensor
+  variants: function
+  device_guard: false
+
 - func: is_distributed(Tensor self) -> bool
   variants: function, method
   device_guard: false
index 63336c2..7d424a7 100644 (file)
@@ -135,6 +135,7 @@ bool test_mnist(
       auto data = batch.data.to(device), targets = batch.target.to(device);
       torch::Tensor prediction = forward_op(std::move(data));
       torch::Tensor loss = torch::nll_loss(prediction, std::move(targets));
+      AT_ASSERT(!torch::isnan(loss).any().item<int64_t>());
       optimizer.zero_grad();
       loss.backward();
       optimizer.step();
index f04297c..81873ee 100644 (file)
@@ -1,6 +1,7 @@
 import torch
 import torch.nn.functional as F
 from torch._six import inf
+from torch._C import _add_docstr
 from operator import mul
 from functools import reduce
 from itertools import product
@@ -373,23 +374,20 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None,
     return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
 
 
-def isnan(tensor):
-    r"""Returns a new tensor with boolean elements representing if each element is `NaN` or not.
+isnan = _add_docstr(torch.isnan, r"""
+Returns a new tensor with boolean elements representing if each element is `NaN` or not.
 
-    Arguments:
-        tensor (Tensor): A tensor to check
+Arguments:
+    tensor (Tensor): A tensor to check
 
-    Returns:
-        Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements.
+Returns:
+    Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements.
 
-    Example::
+Example::
 
-        >>> torch.isnan(torch.tensor([1, float('nan'), 2]))
-        tensor([ 0,  1,  0], dtype=torch.uint8)
-    """
-    if not isinstance(tensor, torch.Tensor):
-        raise ValueError("The argument is not a tensor", str(tensor))
-    return tensor != tensor
+    >>> torch.isnan(torch.tensor([1, float('nan'), 2]))
+    tensor([ 0,  1,  0], dtype=torch.uint8)
+""")
 
 
 def unique(input, sorted=True, return_inverse=False, dim=None):