[ONNX] Update ONNX importer to use SotfPlus-4 (#1959)
authorKatarzyna Mitrus <katarzyna.mitrus@intel.com>
Thu, 27 Aug 2020 12:55:04 +0000 (14:55 +0200)
committerGitHub <noreply@github.com>
Thu, 27 Aug 2020 12:55:04 +0000 (15:55 +0300)
* Use SoftPlus-4 in ONNX importer

* Tests update

ngraph/frontend/onnx_import/src/op/softplus.cpp
ngraph/test/onnx/onnx_import.in.cpp

index 4d21fa9..3d63925 100644 (file)
@@ -18,7 +18,7 @@
 
 #include "ngraph/node.hpp"
 #include "onnx_import/default_opset.hpp"
-#include "softplus.hpp"
+#include "onnx_import/op/softplus.hpp"
 
 namespace ngraph
 {
@@ -31,36 +31,7 @@ namespace ngraph
                 OutputVector softplus(const Node& node)
                 {
                     const auto data = node.get_ng_inputs().at(0);
-
-                    const std::shared_ptr<ngraph::Node> zero_node =
-                        default_opset::Constant::create(data.get_element_type(), Shape{}, {0.f});
-                    const std::shared_ptr<ngraph::Node> one_node =
-                        default_opset::Constant::create(data.get_element_type(), Shape{}, {1.f});
-
-                    // data + log(exp(-data) + 1)
-                    const std::shared_ptr<ngraph::Node> positive_val_node =
-                        std::make_shared<default_opset::Add>(
-                            data,
-                            std::make_shared<default_opset::Log>(
-                                std::make_shared<default_opset::Add>(
-                                    std::make_shared<default_opset::Exp>(
-                                        std::make_shared<default_opset::Negative>(data)),
-                                    one_node)));
-
-                    // log(exp(data) + 1)
-                    const std::shared_ptr<ngraph::Node> negative_val_node =
-                        std::make_shared<default_opset::Log>(std::make_shared<default_opset::Add>(
-                            std::make_shared<default_opset::Exp>(data), one_node));
-
-                    const std::shared_ptr<ngraph::Node> condition_node =
-                        std::make_shared<default_opset::Greater>(data, zero_node);
-
-                    // This equation represents:
-                    //     x + log(exp(-x) + 1) - for x > 0; to manage exponent overflow,
-                    //     log(exp(x) + 1)      - elsewhere.
-                    //
-                    return {std::make_shared<default_opset::Select>(
-                        condition_node, positive_val_node, negative_val_node)};
+                    return {std::make_shared<default_opset::SoftPlus>(data)};
                 }
 
             } // namespace set_1
index cc0c484..5cc7c62 100644 (file)
@@ -1653,20 +1653,20 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softplus)
                                      FLT_MAX,
                                      -FLT_MAX}};
 
-    std::vector<float>& input = inputs.back();
-    std::vector<float> output;
-    auto softplus_impl = [](float x) -> float {
-        if (x > 0)
-        {
-            return x + std::log(std::exp(-x) + 1);
-        }
-        else
-        {
-            return std::log(std::exp(x) + 1);
-        }
-    };
-
-    std::transform(std::begin(input), std::end(input), std::back_inserter(output), softplus_impl);
+    const auto inf = std::numeric_limits<float>::infinity();
+    std::vector<float> output{0.3132616579532623291,
+                              0.6931471824645996094,
+                              1.313261628150939941,
+                              10.0000457763671875,
+                              inf,
+                              0.0,
+                              inf,
+                              0.0,
+                              0.6931471824645996094,
+                              0.6931471824645996094,
+                              0.6931471824645996094,
+                              inf,
+                              0.0};
 
     auto test_case = test::TestCase<TestEngine>(function);
     test_case.add_multiple_inputs(inputs);