1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
15 #include <ie_layers.h>
16 #include "shape_infer/built-in/ie_built_in_holder.hpp"
18 #include "ie_reshape_launcher.hpp"
20 namespace InferenceEngine {
21 namespace ShapeInfer {
28 class DefaultChecker {
30 using Ptr = std::shared_ptr<DefaultChecker>;
32 virtual void run(const std::vector<DataPtr>& inData, const std::string& layerName);
34 virtual ~DefaultChecker() = default;
37 class EmptyChecker : public DefaultChecker {
39 void run(const std::vector<DataPtr>& inData, const std::string& layerName) override {};
42 class InputController {
44 InputController(const std::vector<DataPtr>& dataVec,
45 const std::string& layerName,
46 bool irShapesOnInit = false,
47 const DefaultChecker::Ptr& checker = std::make_shared<DefaultChecker>());
49 virtual ~InputController() = default;
52 * @brief Set shape for current reshape launcher by corresponding Data name.
53 * @param shape - shape to be set
54 * @param dataName - Data's name
56 virtual void setShapeByName(const SizeVector& shape, const std::string& dataName);
59 * @brief Set shape for current reshape launcher by corresponding index.
60 * @param shape - shape to be set
61 * @param index - shape's index
63 virtual void setShapeByIndex(const SizeVector& shape, size_t index);
66 * @brief Returns shapes that are supposed to be set by reshape algorithm.
67 * @note Shapes are in topological order.
68 * @param check - indicator whether check for correspondence of input data and shapes is required
71 virtual std::vector<SizeVector> getShapes(bool check);
74 * @brief Returns shapes from IR. If Controller was initialized irShapesOnInit=false, it accesses Data object of Layer
75 * If not, all shapes from IR are collected on Controller's construction.
76 * @note Shapes are in topological order.
77 * @return shapes from IR
79 virtual std::vector<SizeVector> getIRShapes();
82 * @brief Returns shape from IR by corresponding Data's name
83 * @param dataName - name of Data object that holds requested shape
84 * @return shape from IR
86 virtual SizeVector getIRShapeByName(const std::string& dataName);
89 * @brief Applies calculated shapes to the Data of the Layer
91 virtual void applyChanges();
94 * @brief Reset vector of input shapes.
98 virtual void checkCorrespondence();
102 * @brief Returns shapes from IR by accessing Data object of Layer
103 * @note Shapes are in topological order.
104 * @return shapes from IR
106 std::vector<SizeVector> getIRShapesInternal();
108 long getPositionByName(const std::string& dataName);
111 std::vector<DataPtr> _dataVec;
112 std::vector<SizeVector> _shapes;
113 std::vector<SizeVector> _irShapes;
114 std::vector<std::string> _dataNames;
115 std::string _layerName;
116 bool _irShapesOnInit = false;
120 * @brief Keeps calculated output shapes, distribute (propagate) them to the connected layers, applies output shapes to the Data object
122 class OutputController : public InputController {
124 OutputController(const std::vector<DataPtr>& inData,
125 const std::string& layerName,
126 bool irShapesOnInit = false,
127 const DefaultChecker::Ptr& checker = std::make_shared<DefaultChecker>());
130 * @brief Set calculated output shapes as inputs for next layers
131 * @param launchers - Map of layer names to reshape launchers for that layer
133 virtual void propagateShapes(const std::set<ReshapeLauncher::Ptr>& launchers);
135 virtual void setShapes(const std::vector<SizeVector>& shapes);
138 } // namespace ShapeInfer
139 } // namespace InferenceEngine