Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / ie_reshaper.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <string>
8 #include <vector>
9 #include <list>
10 #include <map>
11 #include <set>
12 #include <memory>
13
14 #include <ie_layers.h>
15 #include <ie_context.hpp>
16 #include <builders/ie_network_builder.hpp>
17 #include "details/caseless.hpp"
18 #include "ie_reshape_launcher.hpp"
19 #include "ie_icnn_network.hpp"
20
21 namespace InferenceEngine {
22 namespace ShapeInfer {
23
24 class INFERENCE_ENGINE_API_CLASS(LauncherCreator) {
25 public:
26     using Ptr = std::shared_ptr<LauncherCreator>;
27
28     /**
29      * @brief Creates reshape launcher for the given intermediate layer with first registered implementation.
30      * Built-in implementations are first, then - custom ones.
31      * Throws exception if it fails to find implementation for the given layer.
32      * @param layer - const pointer to the CNNLayer for which shape infer is needed
33      * @param extensions - all registered extensions
34      * @return - shared_ptr to the corresponding launcher.
35      */
36     virtual ReshapeLauncher::Ptr
37     createNotInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions);
38
39     /**
40      * @brief Creates reshape launcher for the given input layer. Supported types: Input, Const, Memory (as input)
41      * @param layer - const pointer to the CNNLayer for which shape infer is needed
42      * @param extensions - all registered extensions
43      * @return - shared_ptr to the corresponding launcher.
44      */
45     virtual ReshapeLauncher::Ptr
46     createInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions);
47
48     virtual ~LauncherCreator() = default;
49 };
50
51 /**
52  * @class Reshaper
53  * @brief Helper class to infer shapes for the given ICNNNetwork.
54  * It delegates shape inference to the corresponding ReshapeLauncher.
55  */
56 class INFERENCE_ENGINE_API_CLASS(Reshaper) {
57 public:
58     /**
59      * @brief Constructor
60      * @param network - const reference to the ICNNNetwork for performing shape inference
61      */
62     explicit Reshaper(ICNNNetwork& network,
63             const LauncherCreator::Ptr& creator = std::make_shared<LauncherCreator>());
64
65     explicit Reshaper(std::vector<DataPtr> inputs,
66             const LauncherCreator::Ptr& launcherCreator = std::make_shared<LauncherCreator>());
67
68     Reshaper(Builder::Network* network);
69
70     virtual ~Reshaper() = default;
71
72     /**
73      * @brief Adds shape infer extension to provide implementations of shape infer functions
74      * @param extension - pointer to the shape infer extension
75      */
76     void AddExtension(const IShapeInferExtensionPtr& extension);
77
78     /**
79      * @brief Launches shape inference for the given ICNNNetworkAdds and input shapes.
80      * Throws if shape infer failed without corruption of original shapes
81      * @param inputShapes - Map of input names (data) to their input shapes.
82      */
83     StatusCode run(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp = nullptr);
84
85     /**
86      * @brief Perform shape inference for the given input shapes but not apply it.
87      * In case of cusses call apply() method.
88      * @param inputShapes - Map of input names (data) to their input shapes.
89      * @throws exception if shape infer failed without corruption of original shapes
90      */
91     StatusCode runNoApply(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp = nullptr);
92
93     /**
94      * @brief Apply shapes pre calculated by runNoApply() method.
95      */
96     StatusCode apply(ResponseDesc* resp = nullptr);
97
98     /**
99      * @brief Return newly calculated shape for provided data.
100      */
101     SizeVector getResultShapeFor(DataPtr &data, ResponseDesc* resp = nullptr);
102
103 private:
104     ReshapeLauncher::Ptr getLauncherByLayerName(const std::string& layerName) const;
105
106     StatusCode networkShapeInfer(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp);
107
108     static InferenceEngine::details::caseless_set<std::string> getTypeNamesFromExtension(const IShapeInferExtensionPtr& extension);
109
110     std::vector<IShapeInferExtensionPtr> _extensions;
111     std::set<ReshapeLauncher::Ptr> _launchers;
112     std::vector<CNNLayerPtr> _allSortedLayers{};
113     std::set<CNNLayerPtr> _inputLayers{};
114     InferenceEngine::details::caseless_set<std::string> _allTypes;
115
116     Builder::Network* network;
117 };
118
119 }  // namespace ShapeInfer
120 }  // namespace InferenceEngine