From 5eee1ea925a0beb9e7835e45fb677c598b4a6b67 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Tomasz=20Do=C5=82bniak?= Date: Fri, 16 Oct 2020 11:30:20 +0200 Subject: [PATCH] Avoid unnecessary Reshape in ONNX Softmax impl (#2686) --- ngraph/frontend/onnx_import/src/op/softmax.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/ngraph/frontend/onnx_import/src/op/softmax.cpp b/ngraph/frontend/onnx_import/src/op/softmax.cpp index b0c4fe4..87c7e51 100644 --- a/ngraph/frontend/onnx_import/src/op/softmax.cpp +++ b/ngraph/frontend/onnx_import/src/op/softmax.cpp @@ -33,16 +33,11 @@ namespace ngraph const auto coerced_data = ngraph::builder::opset1::flatten(data, axis); const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1}); - const auto max = std::make_shared(coerced_data, axis_1); - - // equivalent to numpy's max.reshape((-1,1)) - const auto reshape_pattern = - default_opset::Constant::create(element::i64, Shape{2}, {0, 1}); - const auto reshaped_max = - std::make_shared(max, reshape_pattern, true); + const auto max = + std::make_shared(coerced_data, axis_1, true); const auto data_minus_max = - std::make_shared(coerced_data, reshaped_max); + std::make_shared(coerced_data, max); const auto result = std::make_shared(data_minus_max, 1); if (data.get_partial_shape().is_static()) -- 2.7.4