2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/LstmParams.hpp>
9 #include <armnn/QuantizedLstmParams.hpp>
10 #include <armnn/TensorFwd.hpp>
11 #include <armnn/Types.hpp>
13 #include <armnn/INetwork.hpp>
26 /// Private implementation of INetwork.
27 class Network final : public INetwork
33 const Graph& GetGraph() const { return *m_Graph; }
35 Status PrintGraph() override;
37 IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
39 IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
40 const char* name = nullptr) override;
42 IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
43 const char* name = nullptr) override;
45 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
46 const ConstTensor& weights,
47 const Optional<ConstTensor>& biases,
48 const char* name = nullptr) override;
50 ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
51 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
52 const ConstTensor& weights,
53 const char* name = nullptr) override;
55 ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
56 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
57 const ConstTensor& weights,
58 const ConstTensor& biases,
59 const char* name = nullptr) override;
61 IConnectableLayer* AddDepthwiseConvolution2dLayer(
62 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
63 const ConstTensor& weights,
64 const Optional<ConstTensor>& biases,
65 const char* name = nullptr) override;
67 ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
68 IConnectableLayer* AddDepthwiseConvolution2dLayer(
69 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
70 const ConstTensor& weights,
71 const char* name = nullptr) override;
73 ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
74 IConnectableLayer* AddDepthwiseConvolution2dLayer(
75 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
76 const ConstTensor& weights,
77 const ConstTensor& biases,
78 const char* name = nullptr) override;
80 IConnectableLayer* AddDequantizeLayer(const char* name = nullptr) override;
82 IConnectableLayer* AddDetectionPostProcessLayer(
83 const DetectionPostProcessDescriptor& descriptor,
84 const ConstTensor& anchors,
85 const char* name = nullptr) override;
87 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
88 const ConstTensor& weights,
89 const Optional<ConstTensor>& biases,
90 const char* name = nullptr) override;
92 ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
93 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
94 const ConstTensor& weights,
95 const char* name = nullptr) override;
97 ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
98 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
99 const ConstTensor& weights,
100 const ConstTensor& biases,
101 const char* name = nullptr) override;
103 IConnectableLayer* AddGatherLayer(const char* name = nullptr) override;
105 IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
106 const char* name = nullptr) override;
108 IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
109 const char* name = nullptr) override;
111 IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
112 const char* name = nullptr) override;
114 IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
115 const char* name = nullptr) override;
117 IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
118 const char* name = nullptr) override;
120 IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
121 const char* name = nullptr) override;
123 ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead")
124 IConnectableLayer* AddMergerLayer(const MergerDescriptor& mergerDescriptor,
125 const char* name = nullptr) override;
127 IConnectableLayer* AddAbsLayer(const char* name = nullptr) override;
129 IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
131 IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
133 IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
134 const ConstTensor& mean,
135 const ConstTensor& variance,
136 const ConstTensor& beta,
137 const ConstTensor& gamma,
138 const char* name = nullptr) override;
140 ARMNN_DEPRECATED_MSG("Use AddResizeLayer instead")
141 IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
142 const char* name = nullptr) override;
144 IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor,
145 const char* name = nullptr) override;
147 IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
148 const char* name = nullptr) override;
150 IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
152 IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
153 const char* name = nullptr) override;
155 IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
156 const char* name = nullptr) override;
158 IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor,
159 const char* name = nullptr) override;
161 IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
163 IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
165 IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
166 const LstmInputParams& params,
167 const char* name = nullptr) override;
169 IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override;
171 IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override;
173 IConnectableLayer* AddMaximumLayer(const char* name = nullptr) override;
175 IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) override;
177 IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr) override;
179 IConnectableLayer* AddQuantizeLayer(const char* name = nullptr) override;
181 IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
182 const char* name = nullptr) override;
184 IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override;
186 IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override;
188 IConnectableLayer* AddEqualLayer(const char* name = nullptr) override;
190 IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
192 IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
194 IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
196 IConnectableLayer* AddPreluLayer(const char* name = nullptr) override;
198 IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor,
199 const ConstTensor& weights,
200 const Optional<ConstTensor>& biases,
201 const char* name = nullptr) override;
203 IConnectableLayer* AddStackLayer(const StackDescriptor& stackDescriptor,
204 const char* name = nullptr) override;
206 IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
207 const char* name = nullptr) override;
209 void Accept(ILayerVisitor& visitor) const override;
212 IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
213 const ConstTensor& weights,
214 const Optional<ConstTensor>& biases,
217 IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
218 const ConstTensor& weights,
219 const Optional<ConstTensor>& biases,
222 IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
223 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
224 const ConstTensor& weights,
225 const Optional<ConstTensor>& biases,
228 std::unique_ptr<Graph> m_Graph;
231 class OptimizedNetwork final : public IOptimizedNetwork
234 OptimizedNetwork(std::unique_ptr<Graph> graph);
237 Status PrintGraph() override;
238 Status SerializeToDot(std::ostream& stream) const override;
240 Graph& GetGraph() { return *m_Graph; }
243 std::unique_ptr<Graph> m_Graph;
248 struct OptimizationResult
259 using BackendsMap = std::map<BackendId, std::unique_ptr<class IBackendInternal>>;
261 BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRegistry,
262 struct BackendSettings& backendSettings);
264 OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
265 BackendsMap& backends,
266 TensorHandleFactoryRegistry& registry,
267 Optional<std::vector<std::string>&> errMessages);