1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
6 #include <ie_layer_validators.hpp>
12 #include <details/ie_exception.hpp>
13 #include <shape_infer/const_infer/ie_const_infer_holder.hpp>
14 #include "shape_infer/ie_reshape_launcher.hpp"
15 #include "shape_infer/ie_reshape_io_controllers.hpp"
16 #include "ie_reshape_launcher.hpp"
18 #include "built-in/ie_tensor_iterator_shape_infer.hpp"
20 using namespace InferenceEngine;
21 using namespace ShapeInfer;
23 void DefaultInitializer::check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) {
24 std::string errorBase = "Failed to init reshape launcher: ";
25 if (!layer) THROW_IE_EXCEPTION << errorBase + " pointer to the layer is null";
26 if (!impl) THROW_IE_EXCEPTION << errorBase + " shape infer implementation is null";
29 InputController* DefaultInitializer::createInputController(const CNNLayer* layer) {
30 std::vector<DataPtr> data;
31 for (auto const& insData : layer->insData) {
32 data.push_back(insData.lock());
34 return new InputController(data, layer->name);
37 OutputController* DefaultInitializer::createOutputController(const CNNLayer* layer) {
38 return new OutputController(layer->outData, layer->name);
41 ReshapeLauncher::ReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl,
42 const DefaultInitializer::Ptr& initializer) : _layer(layer), _reshapeImpl(impl) {
43 initializer->check(layer, impl);
44 ConstInferHolder holder;
45 if (layer) _inferImpl = holder.getConstInferImpl(layer->type);
47 _iController = initializer->createInputController(layer);
48 _oController = initializer->createOutputController(layer);
50 auto exception = std::current_exception();
53 std::rethrow_exception(exception);
57 ReshapeLauncher::~ReshapeLauncher() {
60 _iController = nullptr;
61 _oController = nullptr;
64 void ReshapeLauncher::setShapeByName(const SizeVector& shape, const std::string& dataName) {
65 _iController->setShapeByName(shape, dataName);
68 void ReshapeLauncher::setBlobByName(const Blob::CPtr& blob, const std::string& dataName) {
69 _iController->setBlobByName(blob, dataName);
72 SizeVector ReshapeLauncher::getShapeByName(const std::string& dataName) {
73 return _oController->getShapeByName(dataName);
76 void ReshapeLauncher::reshape(const std::set<ReshapeLauncher::Ptr>& launchers) {
78 std::vector<SizeVector> outShapes;
80 // TODO: TensorIterator strongly required original layer instance because body is not presented
81 // in params map. Original subnetwork body is required for internal shape infer
82 TensorIteratorShapeProp *TI_shaper = dynamic_cast<TensorIteratorShapeProp*>(_reshapeImpl.get());
84 TI_shaper->setOriginalLayer(_layer);
87 // try to call new API with input blobs
88 auto sts = _reshapeImpl->inferShapes(_iController->getBlobs(true), _layer->params, _layer->blobs, outShapes, &resp);
89 // in case of old custom shape infer function call old API
90 if (sts == NOT_IMPLEMENTED) {
91 sts = _reshapeImpl->inferShapes(_iController->getShapes(true), _layer->params, _layer->blobs, outShapes,
94 _oController->setShapes(outShapes);
97 "Failed to infer shapes for " + _layer->type + " layer (" + _layer->name + ") with error: " +
99 _oController->propagateShapes(launchers);
102 void ReshapeLauncher::applyChanges(CNNLayer* layer) {
104 _iController->applyChanges();
105 _oController->applyChanges();
107 // TODO: Need to finalize result of internal body shape infer and apply
108 // new shapes to body subnetwork
109 TensorIteratorShapeProp *TI_shaper = dynamic_cast<TensorIteratorShapeProp*>(_reshapeImpl.get());
110 if (TI_shaper) TI_shaper->apply();
113 void ReshapeLauncher::constInfer(const std::set<ReshapeLauncher::Ptr>& launchers) {
114 if (_iController->isDataAvailable() || _layer->type == "Const" || _layer->type == "Shape") {
115 auto outBlobs = _oController->createBlobs();
116 _oController->setBlobs(outBlobs);
118 THROW_IE_EXCEPTION << "Failed to find reference implementation for `"
119 + _layer->name + "` Layer with `" + _layer->type + "` Type on constant propagation";
120 _inferImpl->infer(_iController->getBlobs(false), _layer->params, _layer->blobs, outBlobs);
121 _oController->propagateBlobs(launchers);
125 void ReshapeLauncher::reset() {
126 _iController->reset();
127 _oController->reset();
130 std::string ReshapeLauncher::getLayerName() const {
134 std::string ReshapeLauncher::getLayerType() const {
138 void ReshapeLauncher::checkLayer(CNNLayer* layer) {
139 if ((nullptr == _layer || layer == nullptr)) {
140 THROW_IE_EXCEPTION << "Can't apply changes for empty layer";
142 auto oldParams = _layer->params;
143 auto newParams = layer->params;
144 if ((!oldParams.empty() && !newParams.empty() && !std::equal(oldParams.begin(), oldParams.end(), newParams.begin()))
145 || (_layer->name != layer->name) || (_layer->type != layer->type) || oldParams.size() != newParams.size()) {
146 THROW_IE_EXCEPTION << "Can't apply changes for layer with another params";
150 void ReshapeLauncher::setIRShapeByName(const std::string& dataName) {
151 SizeVector foundShape = _iController->getIRShapeByName(dataName);
152 _iController->setShapeByName(foundShape, dataName);
155 void ReshapeLauncher::setShapeInferImpl(const IShapeInferImpl::Ptr& impl) {
159 const CNNLayer* ReshapeLauncher::getLayer() const {
163 InputController* FakeInitializer::createInputController(const CNNLayer* layer) {
164 std::vector<DataPtr> outData;
165 for (auto const& insData : layer->insData) {
166 outData.push_back(insData.lock());
168 return new InputController(outData, layer->name);
171 void FakeInitializer::check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) {
172 std::string errorBase = "Failed to init reshape launcher: ";
173 if (!layer) THROW_IE_EXCEPTION << errorBase + " pointer to the layer is null";
176 OutputController* FakeInitializer::createOutputController(const CNNLayer* layer) {
177 return new OutputController(layer->outData, layer->name);
180 FakeReshapeLauncher::FakeReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl)
181 : ReshapeLauncher(layer, impl, std::make_shared<FakeInitializer>()) {
184 void FakeReshapeLauncher::reshape(const std::set<ReshapeLauncher::Ptr>& launchers) {
185 auto iShapesIR = _iController->getIRShapes();
186 auto oShapesIR = _oController->getIRShapes();
187 auto iShapes = _iController->getShapes(true);
189 for (int i = 0; i < iShapes.size(); i++) {
190 auto newInShape = iShapes[i];
191 auto irInShape = iShapesIR[i];
192 bool equal = std::equal(newInShape.begin(), newInShape.end(), irInShape.begin());
195 << "Failed to infer shapes for layer with type: " << _layer->type
196 << ". Use @IShapeInferExtension class to register shape infer function for this layer";
200 _oController->setShapes(oShapesIR);
201 _oController->propagateShapes(launchers);
204 void OutputOnlyInitializer::check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) {
205 std::string errorBase = "Failed to init reshape launcher: ";
206 if (!layer) THROW_IE_EXCEPTION << errorBase + " pointer to the layer is null";
207 if (!layer->insData.empty())
208 THROW_IE_EXCEPTION << "Failed to init reshape launcher: "
209 << "layer type (`" + layer->type + "`) is supposed to not have inputs, but actually it has";
212 InputController* OutputOnlyInitializer::createInputController(const CNNLayer* layer) {
216 OutputController* OutputOnlyInitializer::createOutputController(const CNNLayer* layer) {
217 return new OutputController(layer->outData, layer->name);
220 OutputOnlyReshapeLauncher::OutputOnlyReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl,
221 const OutputOnlyInitializer::Ptr& initializer)
222 : ReshapeLauncher(layer, impl, initializer) {}
224 void OutputOnlyReshapeLauncher::setShapeByName(const SizeVector& shape, const std::string& dataName) {
225 _oController->setShapeByName(shape, dataName);
228 void OutputOnlyReshapeLauncher::setBlobByName(const Blob::CPtr& blob, const std::string& dataName) {
229 _oController->setBlobByName(blob, dataName);
232 void OutputOnlyReshapeLauncher::setIRShapeByName(const std::string& dataName) {
233 SizeVector foundShape = _oController->getIRShapeByName(dataName);
234 _oController->setShapeByName(foundShape, dataName);
237 void OutputOnlyReshapeLauncher::applyChanges(CNNLayer* layer) {
239 _oController->applyChanges();
242 void OutputOnlyReshapeLauncher::reset() {
243 _oController->reset();
246 void OutputOnlyReshapeLauncher::constInfer(const std::set<ReshapeLauncher::Ptr>& launchers) {
247 if (_layer->type == "Const") {
248 auto outBlobs = _oController->createBlobs();
249 _oController->setBlobs(outBlobs);
251 THROW_IE_EXCEPTION << "Failed to find reference implementation for `"
252 + _layer->name + "` Layer with `" + _layer->type + "` Type on constant propagation";
253 _inferImpl->infer({}, _layer->params, _layer->blobs, outBlobs);
254 auto shapes = _oController->getShapes(true);
255 for (int i = 0; i < outBlobs.size(); i++) {
256 outBlobs[i]->Reshape(SizeVector(shapes[i].rbegin(), shapes[i].rend()), TensorDesc::getLayoutByDims(shapes[i]));
258 _oController->setBlobs(outBlobs);
259 _oController->propagateBlobs(launchers);
263 void InputInitializer::check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) {
264 OutputOnlyInitializer::check(layer, impl);
265 std::string errorBase = "Failed to init reshape launcher: layer type (`" + layer->type + "`) is not";
266 if (details::equal(layer->type, "memory")) {
267 if (!layer->GetParamAsInt("index"))
268 THROW_IE_EXCEPTION << errorBase << " `Memory`(as input)";
269 } else if (!::details::equal(layer->type, "input")) {
270 THROW_IE_EXCEPTION << errorBase << " `Input`";
274 InputReshapeLauncher::InputReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl,
275 const DefaultInitializer::Ptr& initializer)
276 : OutputOnlyReshapeLauncher(layer, impl, initializer) {}
278 void InputReshapeLauncher::reshape(const std::set<ReshapeLauncher::Ptr>& launchers) {
279 auto oShapes = _oController->getShapes(false);
280 auto oIRShapes = _oController->getIRShapes();
281 for (size_t i = 0; i < oShapes.size(); i++) {
282 if (oShapes[i].empty()) {
283 _oController->setShapeByIndex(oIRShapes[i], i);
286 _oController->propagateShapes(launchers);
289 void ConstInitializer::check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) {
290 OutputOnlyInitializer::check(layer, impl);
291 if (!::details::equal(layer->type, "const"))
292 THROW_IE_EXCEPTION << "Failed to init reshape launcher: layer type (`" + layer->type + "`) is not `Const`";
295 ConstReshapeLauncher::ConstReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl)
296 : OutputOnlyReshapeLauncher(layer, impl, std::make_shared<ConstInitializer>()) {}
298 void ConstReshapeLauncher::reshape(const std::set<ReshapeLauncher::Ptr>& launchers) {
299 auto oShapesIR = _oController->getIRShapes();
300 auto oShapes = _oController->getShapes(false);
302 if (oShapes.empty()) {
303 _oController->setShapes(oShapesIR);
305 if (oShapes != oShapesIR) {
306 THROW_IE_EXCEPTION << "Failed to set different shapes for Const layer,"
307 << " original shapes:" << details::dumpVec(oShapesIR)
308 << " new shapes:" << details::dumpVec(oShapes);
310 _oController->propagateShapes(launchers);
313 void OutMemoryInitializer::check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) {
314 std::string errorBase = "Failed to init reshape launcher: ";
315 if (!layer) THROW_IE_EXCEPTION << errorBase + " pointer to the layer is null";
316 int index = layer->GetParamAsInt("index");
317 if (!::details::equal(layer->type, "memory") && index)
319 << "Failed to init reshape launcher: layer type (`" + layer->type + "`) is not `Memory` as output";
320 if (!layer->outData.empty())
321 THROW_IE_EXCEPTION << "Failed to init reshape launcher: "
322 << "layer type (`" + layer->type +
323 "`) is supposed to not have outputs, but actually it has";
326 OutputController* OutMemoryInitializer::createOutputController(const CNNLayer* layer) {
330 OutMemoryReshapeLauncher::OutMemoryReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl)
331 : ReshapeLauncher(layer, impl, std::make_shared<OutMemoryInitializer>()) {
334 void OutMemoryReshapeLauncher::applyChanges(CNNLayer* layer) {
336 _iController->applyChanges();
339 void OutMemoryReshapeLauncher::reset() {
340 _iController->reset();