2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/LstmParams.hpp>
9 #include <armnn/QuantizedLstmParams.hpp>
10 #include <armnn/TensorFwd.hpp>
11 #include <armnn/Types.hpp>
13 #include <armnn/INetwork.hpp>
26 /// Private implementation of INetwork.
27 class Network final : public INetwork
33 const Graph& GetGraph() const { return *m_Graph; }
35 Status PrintGraph() override;
37 IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
39 IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc,
40 const char* name = nullptr) override;
42 IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
43 const char* name = nullptr) override;
45 IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
46 const char* name = nullptr) override;
48 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
49 const ConstTensor& weights,
50 const Optional<ConstTensor>& biases,
51 const char* name = nullptr) override;
53 ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
54 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
55 const ConstTensor& weights,
56 const char* name = nullptr) override;
58 ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
59 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
60 const ConstTensor& weights,
61 const ConstTensor& biases,
62 const char* name = nullptr) override;
64 IConnectableLayer* AddDepthwiseConvolution2dLayer(
65 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
66 const ConstTensor& weights,
67 const Optional<ConstTensor>& biases,
68 const char* name = nullptr) override;
70 ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
71 IConnectableLayer* AddDepthwiseConvolution2dLayer(
72 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
73 const ConstTensor& weights,
74 const char* name = nullptr) override;
76 ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
77 IConnectableLayer* AddDepthwiseConvolution2dLayer(
78 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
79 const ConstTensor& weights,
80 const ConstTensor& biases,
81 const char* name = nullptr) override;
83 IConnectableLayer* AddDequantizeLayer(const char* name = nullptr) override;
85 IConnectableLayer* AddDetectionPostProcessLayer(
86 const DetectionPostProcessDescriptor& descriptor,
87 const ConstTensor& anchors,
88 const char* name = nullptr) override;
90 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
91 const ConstTensor& weights,
92 const Optional<ConstTensor>& biases,
93 const char* name = nullptr) override;
95 ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
96 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
97 const ConstTensor& weights,
98 const char* name = nullptr) override;
100 ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
101 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
102 const ConstTensor& weights,
103 const ConstTensor& biases,
104 const char* name = nullptr) override;
106 IConnectableLayer* AddGatherLayer(const char* name = nullptr) override;
108 IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
109 const char* name = nullptr) override;
111 IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
112 const char* name = nullptr) override;
114 IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
115 const char* name = nullptr) override;
117 IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
118 const char* name = nullptr) override;
120 IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
121 const char* name = nullptr) override;
123 IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
124 const char* name = nullptr) override;
126 ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead")
127 IConnectableLayer* AddMergerLayer(const MergerDescriptor& mergerDescriptor,
128 const char* name = nullptr) override;
130 IConnectableLayer* AddAbsLayer(const char* name = nullptr) override;
132 IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
134 IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
136 IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
137 const ConstTensor& mean,
138 const ConstTensor& variance,
139 const ConstTensor& beta,
140 const ConstTensor& gamma,
141 const char* name = nullptr) override;
143 ARMNN_DEPRECATED_MSG("Use AddResizeLayer instead")
144 IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
145 const char* name = nullptr) override;
147 IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor,
148 const char* name = nullptr) override;
150 IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
151 const char* name = nullptr) override;
153 IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
155 IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
156 const char* name = nullptr) override;
158 IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
159 const char* name = nullptr) override;
161 IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor,
162 const char* name = nullptr) override;
164 IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
166 IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
168 IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
169 const LstmInputParams& params,
170 const char* name = nullptr) override;
172 IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override;
174 IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override;
176 IConnectableLayer* AddMaximumLayer(const char* name = nullptr) override;
178 IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) override;
180 IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr) override;
182 IConnectableLayer* AddQuantizeLayer(const char* name = nullptr) override;
184 IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
185 const char* name = nullptr) override;
187 IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override;
189 IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override;
191 IConnectableLayer* AddEqualLayer(const char* name = nullptr) override;
193 IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
195 IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
197 IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
199 IConnectableLayer* AddPreluLayer(const char* name = nullptr) override;
201 IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor,
202 const ConstTensor& weights,
203 const Optional<ConstTensor>& biases,
204 const char* name = nullptr) override;
206 IConnectableLayer* AddStackLayer(const StackDescriptor& stackDescriptor,
207 const char* name = nullptr) override;
209 IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
210 const char* name = nullptr) override;
212 void Accept(ILayerVisitor& visitor) const override;
215 IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
216 const ConstTensor& weights,
217 const Optional<ConstTensor>& biases,
220 IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
221 const ConstTensor& weights,
222 const Optional<ConstTensor>& biases,
225 IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
226 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
227 const ConstTensor& weights,
228 const Optional<ConstTensor>& biases,
231 std::unique_ptr<Graph> m_Graph;
234 class OptimizedNetwork final : public IOptimizedNetwork
237 OptimizedNetwork(std::unique_ptr<Graph> graph);
240 Status PrintGraph() override;
241 Status SerializeToDot(std::ostream& stream) const override;
243 Graph& GetGraph() { return *m_Graph; }
246 std::unique_ptr<Graph> m_Graph;
251 struct OptimizationResult
262 using BackendsMap = std::map<BackendId, std::unique_ptr<class IBackendInternal>>;
264 BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRegistry,
265 struct BackendSettings& backendSettings);
267 OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
268 BackendsMap& backends,
269 TensorHandleFactoryRegistry& registry,
270 Optional<std::vector<std::string>&> errMessages);