nGraph: fix TopK output shape inference (#2967)
authorTomasz Jankowski <tomasz1.jankowski@intel.com>
Fri, 13 Nov 2020 14:10:42 +0000 (15:10 +0100)
committerGitHub <noreply@github.com>
Fri, 13 Nov 2020 14:10:42 +0000 (17:10 +0300)
* nGraph: Fix TopK output shape inference

* nGraph: Correct TopK output shape inference

TopK lower bound of output shape at the axis was mistakenly calculated
basing on max_lenght instead of min_lenght.

* nGraph: Correct TopK output shape inference

* nGraph: Correct TopK type prop test

The topk_negative_axis_support type properties test was comparing
incompatible variables carrying the same value. So it was passing ok.

* nGraph: Add TopK type prop test

* nGraph: Fix code style

* nGraph: Follow review guidelines

Improve variables meaning.
Enforce rigid test pass condition.

* nGraph: Remove magic numbers

ngraph/core/src/op/topk.cpp
ngraph/test/type_prop/top_k.cpp

index 17155dc..9a47674 100644 (file)
@@ -242,7 +242,13 @@ void op::v1::TopK::validate_and_infer_types()
             auto max_k = maximum_value(input_value(1));
             if (max_k.first)
             {
-                output_shape[m_normalized_axis] &= Dimension(0, max_k.second);
+                const auto in_min = output_shape[m_normalized_axis].get_min_length();
+                const auto in_max = output_shape[m_normalized_axis].get_max_length();
+                const auto lower = std::min<Dimension::value_type>(in_min, max_k.second);
+                const auto upper = in_max < 0
+                                       ? Dimension::dynamic().get_max_length()
+                                       : std::max<Dimension::value_type>(in_max, max_k.second);
+                output_shape[m_normalized_axis] = Dimension(lower, upper);
             }
             else
             {
index bb04e2c..644b60b 100644 (file)
@@ -38,7 +38,9 @@ TYPED_TEST_P(topk_type_prop, topk_negative_axis_support)
     const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
 
     ASSERT_EQ(topk->get_provided_axis(), axis);
-    ASSERT_EQ(topk->get_axis(), data_shape.at(1));
+    const auto expect_shape = Shape{1, 2, 2, 4};
+    ASSERT_EQ(topk->get_output_shape(0), expect_shape);
+    ASSERT_EQ(topk->get_output_shape(1), expect_shape);
 }
 
 TYPED_TEST_P(topk_type_prop, topk_negative_axis_dynamic_rank)
@@ -75,14 +77,39 @@ TYPED_TEST_P(topk_type_prop, topk_v1_partial_ouptut)
     {
         auto k = make_shared<op::Constant>(element::i32, Shape{}, 3);
         auto topk = make_shared<TypeParam>(data, k, 1, "max", "value");
+        EXPECT_EQ(topk->get_output_shape(0), Shape({2, 3}));
         EXPECT_EQ(topk->get_output_partial_shape(0), PartialShape({2, 3}));
     }
 }
 
+TYPED_TEST_P(topk_type_prop, topk_rank_static_k_unknown)
+{
+    const int64_t axis = 1;
+    const auto data_shape = Shape{1, 10, 100};
+    const auto data = make_shared<op::Parameter>(element::f32, data_shape);
+
+    {
+        const auto k = make_shared<op::Parameter>(element::i32, PartialShape({}));
+        const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
+
+        const PartialShape fully_dynamic_axis_shape{1, Dimension::dynamic(), 100};
+        EXPECT_EQ(topk->get_output_partial_shape(0), fully_dynamic_axis_shape);
+    }
+    {
+        const auto k = make_shared<op::v0::Constant>(element::i64, Shape{}, 5);
+        const auto convert_k = make_shared<op::v0::Convert>(k, element::i32);
+        const auto topk = make_shared<TypeParam>(data, convert_k, axis, "max", "value");
+
+        const PartialShape ranged_dynamic_axis_shape{1, Dimension{5, 10}, 100};
+        EXPECT_EQ(topk->get_output_partial_shape(0), ranged_dynamic_axis_shape);
+    }
+}
+
 REGISTER_TYPED_TEST_CASE_P(topk_type_prop,
                            topk_negative_axis_support,
                            topk_negative_axis_dynamic_rank,
-                           topk_v1_partial_ouptut);
+                           topk_v1_partial_ouptut,
+                           topk_rank_static_k_unknown);
 
 typedef ::testing::Types<op::v1::TopK, op::v3::TopK> TopKTypes;
 INSTANTIATE_TYPED_TEST_CASE_P(type_prop, topk_type_prop, TopKTypes, );