Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / built-in / ie_interp_shape_infer.hpp
index ebca8ff..a7efae0 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
@@ -11,6 +11,7 @@
 #include <memory>
 #include <string>
 #include <vector>
+#include <limits>
 
 namespace InferenceEngine {
 namespace ShapeInfer {
@@ -22,7 +23,7 @@ class InterpShapeProp : public BuiltInShapeInferImpl {
 public:
     explicit InterpShapeProp(const std::string& type) : BuiltInShapeInferImpl(type) {}
 
-    void inferShapesImpl(const std::vector<SizeVector>& inShapes,
+    void inferShapesImpl(const std::vector<Blob::CPtr>& inBlobs,
                          const std::map<std::string, std::string>& params,
                          const std::map<std::string, Blob::Ptr>& blobs,
                          std::vector<SizeVector>& outShapes) override {
@@ -30,60 +31,67 @@ public:
         CNNLayer cnnLayer(lp);
         cnnLayer.params = params;
         cnnLayer.type = _type;
-        validate(&cnnLayer, inShapes, params, blobs);
-        auto factor = static_cast<size_t>(cnnLayer.GetParamAsInt("factor", 0));
-        auto shrink_factor = static_cast<size_t>(cnnLayer.GetParamAsInt("shrink_factor", 0));
-        auto zoom_factor = static_cast<size_t>(cnnLayer.GetParamAsInt("zoom_factor", 0));
-        auto height = static_cast<size_t>(cnnLayer.GetParamAsInt("height", 0));
-        auto width = static_cast<size_t>(cnnLayer.GetParamAsInt("width", 0));
+        validate(&cnnLayer, inBlobs, params, blobs);
+        SizeVector outShape;
+        if (inBlobs.size() == 2) {
+            auto* buffer = inBlobs[1]->cbuffer().as<float*>();
+            if (buffer != nullptr) {
+                for (int i = 0; i < inBlobs[1]->size(); i++) {
+                    outShape.push_back(static_cast<unsigned long>(buffer[i]));
+                }
+            } else {
+                THROW_IE_EXCEPTION << "Second input must have allocated data";
+            }
+        } else {
+            auto factor = cnnLayer.GetParamAsFloat("factor", 0);
+            auto shrink_factor = cnnLayer.GetParamAsFloat("shrink_factor", 0);
+            auto zoom_factor = cnnLayer.GetParamAsFloat("zoom_factor", 0);
+            auto height = static_cast<size_t>(cnnLayer.GetParamAsInt("height", 0));
+            auto width = static_cast<size_t>(cnnLayer.GetParamAsInt("width", 0));
+
+            auto IS_ZERO = [](float value) {
+                return std::fabs(value) < std::numeric_limits<float>::epsilon();
+            };
+
+        bool noFactor = IS_ZERO(zoom_factor) && IS_ZERO(shrink_factor) && IS_ZERO(factor);
 
-        // TODO: move to validators
-        if (!zoom_factor && !shrink_factor && !factor && (!height || !width)) {
-            THROW_IE_EXCEPTION
-                    << "Can't reshape without factor, or target resolution. "
-                    << "Supported attributes: factor, shrink_factor, zoom_factor, height, width";
-        }
         size_t N, C, H, W;
-        // TODO: validate that only one input
         N = inShapes[0][0];
         C = inShapes[0][1];
         H = inShapes[0][2];
         W = inShapes[0][3];
 
+            auto SETW = [&width, &W](size_t value) {
+                if (width) {
+                    W = width;
+                } else {
+                    W = value;
+                }
+            };
 
-        auto SETW = [&width, &W](size_t value) {
-            if (width) {
-                W = width;
-            } else {
-                W = value;
-            }
-        };
+            auto SETH = [&height, &H](size_t value) {
+                if (height) {
+                    H = height;
+                } else {
+                    H = value;
+                }
+            };
 
-        auto SETH = [&height, &H](size_t value) {
-            if (height) {
-                H = height;
+            if (noFactor) {
+                SETW(width);
+                SETH(height);
             } else {
-                H = value;
-            }
-        };
-
-        if (factor) {
-            SETH(H * factor);
-            SETW(W * factor);
-        } else if (shrink_factor || zoom_factor) {
-            if (shrink_factor) {
-                SETH(H / shrink_factor);
-                SETW(W / shrink_factor);
-            }
-            if (zoom_factor) {
-                SETH(H * zoom_factor);
-                SETW(W * zoom_factor);
+                float actualFactor = factor;
+                if (!IS_ZERO(shrink_factor) || !IS_ZERO(zoom_factor)) {
+                    if (!IS_ZERO(zoom_factor)) actualFactor = zoom_factor;
+                    if (!IS_ZERO(shrink_factor)) actualFactor /= shrink_factor;
+                }
+                SETW(W * actualFactor);
+                SETH(H * actualFactor);
             }
-        } else {
-            SETW(width);
-            SETH(height);
+            outShape = {N, C, H, W};
         }
-        outShapes.push_back({N, C, H, W});
+        outShapes.push_back(outShape);
     }
 };