Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / ie_reshape_io_controllers.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <string>
8 #include <vector>
9 #include <list>
10 #include <map>
11 #include <set>
12 #include <memory>
13
14 #include <ie_layers.h>
15 #include "shape_infer/built-in/ie_built_in_holder.hpp"
16 #include "../debug.h"
17 #include "ie_reshape_launcher.hpp"
18
19 namespace InferenceEngine {
20 namespace ShapeInfer {
21
22 struct ShapeDesc {
23     std::string dataName;
24     SizeVector dims;
25 };
26
27 class DefaultChecker {
28 public:
29     using Ptr = std::shared_ptr<DefaultChecker>;
30
31     virtual void run(const std::vector<DataPtr>& inData, const std::string& layerName);
32
33     virtual ~DefaultChecker() = default;
34 };
35
36 class EmptyChecker : public DefaultChecker {
37 public:
38     void run(const std::vector<DataPtr>& inData, const std::string& layerName) override {};
39 };
40
41 class InputController {
42 public:
43     InputController(const std::vector<DataPtr>& dataVec,
44                     const std::string& layerName,
45                     const DefaultChecker::Ptr& checker = std::make_shared<DefaultChecker>());
46
47     virtual ~InputController() = default;
48
49     /**
50      * @brief Set shape for current reshape launcher by corresponding Data name.
51      * @param shape - shape to be set
52      * @param dataName - Data's name
53      */
54     virtual void setShapeByName(const SizeVector& shape, const std::string& dataName);
55
56     /**
57      * @brief Return calculated shape for name.
58      */
59     virtual SizeVector getShapeByName(const std::string& dataName);
60
61     /**
62      * @brief Set shape for current reshape launcher by corresponding index.
63      * @param shape - shape to be set
64      * @param index - shape's index
65      */
66     virtual void setShapeByIndex(const SizeVector& shape, size_t index);
67
68     /**
69      * @brief Returns shapes that are supposed to be set by reshape algorithm.
70      * @note Shapes are in topological order.
71      * @param check - indicator whether check for correspondence of input data and shapes is required
72      * @return shapes
73      */
74     virtual std::vector<SizeVector> getShapes(bool check);
75
76     /**
77      * @brief Returns shapes from IR. If Controller was initialized irShapesOnInit=false, it accesses Data object of Layer
78      * If not, all shapes from IR are collected on Controller's construction.
79      * @note Shapes are in topological order.
80      * @return shapes from IR
81      */
82     virtual std::vector<SizeVector> getIRShapes();
83
84     /**
85      * @brief Returns shape from IR by corresponding Data's name
86      * @param dataName - name of Data object that holds requested shape
87      * @return shape from IR
88      */
89     virtual SizeVector getIRShapeByName(const std::string& dataName);
90
91     /**
92      * @brief Applies calculated shapes to the Data of the Layer
93      */
94     virtual void applyChanges();
95
96     /**
97      * @brief Reset vector of input shapes.
98      */
99     virtual void reset();
100
101     virtual void checkCorrespondence();
102
103     virtual bool isDataAvailable();
104
105     virtual std::vector<Blob::CPtr> getBlobs(bool check);
106
107     virtual void setBlobByName(const Blob::CPtr& blob, const std::string& name);
108
109 private:
110     long getPositionByName(const std::string& dataName);
111
112 protected:
113     std::vector<DataPtr> _dataVec;
114     std::vector<SizeVector> _shapes;
115     std::vector<SizeVector> _irShapes;
116     std::vector<std::string> _dataNames;
117     std::string _layerName;
118     std::vector<Blob::CPtr> _inferedData;
119 };
120
121 /**
122  * @brief Keeps calculated output shapes, distribute (propagate) them to the connected layers, applies output shapes to the Data object
123  */
124 class OutputController : public InputController {
125 public:
126     OutputController(const std::vector<DataPtr>& inData,
127                      const std::string& layerName,
128                      const DefaultChecker::Ptr& checker = std::make_shared<DefaultChecker>());
129
130     /**
131      * @brief Set calculated output shapes as inputs for next layers
132      * @param launchers - Map of layer names to reshape launchers for that layer
133      */
134     virtual void propagateShapes(const std::set<ReshapeLauncher::Ptr>& launchers);
135
136     virtual void setShapes(const std::vector<SizeVector>& shapes);
137
138     virtual void setBlobs(const std::vector<Blob::Ptr>& blobs);
139
140     std::vector<Blob::Ptr> createBlobs();
141
142     void propagateBlobs(const std::set<ReshapeLauncher::Ptr>& set);
143 };
144
145 }  // namespace ShapeInfer
146 }  // namespace InferenceEngine