2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/LstmParams.hpp>
9 #include <armnn/TensorFwd.hpp>
10 #include <armnn/Types.hpp>
12 #include <armnn/INetwork.hpp>
25 /// Private implementation of INetwork.
26 class Network final : public INetwork
32 const Graph& GetGraph() const { return *m_Graph; }
34 Status PrintGraph() override;
36 IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
38 IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
39 const char* name = nullptr) override;
41 IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
42 const char* name = nullptr) override;
44 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
45 const ConstTensor& weights,
46 const Optional<ConstTensor>& biases,
47 const char* name = nullptr) override;
49 ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
50 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
51 const ConstTensor& weights,
52 const char* name = nullptr) override;
54 ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
55 IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
56 const ConstTensor& weights,
57 const ConstTensor& biases,
58 const char* name = nullptr) override;
60 IConnectableLayer* AddDepthwiseConvolution2dLayer(
61 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
62 const ConstTensor& weights,
63 const Optional<ConstTensor>& biases,
64 const char* name = nullptr) override;
66 ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
67 IConnectableLayer* AddDepthwiseConvolution2dLayer(
68 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
69 const ConstTensor& weights,
70 const char* name = nullptr) override;
72 ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
73 IConnectableLayer* AddDepthwiseConvolution2dLayer(
74 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
75 const ConstTensor& weights,
76 const ConstTensor& biases,
77 const char* name = nullptr) override;
79 IConnectableLayer* AddDequantizeLayer(const char* name = nullptr) override;
81 IConnectableLayer* AddDetectionPostProcessLayer(
82 const DetectionPostProcessDescriptor& descriptor,
83 const ConstTensor& anchors,
84 const char* name = nullptr) override;
86 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
87 const ConstTensor& weights,
88 const Optional<ConstTensor>& biases,
89 const char* name = nullptr) override;
91 ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
92 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
93 const ConstTensor& weights,
94 const char* name = nullptr) override;
96 ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
97 IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
98 const ConstTensor& weights,
99 const ConstTensor& biases,
100 const char* name = nullptr) override;
102 IConnectableLayer* AddGatherLayer(const char* name = nullptr) override;
104 IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
105 const char* name = nullptr) override;
107 IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
108 const char* name = nullptr) override;
110 IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
111 const char* name = nullptr) override;
113 IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
114 const char* name = nullptr) override;
116 IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
117 const char* name = nullptr) override;
119 IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
120 const char* name = nullptr) override;
122 ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead")
123 IConnectableLayer* AddMergerLayer(const MergerDescriptor& mergerDescriptor,
124 const char* name = nullptr) override;
126 IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
128 IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
130 IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
131 const ConstTensor& mean,
132 const ConstTensor& variance,
133 const ConstTensor& beta,
134 const ConstTensor& gamma,
135 const char* name = nullptr) override;
137 ARMNN_DEPRECATED_MSG("Use AddResizeLayer instead")
138 IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
139 const char* name = nullptr) override;
141 IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor,
142 const char* name = nullptr) override;
144 IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
145 const char* name = nullptr) override;
147 IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
149 IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
150 const char* name = nullptr) override;
152 IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
153 const char* name = nullptr) override;
155 IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor,
156 const char* name = nullptr) override;
158 IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
160 IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
162 IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
163 const LstmInputParams& params,
164 const char* name = nullptr) override;
166 IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override;
168 IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override;
170 IConnectableLayer* AddMaximumLayer(const char* name = nullptr) override;
172 IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) override;
174 IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr) override;
176 IConnectableLayer* AddQuantizeLayer(const char* name = nullptr) override;
178 IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
179 const char* name = nullptr) override;
181 IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override;
183 IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override;
185 IConnectableLayer* AddEqualLayer(const char* name = nullptr) override;
187 IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
189 IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
191 IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
193 IConnectableLayer* AddPreluLayer(const char* name = nullptr) override;
195 IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor,
196 const ConstTensor& weights,
197 const Optional<ConstTensor>& biases,
198 const char* name = nullptr) override;
200 void Accept(ILayerVisitor& visitor) const override;
203 IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
204 const ConstTensor& weights,
205 const Optional<ConstTensor>& biases,
208 IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
209 const ConstTensor& weights,
210 const Optional<ConstTensor>& biases,
213 IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
214 const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
215 const ConstTensor& weights,
216 const Optional<ConstTensor>& biases,
219 std::unique_ptr<Graph> m_Graph;
222 class OptimizedNetwork final : public IOptimizedNetwork
225 OptimizedNetwork(std::unique_ptr<Graph> graph);
228 Status PrintGraph() override;
229 Status SerializeToDot(std::ostream& stream) const override;
231 Graph& GetGraph() { return *m_Graph; }
234 std::unique_ptr<Graph> m_Graph;
239 struct OptimizationResult
250 using BackendsMap = std::map<BackendId, std::unique_ptr<class IBackendInternal>>;
252 BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRegistry,
253 struct BackendSettings& backendSettings);
255 OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
256 BackendsMap& backends,
257 TensorHandleFactoryRegistry& registry,
258 Optional<std::vector<std::string>&> errMessages);