auto max_k = maximum_value(input_value(1));
if (max_k.first)
{
- output_shape[m_normalized_axis] &= Dimension(0, max_k.second);
+ const auto in_min = output_shape[m_normalized_axis].get_min_length();
+ const auto in_max = output_shape[m_normalized_axis].get_max_length();
+ const auto lower = std::min<Dimension::value_type>(in_min, max_k.second);
+ const auto upper = in_max < 0
+ ? Dimension::dynamic().get_max_length()
+ : std::max<Dimension::value_type>(in_max, max_k.second);
+ output_shape[m_normalized_axis] = Dimension(lower, upper);
}
else
{
const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
ASSERT_EQ(topk->get_provided_axis(), axis);
- ASSERT_EQ(topk->get_axis(), data_shape.at(1));
+ const auto expect_shape = Shape{1, 2, 2, 4};
+ ASSERT_EQ(topk->get_output_shape(0), expect_shape);
+ ASSERT_EQ(topk->get_output_shape(1), expect_shape);
}
TYPED_TEST_P(topk_type_prop, topk_negative_axis_dynamic_rank)
{
auto k = make_shared<op::Constant>(element::i32, Shape{}, 3);
auto topk = make_shared<TypeParam>(data, k, 1, "max", "value");
+ EXPECT_EQ(topk->get_output_shape(0), Shape({2, 3}));
EXPECT_EQ(topk->get_output_partial_shape(0), PartialShape({2, 3}));
}
}
+TYPED_TEST_P(topk_type_prop, topk_rank_static_k_unknown)
+{
+ const int64_t axis = 1;
+ const auto data_shape = Shape{1, 10, 100};
+ const auto data = make_shared<op::Parameter>(element::f32, data_shape);
+
+ {
+ const auto k = make_shared<op::Parameter>(element::i32, PartialShape({}));
+ const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
+
+ const PartialShape fully_dynamic_axis_shape{1, Dimension::dynamic(), 100};
+ EXPECT_EQ(topk->get_output_partial_shape(0), fully_dynamic_axis_shape);
+ }
+ {
+ const auto k = make_shared<op::v0::Constant>(element::i64, Shape{}, 5);
+ const auto convert_k = make_shared<op::v0::Convert>(k, element::i32);
+ const auto topk = make_shared<TypeParam>(data, convert_k, axis, "max", "value");
+
+ const PartialShape ranged_dynamic_axis_shape{1, Dimension{5, 10}, 100};
+ EXPECT_EQ(topk->get_output_partial_shape(0), ranged_dynamic_axis_shape);
+ }
+}
+
REGISTER_TYPED_TEST_CASE_P(topk_type_prop,
topk_negative_axis_support,
topk_negative_axis_dynamic_rank,
- topk_v1_partial_ouptut);
+ topk_v1_partial_ouptut,
+ topk_rank_static_k_unknown);
typedef ::testing::Types<op::v1::TopK, op::v3::TopK> TopKTypes;
INSTANTIATE_TYPED_TEST_CASE_P(type_prop, topk_type_prop, TopKTypes, );