1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
14 #include <ie_layers.h>
15 #include "shape_infer/const_infer/ie_const_infer_impl.hpp"
16 #include "shape_infer/built-in/ie_built_in_holder.hpp"
18 namespace InferenceEngine {
19 namespace ShapeInfer {
21 class InputController;
23 class OutputController;
25 class DefaultInitializer {
27 using Ptr = std::shared_ptr<DefaultInitializer>;
29 virtual void check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl);
31 virtual InputController* createInputController(const CNNLayer* layer);
33 virtual OutputController* createOutputController(const CNNLayer* layer);
35 virtual ~DefaultInitializer() = default;
39 * @class ReshapeLauncher
40 * @brief Helper class to infer shapes for the given CNNLayer by using specified implementation.
41 * Encapsulate input and output shapes, before applying it to the real CNNLayer and Data.
43 class ReshapeLauncher {
45 using Ptr = std::shared_ptr<ReshapeLauncher>;
49 * @param layer - const pointer to the layer for performing shape inference.
50 * It is used to obtain parameters, input/output shapes.
51 * @param impl - implementation of shape inference for the given layer
53 ReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl,
54 const DefaultInitializer::Ptr& initializer = std::make_shared<DefaultInitializer>());
56 virtual ~ReshapeLauncher();
59 * @brief Set input shape for current reshape launcher.
60 * @param shape - input shape to be set
62 virtual void setShapeByName(const SizeVector& shape, const std::string& dataName);
64 virtual void setBlobByName(const Blob::CPtr& blob, const std::string& dataName);
67 * @brief Return calculated shape for data with requested name.
68 * @return Result shape
70 virtual SizeVector getShapeByName(const std::string& dataName);
73 * @brief Set input shape from IR by Data name. If there's no Data with given name it throws exception
74 * @param dataName - name of the corresponding Data.
76 virtual void setIRShapeByName(const std::string& dataName);
79 * @brief Calculates output shapes and changed layer params using input shapes that was set
80 * @param resp Pointer to the response message that holds a description of an error if any occurred
81 * @param launchers - Map of pairs: layer name and its reshape launcher.
82 * @return Status code of the operation. OK if succeeded
84 virtual void reshape(const std::set<ReshapeLauncher::Ptr>& launchers);
86 virtual void constInfer(const std::set<ReshapeLauncher::Ptr>& launchers);
89 * @brief Apply new input shapes, calculated output shapes and changed layer's params to CNNLayer and Data.
90 * @param layer - pointer to the layer for setting changes in layer's params
92 virtual void applyChanges(CNNLayer* layer);
95 * @brief Reset all stored to the initial state: input/output shapes and layer's params.
96 * @param layer - pointer to the layer for setting changes in layer's params
100 virtual std::string getLayerName() const;
102 virtual std::string getLayerType() const;
104 virtual const CNNLayer* getLayer() const;
106 virtual void setShapeInferImpl(const IShapeInferImpl::Ptr& impl);
109 InputController* _iController = nullptr;
110 OutputController* _oController = nullptr;
111 const CNNLayer* _layer;
112 IShapeInferImpl::Ptr _reshapeImpl;
113 IConstInferImpl::Ptr _inferImpl;
117 * @brief Check that all shape infer operations were done with specified layer.
118 * @param layer - pointer to the layer to compare with
120 void checkLayer(CNNLayer* layer);
123 class FakeInitializer : public DefaultInitializer {
125 void check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) override;
127 InputController* createInputController(const CNNLayer* layer) override;
129 OutputController* createOutputController(const CNNLayer* layer) override;
133 * @class FakeReshapeLauncher
134 * @brief Helper class to infer shapes for layers without registered shape infer functions.
135 * Encapsulates input and output shapes, before applying it to the real CNNLayer and Data.
136 * If input shape is the same as in IR, it takes output shape from IR as is.
137 * It sets batch size to the first output dimension of all outputs if:
138 * 1) first dimension of all input layers should be the same (assume this is batch size)
139 * 2) calculated input shape of the unsupported layer is different only in a first dimension from original input shape in IR.
141 class FakeReshapeLauncher : public ReshapeLauncher {
143 using Ptr = std::shared_ptr<FakeReshapeLauncher>;
145 FakeReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl);
147 void reshape(const std::set<ReshapeLauncher::Ptr>& launchers) override;
149 void constInfer(const std::set<ReshapeLauncher::Ptr>& launchers) override {}
152 class OutputOnlyInitializer : public DefaultInitializer {
154 void check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) override;
156 InputController* createInputController(const CNNLayer* layer) override;
158 OutputController* createOutputController(const CNNLayer* layer) override;
162 * @class OutputOnlyReshapeLauncher
163 * @brief Helper class to infer shapes for layers without inputs. It creates output controller only, input one is null.
165 class OutputOnlyReshapeLauncher : public ReshapeLauncher {
167 using Ptr = std::shared_ptr<OutputOnlyReshapeLauncher>;
169 OutputOnlyReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl,
170 const OutputOnlyInitializer::Ptr& initializer = std::make_shared<OutputOnlyInitializer>());
172 void setShapeByName(const SizeVector& shape, const std::string& dataName) override;
174 void setIRShapeByName(const std::string& dataName) override;
176 void applyChanges(CNNLayer* layer) override;
178 void reset() override;
180 void setBlobByName(const Blob::CPtr& blob, const std::string& dataName) override;
182 void constInfer(const std::set<ReshapeLauncher::Ptr>& launchers) override;
185 class InputInitializer : public OutputOnlyInitializer {
187 void check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) override;
191 * @class InputReshapeLauncher
192 * @brief Helper class to infer shapes for input layers. Supported layer types: `Input` or `Memory`(as inputs only, if index=1)
193 * It takes new given input shape and propagate for connected layers. If shape is not set, it takes shapes from IR.
195 class InputReshapeLauncher : public OutputOnlyReshapeLauncher {
197 using Ptr = std::shared_ptr<InputReshapeLauncher>;
199 InputReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl,
200 const DefaultInitializer::Ptr& initializer = std::make_shared<InputInitializer>());
202 void reshape(const std::set<ReshapeLauncher::Ptr>& launchers) override;
205 class ConstInitializer : public OutputOnlyInitializer {
207 void check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) override;
211 * @class ConstReshapeLauncher
212 * @brief Helper class to infer shapes for layers with Const type.
213 * It checks if new given shape is the same as in IR. The launcher fails if not and propagate for connected layers otherwise.
214 * If shape is not set, it propagates shapes from IR.
216 class ConstReshapeLauncher : public OutputOnlyReshapeLauncher {
218 using Ptr = std::shared_ptr<InputReshapeLauncher>;
220 ConstReshapeLauncher(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl);
222 void reshape(const std::set<ReshapeLauncher::Ptr>& launchers) override;
225 class OutMemoryInitializer : public DefaultInitializer {
226 void check(const CNNLayer* layer, const IShapeInferImpl::Ptr& impl) override;
228 OutputController* createOutputController(const CNNLayer* layer) override;
232 * @class OutMemoryReshapeLauncher
233 * @brief Helper class to infer shapes for layers with Memory type (as outputs only, if index=0).
234 * It sets new input shapes and doesn't call propagation as this layer doesn't have childs.
236 class OutMemoryReshapeLauncher : public ReshapeLauncher {
238 using Ptr = std::shared_ptr<InputReshapeLauncher>;
240 OutMemoryReshapeLauncher(const CNNLayer* layer1, const IShapeInferImpl::Ptr& impl1);
242 void reshape(const std::set<ReshapeLauncher::Ptr>& launchers) override {}
244 void applyChanges(CNNLayer* layer) override;
246 void reset() override;
248 void constInfer(const std::set<ReshapeLauncher::Ptr>& launchers) override {}
251 } // namespace ShapeInfer
252 } // namespace InferenceEngine