Reference Implementation for RegionYolo operator (#2474)
authorGabriele Galiero Casay <gabriele.galiero.casay@intel.com>
Thu, 15 Oct 2020 20:30:12 +0000 (22:30 +0200)
committerGitHub <noreply@github.com>
Thu, 15 Oct 2020 20:30:12 +0000 (22:30 +0200)
16 files changed:
inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/region_yolo.cpp [new file with mode: 0644]
inference-engine/tests/functional/plugin/shared/include/single_layer_tests/region_yolo.hpp [new file with mode: 0644]
inference-engine/tests/functional/plugin/shared/src/single_layer_tests/region_yolo.cpp [new file with mode: 0644]
ngraph/core/include/ngraph/op/region_yolo.hpp
ngraph/core/reference/include/ngraph/runtime/reference/region_yolo.hpp [new file with mode: 0644]
ngraph/core/src/op/region_yolo.cpp
ngraph/test/CMakeLists.txt
ngraph/test/attributes.cpp
ngraph/test/backend/region_yolo.in.cpp [new file with mode: 0644]
ngraph/test/files/region_in_yolov2_caffe.data [new file with mode: 0644]
ngraph/test/files/region_in_yolov3_mxnet.data [new file with mode: 0644]
ngraph/test/files/region_out_yolov2_caffe.data [new file with mode: 0644]
ngraph/test/files/region_out_yolov3_mxnet.data [new file with mode: 0644]
ngraph/test/runtime/ie/unit_test.manifest
ngraph/test/runtime/interpreter/int_executable.hpp
ngraph/test/runtime/interpreter/opset_int_tbl.hpp

diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/region_yolo.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/region_yolo.cpp
new file mode 100644 (file)
index 0000000..eb2e280
--- /dev/null
@@ -0,0 +1,85 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <vector>
+
+#include "single_layer_tests/region_yolo.hpp"
+#include "common_test_utils/test_constants.hpp"
+
+using namespace LayerTestsDefinitions;
+
+const std::vector<ngraph::Shape> inShapes_caffe = {
+    {1, 125, 13, 13}
+};
+
+const std::vector<ngraph::Shape> inShapes_mxnet = {
+    {1, 75, 52, 52},
+    {1, 75, 32, 32},
+    {1, 75, 26, 26},
+    {1, 75, 16, 16},
+    {1, 75, 13, 13},
+    {1, 75, 8, 8}
+};
+
+const std::vector<ngraph::Shape> inShapes_v3 = {
+    {1, 255, 52, 52},
+    {1, 255, 26, 26},
+    {1, 255, 13, 13}
+};
+
+const std::vector<std::vector<int64_t>> masks = {
+    {0, 1, 2},
+    {3, 4, 5},
+    {6, 7, 8}
+};
+
+const std::vector<bool> do_softmax = {true, false};
+const std::vector<size_t> classes = {80, 20};
+const std::vector<size_t> num_regions = {5, 9};
+const size_t coords = 4;
+const int start_axis = 1;
+const int end_axis = 3;
+
+const auto testCase_yolov3 = ::testing::Combine(
+    ::testing::ValuesIn(inShapes_v3),
+    ::testing::Values(classes[0]),
+    ::testing::Values(coords),
+    ::testing::Values(num_regions[1]),
+    ::testing::Values(do_softmax[1]),
+    ::testing::Values(masks[2]),
+    ::testing::Values(start_axis),
+    ::testing::Values(end_axis),
+    ::testing::Values(InferenceEngine::Precision::FP32),
+    ::testing::Values(CommonTestUtils::DEVICE_CPU)
+);
+
+const auto testCase_yolov3_mxnet = ::testing::Combine(
+    ::testing::ValuesIn(inShapes_mxnet),
+    ::testing::Values(classes[1]),
+    ::testing::Values(coords),
+    ::testing::Values(num_regions[1]),
+    ::testing::Values(do_softmax[1]),
+    ::testing::Values(masks[1]),
+    ::testing::Values(start_axis),
+    ::testing::Values(end_axis),
+    ::testing::Values(InferenceEngine::Precision::FP32),
+    ::testing::Values(CommonTestUtils::DEVICE_CPU)
+);
+
+const auto testCase_yolov2_caffe = ::testing::Combine(
+    ::testing::ValuesIn(inShapes_caffe),
+    ::testing::Values(classes[1]),
+    ::testing::Values(coords),
+    ::testing::Values(num_regions[0]),
+    ::testing::Values(do_softmax[0]),
+    ::testing::Values(masks[0]),
+    ::testing::Values(start_axis),
+    ::testing::Values(end_axis),
+    ::testing::Values(InferenceEngine::Precision::FP32),
+    ::testing::Values(CommonTestUtils::DEVICE_CPU)
+);
+
+INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYolov3, RegionYoloLayerTest, testCase_yolov3, RegionYoloLayerTest::getTestCaseName);
+INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYoloMxnet, RegionYoloLayerTest, testCase_yolov3_mxnet, RegionYoloLayerTest::getTestCaseName);
+INSTANTIATE_TEST_CASE_P(smoke_TestsRegionYoloCaffe, RegionYoloLayerTest, testCase_yolov2_caffe, RegionYoloLayerTest::getTestCaseName);
diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/region_yolo.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/region_yolo.hpp
new file mode 100644 (file)
index 0000000..c8d74f6
--- /dev/null
@@ -0,0 +1,38 @@
+// Copyright (C) 2019 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <tuple>
+#include <string>
+#include <vector>
+
+#include "functional_test_utils/layer_test_utils.hpp"
+#include "ngraph_functions/builders.hpp"
+#include "ngraph_functions/utils/ngraph_helpers.hpp"
+
+namespace LayerTestsDefinitions {
+
+using regionYoloParamsTuple = std::tuple<
+        ngraph::Shape,                  // Input Shape
+        size_t,                         // classes
+        size_t,                         // coordinates
+        size_t,                         // num regions
+        bool,                           // do softmax
+        std::vector<int64_t>,           // mask
+        int,                            // start axis
+        int,                            // end axis
+        InferenceEngine::Precision,     // Network precision
+        std::string>;                   // Device name
+
+class RegionYoloLayerTest : public testing::WithParamInterface<regionYoloParamsTuple>,
+                            virtual public LayerTestsUtils::LayerTestsCommon {
+public:
+    static std::string getTestCaseName(const testing::TestParamInfo<regionYoloParamsTuple> &obj);
+
+protected:
+    void SetUp() override;
+};
+
+} // namespace LayerTestsDefinitions
\ No newline at end of file
diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/region_yolo.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/region_yolo.cpp
new file mode 100644 (file)
index 0000000..9689094
--- /dev/null
@@ -0,0 +1,63 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "ie_core.hpp"
+
+#include "common_test_utils/common_utils.hpp"
+#include "functional_test_utils/blob_utils.hpp"
+#include "functional_test_utils/precision_utils.hpp"
+#include "functional_test_utils/plugin_cache.hpp"
+#include "functional_test_utils/skip_tests_config.hpp"
+
+#include "single_layer_tests/region_yolo.hpp"
+
+namespace LayerTestsDefinitions {
+
+std::string RegionYoloLayerTest::getTestCaseName(const testing::TestParamInfo<regionYoloParamsTuple> &obj) {
+    ngraph::Shape inputShape;
+    size_t classes;
+    size_t coords;
+    size_t num_regions;
+    bool do_softmax;
+    std::vector<int64_t> mask;
+    int start_axis;
+    int end_axis;
+    InferenceEngine::Precision netPrecision;
+    std::string targetName;
+    std::tie(inputShape, classes, coords, num_regions, do_softmax , mask, start_axis, end_axis, netPrecision, targetName) = obj.param;
+    std::ostringstream result;
+    result << "IS=" << inputShape << "_";
+    result << "classes=" << classes << "_";
+    result << "coords=" << coords << "_";
+    result << "num=" << num_regions << "_";
+    result << "doSoftmax=" << do_softmax << "_";
+    result << "axis=" << start_axis << "_";
+    result << "endAxis=" << end_axis << "_";
+    result << "netPRC=" << netPrecision.name() << "_";
+    result << "targetDevice=" << targetName << "_";
+    return result.str();
+}
+
+void RegionYoloLayerTest::SetUp() {
+    ngraph::Shape inputShape;
+    size_t classes;
+    size_t coords;
+    size_t num_regions;
+    bool do_softmax;
+    std::vector<int64_t> mask;
+    int start_axis;
+    int end_axis;
+    InferenceEngine::Precision netPrecision;
+    std::tie(inputShape, classes, coords, num_regions, do_softmax, mask, start_axis, end_axis, netPrecision, targetDevice) = this->GetParam();
+    auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
+    auto param = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, inputShape);
+    auto region_yolo = std::make_shared<ngraph::op::v0::RegionYolo>(param, coords, classes, num_regions, do_softmax, mask, start_axis, end_axis);
+    function = std::make_shared<ngraph::Function>(std::make_shared<ngraph::opset1::Result>(region_yolo), ngraph::ParameterVector{param}, "RegionYolo");
+}
+
+TEST_P(RegionYoloLayerTest, CompareWithRefs) {
+    Run();
+};
+
+} // namespace LayerTestsDefinitions
\ No newline at end of file
index 8dfdbb8..b7d9181 100644 (file)
@@ -79,7 +79,7 @@ namespace ngraph
                 int m_axis;
                 int m_end_axis;
             };
-        }
+        } // namespace v0
         using v0::RegionYolo;
-    }
-}
+    } // namespace op
+} // namespace ngraph
diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/region_yolo.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/region_yolo.hpp
new file mode 100644 (file)
index 0000000..2ca3f32
--- /dev/null
@@ -0,0 +1,175 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#pragma once
+
+#include <algorithm>
+#include <cmath>
+
+#include "ngraph/shape.hpp"
+
+namespace ngraph
+{
+    namespace runtime
+    {
+        namespace reference
+        {
+            static inline int entry_index(int width,
+                                          int height,
+                                          int coords,
+                                          int classes,
+                                          int outputs,
+                                          int batch,
+                                          int location,
+                                          int entry)
+            {
+                int n = location / (width * height);
+                int loc = location % (width * height);
+                return batch * outputs + n * width * height * (coords + classes + 1) +
+                       entry * width * height + loc;
+            }
+
+            template <typename T>
+            static inline T sigmoid(float x)
+            {
+                return static_cast<T>(1.f / (1.f + std::exp(-x)));
+            }
+            template <typename T>
+            static inline void softmax_generic(
+                const T* src_data, T* dst_data, int batches, int channels, int height, int width)
+            {
+                const int area = height * width;
+                for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
+                {
+                    const int offset = batch_idx * channels * area;
+                    for (unsigned int i = 0; i < height * width; i++)
+                    {
+                        T max = src_data[batch_idx * channels * area + i];
+                        for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
+                        {
+                            T val = src_data[offset + channel_idx * area + i];
+                            max = std::max(max, val);
+                        }
+
+                        T sum = 0;
+                        for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
+                        {
+                            dst_data[offset + channel_idx * area + i] =
+                                std::exp(src_data[offset + channel_idx * area + i] - max);
+                            sum += dst_data[offset + channel_idx * area + i];
+                        }
+
+                        for (unsigned int channel_idx = 0; channel_idx < channels; channel_idx++)
+                        {
+                            dst_data[offset + channel_idx * area + i] /= sum;
+                        }
+                    }
+                }
+            }
+
+            template <typename T>
+            void region_yolo(const T* input,
+                             T* output,
+                             const Shape& input_shape,
+                             const int coords,
+                             const int classes,
+                             const int regions,
+                             const bool do_softmax,
+                             const std::vector<int64_t>& mask)
+            {
+                NGRAPH_CHECK(input_shape.size() == 4);
+
+                const int batches = input_shape[0];
+                const int channels = input_shape[1];
+                const int height = input_shape[2];
+                const int width = input_shape[3];
+
+                const auto mask_size = mask.size();
+
+                std::copy(input, input + shape_size(input_shape), output);
+
+                int num_regions = 0;
+                int end_index = 0;
+
+                if (do_softmax)
+                {
+                    // Region layer (Yolo v2)
+                    num_regions = regions;
+                    end_index = width * height;
+                }
+                else
+                {
+                    // Yolo layer (Yolo v3)
+                    num_regions = mask_size;
+                    end_index = width * height * (classes + 1);
+                }
+
+                const int inputs_size = width * height * num_regions * (classes + coords + 1);
+
+                for (unsigned int batch_idx = 0; batch_idx < batches; batch_idx++)
+                {
+                    for (unsigned int n = 0; n < num_regions; n++)
+                    {
+                        int index = entry_index(width,
+                                                height,
+                                                coords,
+                                                classes,
+                                                inputs_size,
+                                                batch_idx,
+                                                n * width * height,
+                                                0);
+                        std::transform(output + index,
+                                       output + index + 2 * width * height,
+                                       output + index,
+                                       [](T elem) { return sigmoid<T>(elem); });
+
+                        index = entry_index(width,
+                                            height,
+                                            coords,
+                                            classes,
+                                            inputs_size,
+                                            batch_idx,
+                                            n * width * height,
+                                            coords);
+                        std::transform(output + index,
+                                       output + index + end_index,
+                                       output + index,
+                                       [](T elem) { return sigmoid<T>(elem); });
+                    }
+                }
+
+                if (do_softmax)
+                {
+                    int index =
+                        entry_index(width, height, coords, classes, inputs_size, 0, 0, coords + 1);
+                    int batch_offset = inputs_size / regions;
+                    for (unsigned int batch_idx = 0; batch_idx < batches * regions; batch_idx++)
+                    {
+                        softmax_generic<T>(input + index + batch_idx * batch_offset,
+                                           output + index + batch_idx * batch_offset,
+                                           1,
+                                           classes,
+                                           height,
+                                           width);
+                    }
+                }
+            }
+
+        } // namespace reference
+
+    } // namespace runtime
+
+} // namespace ngraph
\ No newline at end of file
index f260ace..4eed7f5 100644 (file)
@@ -60,6 +60,12 @@ bool ngraph::op::v0::RegionYolo::visit_attributes(AttributeVisitor& visitor)
 void op::RegionYolo::validate_and_infer_types()
 {
     auto input_et = get_input_element_type(0);
+
+    NODE_VALIDATION_CHECK(this,
+                          input_et.is_real(),
+                          "Type of input is expected to be a floating point type. Got: ",
+                          input_et);
+
     if (get_input_partial_shape(0).is_static())
     {
         Shape input_shape = get_input_partial_shape(0).to_shape();
index 6f46f14..6e3a9f5 100644 (file)
@@ -325,6 +325,7 @@ set(MULTI_TEST_SRC
     backend/reduce_min.in.cpp
     backend/reduce_prod.in.cpp
     backend/reduce_sum.in.cpp
+    backend/region_yolo.in.cpp
     backend/relu.in.cpp
     backend/reorg_yolo.in.cpp
     backend/replace_slice.in.cpp
index 322c860..64a5a60 100644 (file)
@@ -787,7 +787,7 @@ TEST(attributes, reduce_sum_op)
 TEST(attributes, region_yolo_op)
 {
     FactoryRegistry<Node>::get().register_factory<opset1::RegionYolo>();
-    auto data = make_shared<op::Parameter>(element::i64, Shape{1, 255, 26, 26});
+    auto data = make_shared<op::Parameter>(element::f32, Shape{1, 255, 26, 26});
 
     size_t num_coords = 4;
     size_t num_classes = 1;
diff --git a/ngraph/test/backend/region_yolo.in.cpp b/ngraph/test/backend/region_yolo.in.cpp
new file mode 100644 (file)
index 0000000..8d520c4
--- /dev/null
@@ -0,0 +1,86 @@
+//*****************************************************************************
+// Copyright 2017-2020 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//*****************************************************************************
+
+#include <fstream>
+
+#include "gtest/gtest.h"
+#include "ngraph/ngraph.hpp"
+#include "util/engine/test_engines.hpp"
+#include "util/test_case.hpp"
+#include "util/test_control.hpp"
+
+NGRAPH_SUPPRESS_DEPRECATED_START
+
+using namespace std;
+using namespace ngraph;
+
+static string s_manifest = "${MANIFEST}";
+using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
+
+NGRAPH_TEST(${BACKEND_NAME}, region_yolo_v2_caffe)
+{
+    const size_t num = 5;
+    const size_t coords = 4;
+    const size_t classes = 20;
+    const size_t batch = 1;
+    const size_t channels = 125;
+    const size_t width = 13;
+    const size_t height = 13;
+    const size_t count = width * height * channels;
+    const std::vector<int64_t> mask{0, 1, 2};
+
+    Shape input_shape{batch, channels, height, width};
+    Shape output_shape{batch, channels * height * width};
+
+    auto A = make_shared<op::Parameter>(element::f32, input_shape);
+    auto R = make_shared<op::v0::RegionYolo>(A, coords, classes, num, true, mask, 1, 3);
+    auto f = make_shared<Function>(R, ParameterVector{A});
+
+    auto test_case = test::TestCase<TestEngine>(f);
+
+    test_case.add_input_from_file<float>(input_shape, TEST_FILES, "region_in_yolov2_caffe.data");
+    test_case.add_expected_output_from_file<float>(
+        output_shape, TEST_FILES, "region_out_yolov2_caffe.data");
+    test_case.run_with_tolerance_as_fp(1.0e-4f);
+}
+
+NGRAPH_TEST(${BACKEND_NAME}, region_yolo_v3_mxnet)
+{
+    const size_t num = 9;
+    const size_t coords = 4;
+    const size_t classes = 20;
+    const size_t batch = 1;
+    const size_t channels = 75;
+    const size_t width = 32;
+    const size_t height = 32;
+    const std::vector<int64_t> mask{0, 1, 2};
+
+    Shape shape{batch, channels, height, width};
+    const auto count = shape_size(shape);
+
+    const auto A = make_shared<op::Parameter>(element::f32, shape);
+    const auto R = make_shared<op::v0::RegionYolo>(A, coords, classes, num, false, mask, 1, 3);
+    const auto f = make_shared<Function>(R, ParameterVector{A});
+
+    EXPECT_EQ(R->get_output_shape(0), shape);
+
+    auto test_case = test::TestCase<TestEngine>(f);
+
+    test_case.add_input_from_file<float>(shape, TEST_FILES, "region_in_yolov3_mxnet.data");
+    test_case.add_expected_output_from_file<float>(
+        shape, TEST_FILES, "region_out_yolov3_mxnet.data");
+    test_case.run_with_tolerance_as_fp(1.0e-4f);
+}
diff --git a/ngraph/test/files/region_in_yolov2_caffe.data b/ngraph/test/files/region_in_yolov2_caffe.data
new file mode 100644 (file)
index 0000000..3111300
Binary files /dev/null and b/ngraph/test/files/region_in_yolov2_caffe.data differ
diff --git a/ngraph/test/files/region_in_yolov3_mxnet.data b/ngraph/test/files/region_in_yolov3_mxnet.data
new file mode 100644 (file)
index 0000000..7fea67d
Binary files /dev/null and b/ngraph/test/files/region_in_yolov3_mxnet.data differ
diff --git a/ngraph/test/files/region_out_yolov2_caffe.data b/ngraph/test/files/region_out_yolov2_caffe.data
new file mode 100644 (file)
index 0000000..44807ba
Binary files /dev/null and b/ngraph/test/files/region_out_yolov2_caffe.data differ
diff --git a/ngraph/test/files/region_out_yolov3_mxnet.data b/ngraph/test/files/region_out_yolov3_mxnet.data
new file mode 100644 (file)
index 0000000..b5336a7
Binary files /dev/null and b/ngraph/test/files/region_out_yolov3_mxnet.data differ
index b248835..f2ae030 100644 (file)
@@ -1466,6 +1466,8 @@ IE_GPU.matmul_2x2_2x2
 IE_GPU.matmul_2x3_3x3
 IE_GPU.matmul_3x2_3x3_transpose
 IE_GPU.matmul_3x2_2x3_transpose
+IE_GPU.region_yolo_v2_caffe
+IE_GPU.region_yolo_v3_mxnet
 
 # Unsupported collapse op with dynamic shape
 IE_GPU.builder_opset1_collapse_dyn_shape
index 0070aaa..d785188 100644 (file)
@@ -77,6 +77,7 @@
 #include "ngraph/runtime/reference/prior_box.hpp"
 #include "ngraph/runtime/reference/product.hpp"
 #include "ngraph/runtime/reference/quantize.hpp"
+#include "ngraph/runtime/reference/region_yolo.hpp"
 #include "ngraph/runtime/reference/relu.hpp"
 #include "ngraph/runtime/reference/reorg_yolo.hpp"
 #include "ngraph/runtime/reference/replace_slice.hpp"
@@ -1187,6 +1188,19 @@ protected:
 
             break;
         }
+        case OP_TYPEID::RegionYolo_v0:
+        {
+            const op::RegionYolo* region_yolo = static_cast<const op::RegionYolo*>(&node);
+            reference::region_yolo<T>(args[0]->get_data_ptr<const T>(),
+                                      out[0]->get_data_ptr<T>(),
+                                      args[0]->get_shape(),
+                                      region_yolo->get_num_coords(),
+                                      region_yolo->get_num_classes(),
+                                      region_yolo->get_num_regions(),
+                                      region_yolo->get_do_softmax(),
+                                      region_yolo->get_mask());
+            break;
+        }
         case OP_TYPEID::Relu:
         {
             size_t element_count = shape_size(node.get_output_shape(0));
index de33cda..4cfe669 100644 (file)
@@ -21,6 +21,7 @@
 #define ID_SUFFIX(NAME) NAME##_v0
 NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
 NGRAPH_OP(DetectionOutput, op::v0)
+NGRAPH_OP(RegionYolo, op::v0)
 NGRAPH_OP(ReorgYolo, op::v0)
 NGRAPH_OP(RNNCell, op::v0)
 #undef ID_SUFFIX