Avoid unnecessary Reshape in ONNX Softmax impl (#2686)
authorTomasz Dołbniak <tomasz.dolbniak@intel.com>
Fri, 16 Oct 2020 09:30:20 +0000 (11:30 +0200)
committerGitHub <noreply@github.com>
Fri, 16 Oct 2020 09:30:20 +0000 (11:30 +0200)
ngraph/frontend/onnx_import/src/op/softmax.cpp

index b0c4fe4..87c7e51 100644 (file)
@@ -33,16 +33,11 @@ namespace ngraph
                 const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
 
                 const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
-                const auto max = std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1);
-
-                // equivalent to numpy's max.reshape((-1,1))
-                const auto reshape_pattern =
-                    default_opset::Constant::create(element::i64, Shape{2}, {0, 1});
-                const auto reshaped_max =
-                    std::make_shared<default_opset::Reshape>(max, reshape_pattern, true);
+                const auto max =
+                    std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1, true);
 
                 const auto data_minus_max =
-                    std::make_shared<default_opset::Subtract>(coerced_data, reshaped_max);
+                    std::make_shared<default_opset::Subtract>(coerced_data, max);
 
                 const auto result = std::make_shared<default_opset::Softmax>(data_minus_max, 1);
                 if (data.get_partial_shape().is_static())