From 001cffed9d6c4a99e1c52270cc0def78d2292eec Mon Sep 17 00:00:00 2001 From: Lara Haidar-Ahmad Date: Fri, 15 Mar 2019 12:10:32 -0700 Subject: [PATCH] ONNX Export IsNan op Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17698 Reviewed By: zrphercule Differential Revision: D14470646 Pulled By: houseroad fbshipit-source-id: d3e6adc83c4f9fa288c5fe0ae4c6af71fdd47905 --- caffe2/operators/utility_ops.cc | 9 +++++ caffe2/operators/utility_ops.h | 24 ++++++++++++ test/onnx/expect/TestOperators.test_isnan.expect | 50 ++++++++++++++++++++++++ test/onnx/test_operators.py | 4 ++ test/onnx/test_pytorch_onnx_caffe2.py | 8 ++++ torch/onnx/symbolic.py | 7 ++++ 6 files changed, 102 insertions(+) create mode 100644 test/onnx/expect/TestOperators.test_isnan.expect diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index d9d6646..0edc47e 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -833,6 +833,15 @@ OPERATOR_SCHEMA(NanCheck) "Tensor to copy input into if no NaNs or inf." " Can be in-place"); +REGISTER_CPU_OPERATOR(IsNaN, IsNanOp); + +OPERATOR_SCHEMA(IsNaN) + .NumInputs(1) + .NumOutputs(1) + .SetDoc("Returns a new tensor with boolean elements representing if each element is NaN or not.") + .Input(0, "tensor", "Tensor to check for nan") + .Output(0, "output", "Tensor containing a 1 at each location of NaN elements."); + OPERATOR_SCHEMA(Size) .NumInputs(1) .NumOutputs(1) diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index 7b92b2c..1287c82 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -43,6 +43,30 @@ struct GetNanCheckGradient : public GradientMakerBase { }; template +class IsNanOp final : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + IsNanOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) {} + + bool RunOnDevice() override { + return DispatchHelper>::call(this, Input(0)); + } + + template + bool DoRunWithType() { + auto& X = Input(0); + auto* Y = Output(0, X.sizes(), at::dtype()); + const auto* X_data = X.template data(); + uint8_t* Y_data = Y->template mutable_data(); + for (size_t i = 0; i < X.numel(); i++) { + Y_data[i] = (uint8_t)(std::isnan(X_data[i])); + } + return true; + } +}; + +template class WallClockTimeOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; diff --git a/test/onnx/expect/TestOperators.test_isnan.expect b/test/onnx/expect/TestOperators.test_isnan.expect new file mode 100644 index 0000000..17e62c0 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_isnan.expect @@ -0,0 +1,50 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.1" +graph { + node { + input: "0" + output: "1" + op_type: "IsNaN" + } + node { + input: "1" + output: "2" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } + output { + name: "2" + type { + tensor_type { + elem_type: 2 + shape { + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 981f662..1770746 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -434,6 +434,10 @@ class TestOperators(TestCase): x = torch.randn(1, 2, 3, 4, requires_grad=True) self.assertONNX(lambda x: torch.flatten(x, 1), x) + def test_isnan(self): + x = torch.tensor([1, float('nan'), 2]) + self.assertONNX(lambda x: torch.isnan(x), x) + def test_argmax(self): x = torch.randn(4, 4, requires_grad=True) self.assertONNX(lambda x: torch.argmax(x, dim=1), x) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 9791949..9c26392 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -1089,6 +1089,14 @@ class TestCaffe2Backend(unittest.TestCase): self.run_model_test(RsubModel(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False) + def test_isnan(self): + class IsNaNModel(torch.nn.Module): + def forward(self, input): + return torch.isnan(input) + + x = torch.tensor([1.0, float('nan'), 2.0]) + self.run_model_test(IsNaNModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False) + def test_flatten(self): class FlattenModel(torch.nn.Module): def forward(self, input): diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 18be2e0..b4ed832 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1677,6 +1677,13 @@ def nonzero(g, input): return t(g, g.op('NonZero', input)) +@parse_args('v') +def isnan(g, input): + output = g.op('IsNaN', input) + output = _cast_func_template(cast_pytorch_to_onnx['Byte'], g, output, None) + return output + + @parse_args('v', 'i', 'i', 'i') def narrow(g, input, dim, start, length): return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length]) -- 2.7.4