2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/ILayerVisitor.hpp>
8 #include <armnn/LayerVisitorBase.hpp>
10 #include <armnnSerializer/ISerializer.hpp>
12 #include <unordered_map>
14 #include <ArmnnSchema_generated.h>
16 namespace armnnSerializer
19 class SerializerVisitor : public armnn::ILayerVisitor
22 SerializerVisitor() : m_layerId(0) {}
23 ~SerializerVisitor() {}
25 flatbuffers::FlatBufferBuilder& GetFlatBufferBuilder()
27 return m_flatBufferBuilder;
30 std::vector<unsigned int>& GetInputIds()
35 std::vector<unsigned int>& GetOutputIds()
40 std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>>& GetSerializedLayers()
42 return m_serializedLayers;
45 void VisitAbsLayer(const armnn::IConnectableLayer* layer,
46 const char* name = nullptr) override;
48 void VisitActivationLayer(const armnn::IConnectableLayer* layer,
49 const armnn::ActivationDescriptor& descriptor,
50 const char* name = nullptr) override;
52 void VisitAdditionLayer(const armnn::IConnectableLayer* layer,
53 const char* name = nullptr) override;
55 void VisitArgMinMaxLayer(const armnn::IConnectableLayer* layer,
56 const armnn::ArgMinMaxDescriptor& argMinMaxDescriptor,
57 const char* name = nullptr) override;
59 void VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
60 const armnn::BatchToSpaceNdDescriptor& descriptor,
61 const char* name = nullptr) override;
63 void VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
64 const armnn::BatchNormalizationDescriptor& BatchNormalizationDescriptor,
65 const armnn::ConstTensor& mean,
66 const armnn::ConstTensor& variance,
67 const armnn::ConstTensor& beta,
68 const armnn::ConstTensor& gamma,
69 const char* name = nullptr) override;
71 void VisitConcatLayer(const armnn::IConnectableLayer* layer,
72 const armnn::ConcatDescriptor& concatDescriptor,
73 const char* name = nullptr) override;
75 void VisitConstantLayer(const armnn::IConnectableLayer* layer,
76 const armnn::ConstTensor& input,
77 const char* = nullptr) override;
79 void VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
80 const armnn::Convolution2dDescriptor& descriptor,
81 const armnn::ConstTensor& weights,
82 const armnn::Optional<armnn::ConstTensor>& biases,
83 const char* = nullptr) override;
85 void VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
86 const armnn::DepthwiseConvolution2dDescriptor& descriptor,
87 const armnn::ConstTensor& weights,
88 const armnn::Optional<armnn::ConstTensor>& biases,
89 const char* name = nullptr) override;
91 void VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
92 const char* name = nullptr) override;
94 void VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
95 const armnn::DetectionPostProcessDescriptor& descriptor,
96 const armnn::ConstTensor& anchors,
97 const char* name = nullptr) override;
99 void VisitDivisionLayer(const armnn::IConnectableLayer* layer,
100 const char* name = nullptr) override;
102 void VisitEqualLayer(const armnn::IConnectableLayer* layer,
103 const char* name = nullptr) override;
105 void VisitFloorLayer(const armnn::IConnectableLayer *layer,
106 const char *name = nullptr) override;
108 void VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
109 const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
110 const armnn::ConstTensor& weights,
111 const armnn::Optional<armnn::ConstTensor>& biases,
112 const char* name = nullptr) override;
114 void VisitGatherLayer(const armnn::IConnectableLayer* layer,
115 const char* name = nullptr) override;
117 void VisitGreaterLayer(const armnn::IConnectableLayer* layer,
118 const char* name = nullptr) override;
120 void VisitInputLayer(const armnn::IConnectableLayer* layer,
121 armnn::LayerBindingId id,
122 const char* name = nullptr) override;
124 void VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
125 const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
126 const char* name = nullptr) override;
128 void VisitLstmLayer(const armnn::IConnectableLayer* layer,
129 const armnn::LstmDescriptor& descriptor,
130 const armnn::LstmInputParams& params,
131 const char* name = nullptr) override;
133 void VisitMeanLayer(const armnn::IConnectableLayer* layer,
134 const armnn::MeanDescriptor& descriptor,
135 const char* name) override;
137 void VisitMinimumLayer(const armnn::IConnectableLayer* layer,
138 const char* name = nullptr) override;
140 void VisitMaximumLayer(const armnn::IConnectableLayer* layer,
141 const char* name = nullptr) override;
143 void VisitMergeLayer(const armnn::IConnectableLayer* layer,
144 const char* name = nullptr) override;
146 ARMNN_DEPRECATED_MSG("Use VisitConcatLayer instead")
147 void VisitMergerLayer(const armnn::IConnectableLayer* layer,
148 const armnn::MergerDescriptor& mergerDescriptor,
149 const char* name = nullptr) override;
151 void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer,
152 const char* name = nullptr) override;
154 void VisitOutputLayer(const armnn::IConnectableLayer* layer,
155 armnn::LayerBindingId id,
156 const char* name = nullptr) override;
158 void VisitPadLayer(const armnn::IConnectableLayer* layer,
159 const armnn::PadDescriptor& PadDescriptor,
160 const char* name = nullptr) override;
162 void VisitPermuteLayer(const armnn::IConnectableLayer* layer,
163 const armnn::PermuteDescriptor& PermuteDescriptor,
164 const char* name = nullptr) override;
166 void VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
167 const armnn::Pooling2dDescriptor& pooling2dDescriptor,
168 const char* name = nullptr) override;
170 void VisitPreluLayer(const armnn::IConnectableLayer* layer,
171 const char* name = nullptr) override;
173 void VisitQuantizeLayer(const armnn::IConnectableLayer* layer,
174 const char* name = nullptr) override;
176 void VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
177 const armnn::QuantizedLstmInputParams& params,
178 const char* name = nullptr) override;
180 void VisitReshapeLayer(const armnn::IConnectableLayer* layer,
181 const armnn::ReshapeDescriptor& reshapeDescriptor,
182 const char* name = nullptr) override;
184 void VisitResizeLayer(const armnn::IConnectableLayer* layer,
185 const armnn::ResizeDescriptor& resizeDescriptor,
186 const char* name = nullptr) override;
188 ARMNN_DEPRECATED_MSG("Use VisitResizeLayer instead")
189 void VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
190 const armnn::ResizeBilinearDescriptor& resizeDescriptor,
191 const char* name = nullptr) override;
193 void VisitRsqrtLayer(const armnn::IConnectableLayer* layer,
194 const char* name = nullptr) override;
196 void VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
197 const armnn::SoftmaxDescriptor& softmaxDescriptor,
198 const char* name = nullptr) override;
200 void VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
201 const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
202 const char* name = nullptr) override;
204 void VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
205 const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
206 const char* name = nullptr) override;
208 void VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
209 const armnn::NormalizationDescriptor& normalizationDescriptor,
210 const char* name = nullptr) override;
212 void VisitSplitterLayer(const armnn::IConnectableLayer* layer,
213 const armnn::ViewsDescriptor& viewsDescriptor,
214 const char* name = nullptr) override;
216 void VisitStackLayer(const armnn::IConnectableLayer* layer,
217 const armnn::StackDescriptor& stackDescriptor,
218 const char* name = nullptr) override;
220 void VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
221 const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
222 const char* name = nullptr) override;
224 void VisitSubtractionLayer(const armnn::IConnectableLayer* layer,
225 const char* name = nullptr) override;
227 void VisitSwitchLayer(const armnn::IConnectableLayer* layer,
228 const char* name = nullptr) override;
230 void VisitTransposeConvolution2dLayer(const armnn::IConnectableLayer* layer,
231 const armnn::TransposeConvolution2dDescriptor& descriptor,
232 const armnn::ConstTensor& weights,
233 const armnn::Optional<armnn::ConstTensor>& biases,
234 const char* = nullptr) override;
238 /// Creates the Input Slots and Output Slots and LayerBase for the layer.
239 flatbuffers::Offset<armnnSerializer::LayerBase> CreateLayerBase(
240 const armnn::IConnectableLayer* layer,
241 const armnnSerializer::LayerType layerType);
243 /// Creates the serializer AnyLayer for the layer and adds it to m_serializedLayers.
244 void CreateAnyLayer(const flatbuffers::Offset<void>& layer, const armnnSerializer::Layer serializerLayer);
246 /// Creates the serializer ConstTensor for the armnn ConstTensor.
247 flatbuffers::Offset<armnnSerializer::ConstTensor> CreateConstTensorInfo(
248 const armnn::ConstTensor& constTensor);
250 template <typename T>
251 flatbuffers::Offset<flatbuffers::Vector<T>> CreateDataVector(const void* memory, unsigned int size);
253 ///Function which maps Guid to an index
254 uint32_t GetSerializedId(unsigned int guid);
256 /// Creates the serializer InputSlots for the layer.
257 std::vector<flatbuffers::Offset<armnnSerializer::InputSlot>> CreateInputSlots(
258 const armnn::IConnectableLayer* layer);
260 /// Creates the serializer OutputSlots for the layer.
261 std::vector<flatbuffers::Offset<armnnSerializer::OutputSlot>> CreateOutputSlots(
262 const armnn::IConnectableLayer* layer);
264 /// FlatBufferBuilder to create our layers' FlatBuffers.
265 flatbuffers::FlatBufferBuilder m_flatBufferBuilder;
267 /// AnyLayers required by the SerializedGraph.
268 std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>> m_serializedLayers;
270 /// Guids of all Input Layers required by the SerializedGraph.
271 std::vector<unsigned int> m_inputIds;
273 /// Guids of all Output Layers required by the SerializedGraph.
274 std::vector<unsigned int> m_outputIds;
276 /// Mapped Guids of all Layers to match our index.
277 std::unordered_map<unsigned int, uint32_t > m_guidMap;
279 /// layer within our FlatBuffer index.
283 class Serializer : public ISerializer
289 /// Serializes the network to ArmNN SerializedGraph.
290 /// @param [in] inNetwork The network to be serialized.
291 void Serialize(const armnn::INetwork& inNetwork) override;
293 /// Serializes the SerializedGraph to the stream.
294 /// @param [stream] the stream to save to
295 /// @return true if graph is Serialized to the Stream, false otherwise
296 bool SaveSerializedToStream(std::ostream& stream) override;
300 /// Visitor to contruct serialized network
301 SerializerVisitor m_SerializerVisitor;
304 } //namespace armnnSerializer