1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
14 #include <ie_layers.h>
15 #include "shape_infer/built-in/ie_built_in_holder.hpp"
17 #include "ie_reshape_launcher.hpp"
19 namespace InferenceEngine {
20 namespace ShapeInfer {
27 class DefaultChecker {
29 using Ptr = std::shared_ptr<DefaultChecker>;
31 virtual void run(const std::vector<DataPtr>& inData, const std::string& layerName);
33 virtual ~DefaultChecker() = default;
36 class EmptyChecker : public DefaultChecker {
38 void run(const std::vector<DataPtr>& inData, const std::string& layerName) override {};
41 class InputController {
43 InputController(const std::vector<DataPtr>& dataVec,
44 const std::string& layerName,
45 const DefaultChecker::Ptr& checker = std::make_shared<DefaultChecker>());
47 virtual ~InputController() = default;
50 * @brief Set shape for current reshape launcher by corresponding Data name.
51 * @param shape - shape to be set
52 * @param dataName - Data's name
54 virtual void setShapeByName(const SizeVector& shape, const std::string& dataName);
57 * @brief Return calculated shape for name.
59 virtual SizeVector getShapeByName(const std::string& dataName);
62 * @brief Set shape for current reshape launcher by corresponding index.
63 * @param shape - shape to be set
64 * @param index - shape's index
66 virtual void setShapeByIndex(const SizeVector& shape, size_t index);
69 * @brief Returns shapes that are supposed to be set by reshape algorithm.
70 * @note Shapes are in topological order.
71 * @param check - indicator whether check for correspondence of input data and shapes is required
74 virtual std::vector<SizeVector> getShapes(bool check);
77 * @brief Returns shapes from IR. If Controller was initialized irShapesOnInit=false, it accesses Data object of Layer
78 * If not, all shapes from IR are collected on Controller's construction.
79 * @note Shapes are in topological order.
80 * @return shapes from IR
82 virtual std::vector<SizeVector> getIRShapes();
85 * @brief Returns shape from IR by corresponding Data's name
86 * @param dataName - name of Data object that holds requested shape
87 * @return shape from IR
89 virtual SizeVector getIRShapeByName(const std::string& dataName);
92 * @brief Applies calculated shapes to the Data of the Layer
94 virtual void applyChanges();
97 * @brief Reset vector of input shapes.
101 virtual void checkCorrespondence();
103 virtual bool isDataAvailable();
105 virtual std::vector<Blob::CPtr> getBlobs(bool check);
107 virtual void setBlobByName(const Blob::CPtr& blob, const std::string& name);
110 long getPositionByName(const std::string& dataName);
113 std::vector<DataPtr> _dataVec;
114 std::vector<SizeVector> _shapes;
115 std::vector<SizeVector> _irShapes;
116 std::vector<std::string> _dataNames;
117 std::string _layerName;
118 std::vector<Blob::CPtr> _inferedData;
122 * @brief Keeps calculated output shapes, distribute (propagate) them to the connected layers, applies output shapes to the Data object
124 class OutputController : public InputController {
126 OutputController(const std::vector<DataPtr>& inData,
127 const std::string& layerName,
128 const DefaultChecker::Ptr& checker = std::make_shared<DefaultChecker>());
131 * @brief Set calculated output shapes as inputs for next layers
132 * @param launchers - Map of layer names to reshape launchers for that layer
134 virtual void propagateShapes(const std::set<ReshapeLauncher::Ptr>& launchers);
136 virtual void setShapes(const std::vector<SizeVector>& shapes);
138 virtual void setBlobs(const std::vector<Blob::Ptr>& blobs);
140 std::vector<Blob::Ptr> createBlobs();
142 void propagateBlobs(const std::set<ReshapeLauncher::Ptr>& set);
145 } // namespace ShapeInfer
146 } // namespace InferenceEngine