From: Tomasz DoĊ‚bniak Date: Fri, 16 Oct 2020 09:30:20 +0000 (+0200) Subject: Avoid unnecessary Reshape in ONNX Softmax impl (#2686) X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5eee1ea925a0beb9e7835e45fb677c598b4a6b67;p=platform%2Fupstream%2Fdldt.git Avoid unnecessary Reshape in ONNX Softmax impl (#2686) --- diff --git a/ngraph/frontend/onnx_import/src/op/softmax.cpp b/ngraph/frontend/onnx_import/src/op/softmax.cpp index b0c4fe4..87c7e51 100644 --- a/ngraph/frontend/onnx_import/src/op/softmax.cpp +++ b/ngraph/frontend/onnx_import/src/op/softmax.cpp @@ -33,16 +33,11 @@ namespace ngraph const auto coerced_data = ngraph::builder::opset1::flatten(data, axis); const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1}); - const auto max = std::make_shared(coerced_data, axis_1); - - // equivalent to numpy's max.reshape((-1,1)) - const auto reshape_pattern = - default_opset::Constant::create(element::i64, Shape{2}, {0, 1}); - const auto reshaped_max = - std::make_shared(max, reshape_pattern, true); + const auto max = + std::make_shared(coerced_data, axis_1, true); const auto data_minus_max = - std::make_shared(coerced_data, reshaped_max); + std::make_shared(coerced_data, max); const auto result = std::make_shared(data_minus_max, 1); if (data.get_partial_shape().is_static())