Publishing R3
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / ie_reshape_io_controllers.hpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #pragma once
7
8 #include <string>
9 #include <vector>
10 #include <list>
11 #include <map>
12 #include <set>
13 #include <memory>
14
15 #include <ie_layers.h>
16 #include "shape_infer/built-in/ie_built_in_holder.hpp"
17 #include "../debug.h"
18 #include "ie_reshape_launcher.hpp"
19
20 namespace InferenceEngine {
21 namespace ShapeInfer {
22
23 struct ShapeDesc {
24     std::string dataName;
25     SizeVector dims;
26 };
27
28 class DefaultChecker {
29 public:
30     using Ptr = std::shared_ptr<DefaultChecker>;
31
32     virtual void run(const std::vector<DataPtr>& inData, const std::string& layerName);
33
34     virtual ~DefaultChecker() = default;
35 };
36
37 class EmptyChecker : public DefaultChecker {
38 public:
39     void run(const std::vector<DataPtr>& inData, const std::string& layerName) override {};
40 };
41
42 class InputController {
43 public:
44     InputController(const std::vector<DataPtr>& dataVec,
45                     const std::string& layerName,
46                     bool irShapesOnInit = false,
47                     const DefaultChecker::Ptr& checker = std::make_shared<DefaultChecker>());
48
49     virtual ~InputController() = default;
50
51     /**
52      * @brief Set shape for current reshape launcher by corresponding Data name.
53      * @param shape - shape to be set
54      * @param dataName - Data's name
55      */
56     virtual void setShapeByName(const SizeVector& shape, const std::string& dataName);
57
58     /**
59      * @brief Set shape for current reshape launcher by corresponding index.
60      * @param shape - shape to be set
61      * @param index - shape's index
62      */
63     virtual void setShapeByIndex(const SizeVector& shape, size_t index);
64
65     /**
66      * @brief Returns shapes that are supposed to be set by reshape algorithm.
67      * @note Shapes are in topological order.
68      * @param check - indicator whether check for correspondence of input data and shapes is required
69      * @return shapes
70      */
71     virtual std::vector<SizeVector> getShapes(bool check);
72
73     /**
74      * @brief Returns shapes from IR. If Controller was initialized irShapesOnInit=false, it accesses Data object of Layer
75      * If not, all shapes from IR are collected on Controller's construction.
76      * @note Shapes are in topological order.
77      * @return shapes from IR
78      */
79     virtual std::vector<SizeVector> getIRShapes();
80
81     /**
82      * @brief Returns shape from IR by corresponding Data's name
83      * @param dataName - name of Data object that holds requested shape
84      * @return shape from IR
85      */
86     virtual SizeVector getIRShapeByName(const std::string& dataName);
87
88     /**
89      * @brief Applies calculated shapes to the Data of the Layer
90      */
91     virtual void applyChanges();
92
93     /**
94      * @brief Reset vector of input shapes.
95      */
96     virtual void reset();
97
98     virtual void checkCorrespondence();
99
100 private:
101     /**
102      * @brief Returns shapes from IR by accessing Data object of Layer
103      * @note Shapes are in topological order.
104      * @return shapes from IR
105      */
106     std::vector<SizeVector> getIRShapesInternal();
107
108     long getPositionByName(const std::string& dataName);
109
110 protected:
111     std::vector<DataPtr> _dataVec;
112     std::vector<SizeVector> _shapes;
113     std::vector<SizeVector> _irShapes;
114     std::vector<std::string> _dataNames;
115     std::string _layerName;
116     bool _irShapesOnInit = false;
117 };
118
119 /**
120  * @brief Keeps calculated output shapes, distribute (propagate) them to the connected layers, applies output shapes to the Data object
121  */
122 class OutputController : public InputController {
123 public:
124     OutputController(const std::vector<DataPtr>& inData,
125                      const std::string& layerName,
126                      bool irShapesOnInit = false,
127                      const DefaultChecker::Ptr& checker = std::make_shared<DefaultChecker>());
128
129     /**
130      * @brief Set calculated output shapes as inputs for next layers
131      * @param launchers - Map of layer names to reshape launchers for that layer
132      */
133     virtual void propagateShapes(const std::set<ReshapeLauncher::Ptr>& launchers);
134
135     virtual void setShapes(const std::vector<SizeVector>& shapes);
136 };
137
138 }  // namespace ShapeInfer
139 }  // namespace InferenceEngine