-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <map>
#include <set>
#include <details/ie_exception.hpp>
+#include <shape_infer/const_infer/ie_const_infer_holder.hpp>
#include "shape_infer/ie_reshape_launcher.hpp"
#include "shape_infer/ie_reshape_io_controllers.hpp"
+#include "ie_reshape_launcher.hpp"
+
+#include "built-in/ie_tensor_iterator_shape_infer.hpp"
using namespace InferenceEngine;
using namespace ShapeInfer;
}
ReshapeLauncher::ReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl,
- const DefaultInitializer::Ptr& initializer) : _layer(layer), _impl(impl) {
+ const DefaultInitializer::Ptr& initializer) : _layer(layer), _reshapeImpl(impl) {
initializer->check(layer, impl);
+ ConstInferHolder holder;
+ if (layer) _inferImpl = holder.getConstInferImpl(layer->type);
try {
_iController = initializer->createInputController(layer);
_oController = initializer->createOutputController(layer);
_iController->setShapeByName(shape, dataName);
}
+void ReshapeLauncher::setBlobByName(const Blob::CPtr& blob, const std::string& dataName) {
+ _iController->setBlobByName(blob, dataName);
+}
+
+SizeVector ReshapeLauncher::getShapeByName(const std::string& dataName) {
+ return _oController->getShapeByName(dataName);
+}
+
void ReshapeLauncher::reshape(const std::set<ReshapeLauncher::Ptr>& launchers) {
ResponseDesc resp;
std::vector<SizeVector> outShapes;
- auto sts = _impl->inferShapes(_iController->getShapes(true), _layer->params, _layer->blobs, outShapes, &resp);
+
+ // TODO: TensorIterator strongly required original layer instance because body is not presented
+ // in params map. Original subnetwork body is required for internal shape infer
+ TensorIteratorShapeProp *TI_shaper = dynamic_cast<TensorIteratorShapeProp*>(_reshapeImpl.get());
+ if (TI_shaper) {
+ TI_shaper->setOriginalLayer(_layer);
+ }
+
+ // try to call new API with input blobs
+ auto sts = _reshapeImpl->inferShapes(_iController->getBlobs(true), _layer->params, _layer->blobs, outShapes, &resp);
+ // in case of old custom shape infer function call old API
+ if (sts == NOT_IMPLEMENTED) {
+ sts = _reshapeImpl->inferShapes(_iController->getShapes(true), _layer->params, _layer->blobs, outShapes,
+ &resp);
+ }
_oController->setShapes(outShapes);
if (sts != OK)
- THROW_IE_EXCEPTION << resp.msg;
+ THROW_IE_EXCEPTION <<
+ "Failed to infer shapes for " + _layer->type + " layer (" + _layer->name + ") with error: " +
+ resp.msg;
_oController->propagateShapes(launchers);
}
checkLayer(layer);
_iController->applyChanges();
_oController->applyChanges();
+
+ // TODO: Need to finalize result of internal body shape infer and apply
+ // new shapes to body subnetwork
+ TensorIteratorShapeProp *TI_shaper = dynamic_cast<TensorIteratorShapeProp*>(_reshapeImpl.get());
+ if (TI_shaper) TI_shaper->apply();
+}
+
+void ReshapeLauncher::constInfer(const std::set<ReshapeLauncher::Ptr>& launchers) {
+ if (_iController->isDataAvailable() || _layer->type == "Const" || _layer->type == "Shape") {
+ auto outBlobs = _oController->createBlobs();
+ _oController->setBlobs(outBlobs);
+ if (!_inferImpl)
+ THROW_IE_EXCEPTION << "Failed to find reference implementation for `"
+ + _layer->name + "` Layer with `" + _layer->type + "` Type on constant propagation";
+ _inferImpl->infer(_iController->getBlobs(false), _layer->params, _layer->blobs, outBlobs);
+ _oController->propagateBlobs(launchers);
+ }
}
void ReshapeLauncher::reset() {
}
void ReshapeLauncher::setShapeInferImpl(const IShapeInferImpl::Ptr& impl) {
- _impl = impl;
+ _reshapeImpl = impl;
}
const CNNLayer* ReshapeLauncher::getLayer() const {
_oController->setShapeByName(shape, dataName);
}
+void OutputOnlyReshapeLauncher::setBlobByName(const Blob::CPtr& blob, const std::string& dataName) {
+ _oController->setBlobByName(blob, dataName);
+}
+
void OutputOnlyReshapeLauncher::setIRShapeByName(const std::string& dataName) {
SizeVector foundShape = _oController->getIRShapeByName(dataName);
_oController->setShapeByName(foundShape, dataName);
_oController->reset();
}
+void OutputOnlyReshapeLauncher::constInfer(const std::set<ReshapeLauncher::Ptr>& launchers) {
+ if (_layer->type == "Const") {
+ auto outBlobs = _oController->createBlobs();
+ _oController->setBlobs(outBlobs);
+ if (!_inferImpl)
+ THROW_IE_EXCEPTION << "Failed to find reference implementation for `"
+ + _layer->name + "` Layer with `" + _layer->type + "` Type on constant propagation";
+ _inferImpl->infer({}, _layer->params, _layer->blobs, outBlobs);
+ auto shapes = _oController->getShapes(true);
+ for (int i = 0; i < outBlobs.size(); i++) {
+ outBlobs[i]->Reshape(SizeVector(shapes[i].rbegin(), shapes[i].rend()), TensorDesc::getLayoutByDims(shapes[i]));
+ }
+ _oController->setBlobs(outBlobs);
+ _oController->propagateBlobs(launchers);
+ }
+}
+
void InputInitializer::check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) {
OutputOnlyInitializer::check(layer, impl);
std::string errorBase = "Failed to init reshape launcher: layer type (`" + layer->type + "`) is not";
: ReshapeLauncher(layer, impl, std::make_shared<OutMemoryInitializer>()) {
}
-void OutMemoryReshapeLauncher::reshape(const std::set<ReshapeLauncher::Ptr>& launchers) {
-}
-
void OutMemoryReshapeLauncher::applyChanges(CNNLayer* layer) {
checkLayer(layer);
_iController->applyChanges();