#include "ngraph/node.hpp"
#include "onnx_import/default_opset.hpp"
-#include "softplus.hpp"
+#include "onnx_import/op/softplus.hpp"
namespace ngraph
{
OutputVector softplus(const Node& node)
{
const auto data = node.get_ng_inputs().at(0);
-
- const std::shared_ptr<ngraph::Node> zero_node =
- default_opset::Constant::create(data.get_element_type(), Shape{}, {0.f});
- const std::shared_ptr<ngraph::Node> one_node =
- default_opset::Constant::create(data.get_element_type(), Shape{}, {1.f});
-
- // data + log(exp(-data) + 1)
- const std::shared_ptr<ngraph::Node> positive_val_node =
- std::make_shared<default_opset::Add>(
- data,
- std::make_shared<default_opset::Log>(
- std::make_shared<default_opset::Add>(
- std::make_shared<default_opset::Exp>(
- std::make_shared<default_opset::Negative>(data)),
- one_node)));
-
- // log(exp(data) + 1)
- const std::shared_ptr<ngraph::Node> negative_val_node =
- std::make_shared<default_opset::Log>(std::make_shared<default_opset::Add>(
- std::make_shared<default_opset::Exp>(data), one_node));
-
- const std::shared_ptr<ngraph::Node> condition_node =
- std::make_shared<default_opset::Greater>(data, zero_node);
-
- // This equation represents:
- // x + log(exp(-x) + 1) - for x > 0; to manage exponent overflow,
- // log(exp(x) + 1) - elsewhere.
- //
- return {std::make_shared<default_opset::Select>(
- condition_node, positive_val_node, negative_val_node)};
+ return {std::make_shared<default_opset::SoftPlus>(data)};
}
} // namespace set_1
FLT_MAX,
-FLT_MAX}};
- std::vector<float>& input = inputs.back();
- std::vector<float> output;
- auto softplus_impl = [](float x) -> float {
- if (x > 0)
- {
- return x + std::log(std::exp(-x) + 1);
- }
- else
- {
- return std::log(std::exp(x) + 1);
- }
- };
-
- std::transform(std::begin(input), std::end(input), std::back_inserter(output), softplus_impl);
+ const auto inf = std::numeric_limits<float>::infinity();
+ std::vector<float> output{0.3132616579532623291,
+ 0.6931471824645996094,
+ 1.313261628150939941,
+ 10.0000457763671875,
+ inf,
+ 0.0,
+ inf,
+ 0.0,
+ 0.6931471824645996094,
+ 0.6931471824645996094,
+ 0.6931471824645996094,
+ inf,
+ 0.0};
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_multiple_inputs(inputs);