From 8232bd526fbf6b3720378e8d5a4ed93f284a4812 Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Tue, 8 Jan 2019 10:26:32 -0800 Subject: [PATCH] Move isnan to C++ (#15722) Summary: Wanted to use `Tensor.isnan` in C++, figured it'd be nice to have, so I made it into a tiny native function. gchanan ezyang apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/15722 Differential Revision: D13591315 Pulled By: goldsborough fbshipit-source-id: a78bd22101fde87a0257f759b9bfcf3b4208f5fa --- aten/src/ATen/native/TensorCompare.cpp | 4 ++++ aten/src/ATen/native/native_functions.yaml | 4 ++++ test/cpp/api/integration.cpp | 1 + torch/functional.py | 24 +++++++++++------------- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index b9cc8a8..6d0f7b1 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -57,6 +57,10 @@ Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol return close; } +Tensor isnan(const Tensor& self) { + return self != self; +} + bool is_nonzero(const Tensor& self) { auto n = self.numel(); AT_ASSERT(n >= 0); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1a35f06..7c6cbe9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -917,6 +917,10 @@ - func: isclose(Tensor self, Tensor other, double rtol=1e-5, double atol=1e-8, bool equal_nan=False) -> Tensor variants: function, method +- func: isnan(Tensor self) -> Tensor + variants: function + device_guard: false + - func: is_distributed(Tensor self) -> bool variants: function, method device_guard: false diff --git a/test/cpp/api/integration.cpp b/test/cpp/api/integration.cpp index 63336c2..7d424a7 100644 --- a/test/cpp/api/integration.cpp +++ b/test/cpp/api/integration.cpp @@ -135,6 +135,7 @@ bool test_mnist( auto data = batch.data.to(device), targets = batch.target.to(device); torch::Tensor prediction = forward_op(std::move(data)); torch::Tensor loss = torch::nll_loss(prediction, std::move(targets)); + AT_ASSERT(!torch::isnan(loss).any().item()); optimizer.zero_grad(); loss.backward(); optimizer.step(); diff --git a/torch/functional.py b/torch/functional.py index f04297c..81873ee 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F from torch._six import inf +from torch._C import _add_docstr from operator import mul from functools import reduce from itertools import product @@ -373,23 +374,20 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None, return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided) -def isnan(tensor): - r"""Returns a new tensor with boolean elements representing if each element is `NaN` or not. +isnan = _add_docstr(torch.isnan, r""" +Returns a new tensor with boolean elements representing if each element is `NaN` or not. - Arguments: - tensor (Tensor): A tensor to check +Arguments: + tensor (Tensor): A tensor to check - Returns: - Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements. +Returns: + Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements. - Example:: +Example:: - >>> torch.isnan(torch.tensor([1, float('nan'), 2])) - tensor([ 0, 1, 0], dtype=torch.uint8) - """ - if not isinstance(tensor, torch.Tensor): - raise ValueError("The argument is not a tensor", str(tensor)) - return tensor != tensor + >>> torch.isnan(torch.tensor([1, float('nan'), 2])) + tensor([ 0, 1, 0], dtype=torch.uint8) +""") def unique(input, sorted=True, return_inverse=False, dim=None): -- 2.7.4