const auto coerced_data = ngraph::builder::opset1::flatten(data, axis);
const auto axis_1 = default_opset::Constant::create(element::i64, Shape{1}, {1});
- const auto max = std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1);
-
- // equivalent to numpy's max.reshape((-1,1))
- const auto reshape_pattern =
- default_opset::Constant::create(element::i64, Shape{2}, {0, 1});
- const auto reshaped_max =
- std::make_shared<default_opset::Reshape>(max, reshape_pattern, true);
+ const auto max =
+ std::make_shared<default_opset::ReduceMax>(coerced_data, axis_1, true);
const auto data_minus_max =
- std::make_shared<default_opset::Subtract>(coerced_data, reshaped_max);
+ std::make_shared<default_opset::Subtract>(coerced_data, max);
const auto result = std::make_shared<default_opset::Softmax>(data_minus_max, 1);
if (data.get_partial_shape().is_static())