2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/ILayerVisitor.hpp>
8 #include <armnn/LayerVisitorBase.hpp>
10 #include <armnnSerializer/ISerializer.hpp>
12 #include <unordered_map>
14 #include <ArmnnSchema_generated.h>
16 namespace armnnSerializer
19 class SerializerVisitor : public armnn::ILayerVisitor
22 SerializerVisitor() : m_layerId(0) {}
23 ~SerializerVisitor() {}
25 flatbuffers::FlatBufferBuilder& GetFlatBufferBuilder()
27 return m_flatBufferBuilder;
30 std::vector<unsigned int>& GetInputIds()
35 std::vector<unsigned int>& GetOutputIds()
40 std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>>& GetSerializedLayers()
42 return m_serializedLayers;
45 void VisitAbsLayer(const armnn::IConnectableLayer* layer,
46 const char* name = nullptr) override;
48 void VisitActivationLayer(const armnn::IConnectableLayer* layer,
49 const armnn::ActivationDescriptor& descriptor,
50 const char* name = nullptr) override;
52 void VisitAdditionLayer(const armnn::IConnectableLayer* layer,
53 const char* name = nullptr) override;
55 void VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
56 const armnn::BatchToSpaceNdDescriptor& descriptor,
57 const char* name = nullptr) override;
59 void VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
60 const armnn::BatchNormalizationDescriptor& BatchNormalizationDescriptor,
61 const armnn::ConstTensor& mean,
62 const armnn::ConstTensor& variance,
63 const armnn::ConstTensor& beta,
64 const armnn::ConstTensor& gamma,
65 const char* name = nullptr) override;
67 void VisitConcatLayer(const armnn::IConnectableLayer* layer,
68 const armnn::ConcatDescriptor& concatDescriptor,
69 const char* name = nullptr) override;
71 void VisitConstantLayer(const armnn::IConnectableLayer* layer,
72 const armnn::ConstTensor& input,
73 const char* = nullptr) override;
75 void VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
76 const armnn::Convolution2dDescriptor& descriptor,
77 const armnn::ConstTensor& weights,
78 const armnn::Optional<armnn::ConstTensor>& biases,
79 const char* = nullptr) override;
81 void VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
82 const armnn::DepthwiseConvolution2dDescriptor& descriptor,
83 const armnn::ConstTensor& weights,
84 const armnn::Optional<armnn::ConstTensor>& biases,
85 const char* name = nullptr) override;
87 void VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
88 const char* name = nullptr) override;
90 void VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
91 const armnn::DetectionPostProcessDescriptor& descriptor,
92 const armnn::ConstTensor& anchors,
93 const char* name = nullptr) override;
95 void VisitDivisionLayer(const armnn::IConnectableLayer* layer,
96 const char* name = nullptr) override;
98 void VisitEqualLayer(const armnn::IConnectableLayer* layer,
99 const char* name = nullptr) override;
101 void VisitFloorLayer(const armnn::IConnectableLayer *layer,
102 const char *name = nullptr) override;
104 void VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
105 const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
106 const armnn::ConstTensor& weights,
107 const armnn::Optional<armnn::ConstTensor>& biases,
108 const char* name = nullptr) override;
110 void VisitGatherLayer(const armnn::IConnectableLayer* layer,
111 const char* name = nullptr) override;
113 void VisitGreaterLayer(const armnn::IConnectableLayer* layer,
114 const char* name = nullptr) override;
116 void VisitInputLayer(const armnn::IConnectableLayer* layer,
117 armnn::LayerBindingId id,
118 const char* name = nullptr) override;
120 void VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
121 const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
122 const char* name = nullptr) override;
124 void VisitLstmLayer(const armnn::IConnectableLayer* layer,
125 const armnn::LstmDescriptor& descriptor,
126 const armnn::LstmInputParams& params,
127 const char* name = nullptr) override;
129 void VisitMeanLayer(const armnn::IConnectableLayer* layer,
130 const armnn::MeanDescriptor& descriptor,
131 const char* name) override;
133 void VisitMinimumLayer(const armnn::IConnectableLayer* layer,
134 const char* name = nullptr) override;
136 void VisitMaximumLayer(const armnn::IConnectableLayer* layer,
137 const char* name = nullptr) override;
139 void VisitMergeLayer(const armnn::IConnectableLayer* layer,
140 const char* name = nullptr) override;
142 ARMNN_DEPRECATED_MSG("Use VisitConcatLayer instead")
143 void VisitMergerLayer(const armnn::IConnectableLayer* layer,
144 const armnn::MergerDescriptor& mergerDescriptor,
145 const char* name = nullptr) override;
147 void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer,
148 const char* name = nullptr) override;
150 void VisitOutputLayer(const armnn::IConnectableLayer* layer,
151 armnn::LayerBindingId id,
152 const char* name = nullptr) override;
154 void VisitPadLayer(const armnn::IConnectableLayer* layer,
155 const armnn::PadDescriptor& PadDescriptor,
156 const char* name = nullptr) override;
158 void VisitPermuteLayer(const armnn::IConnectableLayer* layer,
159 const armnn::PermuteDescriptor& PermuteDescriptor,
160 const char* name = nullptr) override;
162 void VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
163 const armnn::Pooling2dDescriptor& pooling2dDescriptor,
164 const char* name = nullptr) override;
166 void VisitPreluLayer(const armnn::IConnectableLayer* layer,
167 const char* name = nullptr) override;
169 void VisitQuantizeLayer(const armnn::IConnectableLayer* layer,
170 const char* name = nullptr) override;
172 void VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
173 const armnn::QuantizedLstmInputParams& params,
174 const char* name = nullptr) override;
176 void VisitReshapeLayer(const armnn::IConnectableLayer* layer,
177 const armnn::ReshapeDescriptor& reshapeDescriptor,
178 const char* name = nullptr) override;
180 void VisitResizeLayer(const armnn::IConnectableLayer* layer,
181 const armnn::ResizeDescriptor& resizeDescriptor,
182 const char* name = nullptr) override;
184 ARMNN_DEPRECATED_MSG("Use VisitResizeLayer instead")
185 void VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
186 const armnn::ResizeBilinearDescriptor& resizeDescriptor,
187 const char* name = nullptr) override;
189 void VisitRsqrtLayer(const armnn::IConnectableLayer* layer,
190 const char* name = nullptr) override;
192 void VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
193 const armnn::SoftmaxDescriptor& softmaxDescriptor,
194 const char* name = nullptr) override;
196 void VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
197 const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
198 const char* name = nullptr) override;
200 void VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
201 const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
202 const char* name = nullptr) override;
204 void VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
205 const armnn::NormalizationDescriptor& normalizationDescriptor,
206 const char* name = nullptr) override;
208 void VisitSplitterLayer(const armnn::IConnectableLayer* layer,
209 const armnn::ViewsDescriptor& viewsDescriptor,
210 const char* name = nullptr) override;
212 void VisitStackLayer(const armnn::IConnectableLayer* layer,
213 const armnn::StackDescriptor& stackDescriptor,
214 const char* name = nullptr) override;
216 void VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
217 const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
218 const char* name = nullptr) override;
220 void VisitSubtractionLayer(const armnn::IConnectableLayer* layer,
221 const char* name = nullptr) override;
223 void VisitSwitchLayer(const armnn::IConnectableLayer* layer,
224 const char* name = nullptr) override;
226 void VisitTransposeConvolution2dLayer(const armnn::IConnectableLayer* layer,
227 const armnn::TransposeConvolution2dDescriptor& descriptor,
228 const armnn::ConstTensor& weights,
229 const armnn::Optional<armnn::ConstTensor>& biases,
230 const char* = nullptr) override;
234 /// Creates the Input Slots and Output Slots and LayerBase for the layer.
235 flatbuffers::Offset<armnnSerializer::LayerBase> CreateLayerBase(
236 const armnn::IConnectableLayer* layer,
237 const armnnSerializer::LayerType layerType);
239 /// Creates the serializer AnyLayer for the layer and adds it to m_serializedLayers.
240 void CreateAnyLayer(const flatbuffers::Offset<void>& layer, const armnnSerializer::Layer serializerLayer);
242 /// Creates the serializer ConstTensor for the armnn ConstTensor.
243 flatbuffers::Offset<armnnSerializer::ConstTensor> CreateConstTensorInfo(
244 const armnn::ConstTensor& constTensor);
246 template <typename T>
247 flatbuffers::Offset<flatbuffers::Vector<T>> CreateDataVector(const void* memory, unsigned int size);
249 ///Function which maps Guid to an index
250 uint32_t GetSerializedId(unsigned int guid);
252 /// Creates the serializer InputSlots for the layer.
253 std::vector<flatbuffers::Offset<armnnSerializer::InputSlot>> CreateInputSlots(
254 const armnn::IConnectableLayer* layer);
256 /// Creates the serializer OutputSlots for the layer.
257 std::vector<flatbuffers::Offset<armnnSerializer::OutputSlot>> CreateOutputSlots(
258 const armnn::IConnectableLayer* layer);
260 /// FlatBufferBuilder to create our layers' FlatBuffers.
261 flatbuffers::FlatBufferBuilder m_flatBufferBuilder;
263 /// AnyLayers required by the SerializedGraph.
264 std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>> m_serializedLayers;
266 /// Guids of all Input Layers required by the SerializedGraph.
267 std::vector<unsigned int> m_inputIds;
269 /// Guids of all Output Layers required by the SerializedGraph.
270 std::vector<unsigned int> m_outputIds;
272 /// Mapped Guids of all Layers to match our index.
273 std::unordered_map<unsigned int, uint32_t > m_guidMap;
275 /// layer within our FlatBuffer index.
279 class Serializer : public ISerializer
285 /// Serializes the network to ArmNN SerializedGraph.
286 /// @param [in] inNetwork The network to be serialized.
287 void Serialize(const armnn::INetwork& inNetwork) override;
289 /// Serializes the SerializedGraph to the stream.
290 /// @param [stream] the stream to save to
291 /// @return true if graph is Serialized to the Stream, false otherwise
292 bool SaveSerializedToStream(std::ostream& stream) override;
296 /// Visitor to contruct serialized network
297 SerializerVisitor m_SerializerVisitor;
300 } //namespace armnnSerializer