From 2966910dacc2ada8cc902c9430b37d5dbbfa3061 Mon Sep 17 00:00:00 2001 From: Tomasz Jankowski Date: Fri, 13 Nov 2020 15:10:42 +0100 Subject: [PATCH] nGraph: fix TopK output shape inference (#2967) * nGraph: Fix TopK output shape inference * nGraph: Correct TopK output shape inference TopK lower bound of output shape at the axis was mistakenly calculated basing on max_lenght instead of min_lenght. * nGraph: Correct TopK output shape inference * nGraph: Correct TopK type prop test The topk_negative_axis_support type properties test was comparing incompatible variables carrying the same value. So it was passing ok. * nGraph: Add TopK type prop test * nGraph: Fix code style * nGraph: Follow review guidelines Improve variables meaning. Enforce rigid test pass condition. * nGraph: Remove magic numbers --- ngraph/core/src/op/topk.cpp | 8 +++++++- ngraph/test/type_prop/top_k.cpp | 31 +++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/ngraph/core/src/op/topk.cpp b/ngraph/core/src/op/topk.cpp index 17155dc..9a47674 100644 --- a/ngraph/core/src/op/topk.cpp +++ b/ngraph/core/src/op/topk.cpp @@ -242,7 +242,13 @@ void op::v1::TopK::validate_and_infer_types() auto max_k = maximum_value(input_value(1)); if (max_k.first) { - output_shape[m_normalized_axis] &= Dimension(0, max_k.second); + const auto in_min = output_shape[m_normalized_axis].get_min_length(); + const auto in_max = output_shape[m_normalized_axis].get_max_length(); + const auto lower = std::min(in_min, max_k.second); + const auto upper = in_max < 0 + ? Dimension::dynamic().get_max_length() + : std::max(in_max, max_k.second); + output_shape[m_normalized_axis] = Dimension(lower, upper); } else { diff --git a/ngraph/test/type_prop/top_k.cpp b/ngraph/test/type_prop/top_k.cpp index bb04e2c..644b60b 100644 --- a/ngraph/test/type_prop/top_k.cpp +++ b/ngraph/test/type_prop/top_k.cpp @@ -38,7 +38,9 @@ TYPED_TEST_P(topk_type_prop, topk_negative_axis_support) const auto topk = make_shared(data, k, axis, "max", "value"); ASSERT_EQ(topk->get_provided_axis(), axis); - ASSERT_EQ(topk->get_axis(), data_shape.at(1)); + const auto expect_shape = Shape{1, 2, 2, 4}; + ASSERT_EQ(topk->get_output_shape(0), expect_shape); + ASSERT_EQ(topk->get_output_shape(1), expect_shape); } TYPED_TEST_P(topk_type_prop, topk_negative_axis_dynamic_rank) @@ -75,14 +77,39 @@ TYPED_TEST_P(topk_type_prop, topk_v1_partial_ouptut) { auto k = make_shared(element::i32, Shape{}, 3); auto topk = make_shared(data, k, 1, "max", "value"); + EXPECT_EQ(topk->get_output_shape(0), Shape({2, 3})); EXPECT_EQ(topk->get_output_partial_shape(0), PartialShape({2, 3})); } } +TYPED_TEST_P(topk_type_prop, topk_rank_static_k_unknown) +{ + const int64_t axis = 1; + const auto data_shape = Shape{1, 10, 100}; + const auto data = make_shared(element::f32, data_shape); + + { + const auto k = make_shared(element::i32, PartialShape({})); + const auto topk = make_shared(data, k, axis, "max", "value"); + + const PartialShape fully_dynamic_axis_shape{1, Dimension::dynamic(), 100}; + EXPECT_EQ(topk->get_output_partial_shape(0), fully_dynamic_axis_shape); + } + { + const auto k = make_shared(element::i64, Shape{}, 5); + const auto convert_k = make_shared(k, element::i32); + const auto topk = make_shared(data, convert_k, axis, "max", "value"); + + const PartialShape ranged_dynamic_axis_shape{1, Dimension{5, 10}, 100}; + EXPECT_EQ(topk->get_output_partial_shape(0), ranged_dynamic_axis_shape); + } +} + REGISTER_TYPED_TEST_CASE_P(topk_type_prop, topk_negative_axis_support, topk_negative_axis_dynamic_rank, - topk_v1_partial_ouptut); + topk_v1_partial_ouptut, + topk_rank_static_k_unknown); typedef ::testing::Types TopKTypes; INSTANTIATE_TYPED_TEST_CASE_P(type_prop, topk_type_prop, TopKTypes, ); -- 2.7.4