Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / ie_reshape_launcher.cpp
index c2651a0..d64c3bb 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
 #include <map>
 #include <set>
 #include <details/ie_exception.hpp>
+#include <shape_infer/const_infer/ie_const_infer_holder.hpp>
 #include "shape_infer/ie_reshape_launcher.hpp"
 #include "shape_infer/ie_reshape_io_controllers.hpp"
+#include "ie_reshape_launcher.hpp"
+
+#include "built-in/ie_tensor_iterator_shape_infer.hpp"
 
 using namespace InferenceEngine;
 using namespace ShapeInfer;
@@ -35,8 +39,10 @@ OutputController* DefaultInitializer::createOutputController(const CNNLayer* lay
 }
 
 ReshapeLauncher::ReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl,
-                                 const DefaultInitializer::Ptr& initializer) : _layer(layer), _impl(impl) {
+                                 const DefaultInitializer::Ptr& initializer) : _layer(layer), _reshapeImpl(impl) {
     initializer->check(layer, impl);
+    ConstInferHolder holder;
+    if (layer) _inferImpl = holder.getConstInferImpl(layer->type);
     try {
         _iController = initializer->createInputController(layer);
         _oController = initializer->createOutputController(layer);
@@ -59,13 +65,37 @@ void ReshapeLauncher::setShapeByName(const SizeVector& shape, const std::string&
     _iController->setShapeByName(shape, dataName);
 }
 
+void ReshapeLauncher::setBlobByName(const Blob::CPtr& blob, const std::string& dataName) {
+    _iController->setBlobByName(blob, dataName);
+}
+
+SizeVector ReshapeLauncher::getShapeByName(const std::string& dataName) {
+    return _oController->getShapeByName(dataName);
+}
+
 void ReshapeLauncher::reshape(const std::set<ReshapeLauncher::Ptr>& launchers) {
     ResponseDesc resp;
     std::vector<SizeVector> outShapes;
-    auto sts = _impl->inferShapes(_iController->getShapes(true), _layer->params, _layer->blobs, outShapes, &resp);
+
+    // TODO: TensorIterator strongly required original layer instance because body is not presented
+    //       in params map. Original subnetwork body is required for internal shape infer
+    TensorIteratorShapeProp *TI_shaper = dynamic_cast<TensorIteratorShapeProp*>(_reshapeImpl.get());
+    if (TI_shaper) {
+        TI_shaper->setOriginalLayer(_layer);
+    }
+
+    // try to call new API with input blobs
+    auto sts = _reshapeImpl->inferShapes(_iController->getBlobs(true), _layer->params, _layer->blobs, outShapes, &resp);
+    // in case of old custom shape infer function call old API
+    if (sts == NOT_IMPLEMENTED) {
+        sts = _reshapeImpl->inferShapes(_iController->getShapes(true), _layer->params, _layer->blobs, outShapes,
+                                        &resp);
+    }
     _oController->setShapes(outShapes);
     if (sts != OK)
-        THROW_IE_EXCEPTION << resp.msg;
+        THROW_IE_EXCEPTION <<
+                           "Failed to infer shapes for " + _layer->type + " layer (" + _layer->name + ") with error: " +
+                           resp.msg;
     _oController->propagateShapes(launchers);
 }
 
@@ -73,6 +103,23 @@ void ReshapeLauncher::applyChanges(CNNLayer* layer) {
     checkLayer(layer);
     _iController->applyChanges();
     _oController->applyChanges();
+
+    // TODO: Need to finalize result of internal body shape infer and apply
+    //       new shapes to body subnetwork
+    TensorIteratorShapeProp *TI_shaper = dynamic_cast<TensorIteratorShapeProp*>(_reshapeImpl.get());
+    if (TI_shaper) TI_shaper->apply();
+}
+
+void ReshapeLauncher::constInfer(const std::set<ReshapeLauncher::Ptr>& launchers) {
+    if (_iController->isDataAvailable() || _layer->type == "Const" || _layer->type == "Shape") {
+        auto outBlobs = _oController->createBlobs();
+        _oController->setBlobs(outBlobs);
+        if (!_inferImpl)
+            THROW_IE_EXCEPTION << "Failed to find reference implementation for `"
+                                  + _layer->name + "` Layer with `" + _layer->type + "` Type on constant propagation";
+        _inferImpl->infer(_iController->getBlobs(false), _layer->params, _layer->blobs, outBlobs);
+        _oController->propagateBlobs(launchers);
+    }
 }
 
 void ReshapeLauncher::reset() {
@@ -106,7 +153,7 @@ void ReshapeLauncher::setIRShapeByName(const std::string& dataName) {
 }
 
 void ReshapeLauncher::setShapeInferImpl(const IShapeInferImpl::Ptr& impl) {
-    _impl = impl;
+    _reshapeImpl = impl;
 }
 
 const CNNLayer* ReshapeLauncher::getLayer() const {
@@ -178,6 +225,10 @@ void OutputOnlyReshapeLauncher::setShapeByName(const SizeVector& shape, const st
     _oController->setShapeByName(shape, dataName);
 }
 
+void OutputOnlyReshapeLauncher::setBlobByName(const Blob::CPtr& blob, const std::string& dataName) {
+    _oController->setBlobByName(blob, dataName);
+}
+
 void OutputOnlyReshapeLauncher::setIRShapeByName(const std::string& dataName) {
     SizeVector foundShape = _oController->getIRShapeByName(dataName);
     _oController->setShapeByName(foundShape, dataName);
@@ -192,6 +243,23 @@ void OutputOnlyReshapeLauncher::reset() {
     _oController->reset();
 }
 
+void OutputOnlyReshapeLauncher::constInfer(const std::set<ReshapeLauncher::Ptr>& launchers) {
+    if (_layer->type == "Const") {
+        auto outBlobs = _oController->createBlobs();
+        _oController->setBlobs(outBlobs);
+        if (!_inferImpl)
+            THROW_IE_EXCEPTION << "Failed to find reference implementation for `"
+                                  + _layer->name + "` Layer with `" + _layer->type + "` Type on constant propagation";
+        _inferImpl->infer({}, _layer->params, _layer->blobs, outBlobs);
+        auto shapes = _oController->getShapes(true);
+        for (int i = 0; i < outBlobs.size(); i++) {
+            outBlobs[i]->Reshape(SizeVector(shapes[i].rbegin(), shapes[i].rend()), TensorDesc::getLayoutByDims(shapes[i]));
+        }
+        _oController->setBlobs(outBlobs);
+        _oController->propagateBlobs(launchers);
+    }
+}
+
 void InputInitializer::check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) {
     OutputOnlyInitializer::check(layer, impl);
     std::string errorBase = "Failed to init reshape launcher: layer type (`" + layer->type + "`) is not";
@@ -263,9 +331,6 @@ OutMemoryReshapeLauncher::OutMemoryReshapeLauncher(const CNNLayer* layer, const
         : ReshapeLauncher(layer, impl, std::make_shared<OutMemoryInitializer>()) {
 }
 
-void OutMemoryReshapeLauncher::reshape(const std::set<ReshapeLauncher::Ptr>& launchers) {
-}
-
 void OutMemoryReshapeLauncher::applyChanges(CNNLayer* layer) {
     checkLayer(layer);
     _iController->applyChanges();