1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
14 #include <ie_layers.h>
15 #include <ie_context.hpp>
16 #include <builders/ie_network_builder.hpp>
17 #include "details/caseless.hpp"
18 #include "ie_reshape_launcher.hpp"
19 #include "ie_icnn_network.hpp"
21 namespace InferenceEngine {
22 namespace ShapeInfer {
24 class INFERENCE_ENGINE_API_CLASS(LauncherCreator) {
26 using Ptr = std::shared_ptr<LauncherCreator>;
29 * @brief Creates reshape launcher for the given intermediate layer with first registered implementation.
30 * Built-in implementations are first, then - custom ones.
31 * Throws exception if it fails to find implementation for the given layer.
32 * @param layer - const pointer to the CNNLayer for which shape infer is needed
33 * @param extensions - all registered extensions
34 * @return - shared_ptr to the corresponding launcher.
36 virtual ReshapeLauncher::Ptr
37 createNotInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions);
40 * @brief Creates reshape launcher for the given input layer. Supported types: Input, Const, Memory (as input)
41 * @param layer - const pointer to the CNNLayer for which shape infer is needed
42 * @param extensions - all registered extensions
43 * @return - shared_ptr to the corresponding launcher.
45 virtual ReshapeLauncher::Ptr
46 createInputLauncher(const CNNLayer* layer, const std::vector<IShapeInferExtensionPtr>& extensions);
48 virtual ~LauncherCreator() = default;
53 * @brief Helper class to infer shapes for the given ICNNNetwork.
54 * It delegates shape inference to the corresponding ReshapeLauncher.
56 class INFERENCE_ENGINE_API_CLASS(Reshaper) {
60 * @param network - const reference to the ICNNNetwork for performing shape inference
62 explicit Reshaper(ICNNNetwork& network,
63 const LauncherCreator::Ptr& creator = std::make_shared<LauncherCreator>());
65 explicit Reshaper(std::vector<DataPtr> inputs,
66 const LauncherCreator::Ptr& launcherCreator = std::make_shared<LauncherCreator>());
68 Reshaper(Builder::Network* network);
70 virtual ~Reshaper() = default;
73 * @brief Adds shape infer extension to provide implementations of shape infer functions
74 * @param extension - pointer to the shape infer extension
76 void AddExtension(const IShapeInferExtensionPtr& extension);
79 * @brief Launches shape inference for the given ICNNNetworkAdds and input shapes.
80 * Throws if shape infer failed without corruption of original shapes
81 * @param inputShapes - Map of input names (data) to their input shapes.
83 StatusCode run(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp = nullptr);
86 * @brief Perform shape inference for the given input shapes but not apply it.
87 * In case of cusses call apply() method.
88 * @param inputShapes - Map of input names (data) to their input shapes.
89 * @throws exception if shape infer failed without corruption of original shapes
91 StatusCode runNoApply(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp = nullptr);
94 * @brief Apply shapes pre calculated by runNoApply() method.
96 StatusCode apply(ResponseDesc* resp = nullptr);
99 * @brief Return newly calculated shape for provided data.
101 SizeVector getResultShapeFor(DataPtr &data, ResponseDesc* resp = nullptr);
104 ReshapeLauncher::Ptr getLauncherByLayerName(const std::string& layerName) const;
106 StatusCode networkShapeInfer(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp);
108 static InferenceEngine::details::caseless_set<std::string> getTypeNamesFromExtension(const IShapeInferExtensionPtr& extension);
110 std::vector<IShapeInferExtensionPtr> _extensions;
111 std::set<ReshapeLauncher::Ptr> _launchers;
112 std::vector<CNNLayerPtr> _allSortedLayers{};
113 std::set<CNNLayerPtr> _inputLayers{};
114 InferenceEngine::details::caseless_set<std::string> _allTypes;
116 Builder::Network* network;
119 } // namespace ShapeInfer
120 } // namespace InferenceEngine