Added ngraph::op::v4::Interpolation
authorLiubov Batanina <piccione-mail@yandex.ru>
Fri, 12 Mar 2021 09:00:59 +0000 (12:00 +0300)
committerLiubov Batanina <piccione-mail@yandex.ru>
Fri, 12 Mar 2021 09:00:59 +0000 (12:00 +0300)
modules/dnn/src/layers/resize_layer.cpp

index a41e20a666e4d697ec2c188ac507347fc48c480a..2527ad1190798821945ec1090f65e1c50f021bde 100644 (file)
@@ -257,6 +257,7 @@ public:
     {
         auto& ieInpNode = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
 
+#if INF_ENGINE_VER_MAJOR_LE(INF_ENGINE_RELEASE_2021_2)
         ngraph::op::InterpolateAttrs attrs;
         attrs.pads_begin.push_back(0);
         attrs.pads_end.push_back(0);
@@ -275,6 +276,36 @@ public:
         std::vector<int64_t> shape = {outHeight, outWidth};
         auto out_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{2}, shape.data());
         auto interp = std::make_shared<ngraph::op::Interpolate>(ieInpNode, out_shape, attrs);
+#else
+        ngraph::op::v4::Interpolate::InterpolateAttrs attrs;
+
+        if (interpolation == "nearest") {
+            attrs.mode = ngraph::op::v4::Interpolate::InterpolateMode::nearest;
+            attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::half_pixel;
+        } else if (interpolation == "bilinear") {
+            attrs.mode = ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx;
+            attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::asymmetric;
+        } else {
+            CV_Error(Error::StsNotImplemented, "Unsupported interpolation: " + interpolation);
+        }
+        attrs.shape_calculation_mode = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes;
+
+        if (alignCorners) {
+            attrs.coordinate_transformation_mode = ngraph::op::v4::Interpolate::CoordinateTransformMode::align_corners;
+        }
+
+        attrs.nearest_mode = ngraph::op::v4::Interpolate::NearestMode::round_prefer_floor;
+
+        std::vector<int64_t> shape = {outHeight, outWidth};
+        auto out_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{2}, shape.data());
+
+        auto& input_shape = ieInpNode->get_shape();
+        std::vector<float> scales = {static_cast<float>(outHeight) / input_shape[2], static_cast<float>(outHeight) / input_shape[2]};
+        auto scales_shape = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, ngraph::Shape{2}, scales.data());
+
+        auto axes = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{2}, std::vector<int64_t>{2, 3});
+        auto interp = std::make_shared<ngraph::op::v4::Interpolate>(ieInpNode, out_shape, scales_shape, axes, attrs);
+#endif
         return Ptr<BackendNode>(new InfEngineNgraphNode(interp));
     }
 #endif  // HAVE_DNN_NGRAPH