Release 18.08
[platform/upstream/armnn.git] / src / armnnTfLiteParser / TfLiteParser.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "armnn/INetwork.hpp"
8 #include "armnnTfLiteParser/ITfLiteParser.hpp"
9
10 #include <schema_generated.h>
11 #include <functional>
12 #include <vector>
13
14 namespace armnnTfLiteParser
15 {
16
17 class TfLiteParser : public ITfLiteParser
18 {
19 public:
20     // Shorthands for TfLite types
21     using ModelPtr = std::unique_ptr<tflite::ModelT>;
22     using SubGraphPtr = std::unique_ptr<tflite::SubGraphT>;
23     using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
24     using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
25     using TensorPtr = std::unique_ptr<tflite::TensorT>;
26     using TensorRawPtr = const tflite::TensorT *;
27     using TensorRawPtrVector = std::vector<TensorRawPtr>;
28     using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
29     using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
30     using BufferPtr = std::unique_ptr<tflite::BufferT>;
31     using BufferRawPtr = const tflite::BufferT *;
32
33 public:
34     /// Create the network from a flatbuffers binary file on disk
35     virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
36
37     /// Create the network from a flatbuffers binary
38     virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) override;
39
40
41     /// Retrieve binding info (layer id and tensor info) for the network input identified by
42     /// the given layer name and subgraph id
43     virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
44                                                         const std::string& name) const override;
45
46     /// Retrieve binding info (layer id and tensor info) for the network output identified by
47     /// the given layer name and subgraph id
48     virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
49                                                          const std::string& name) const override;
50
51     /// Return the number of subgraphs in the parsed model
52     virtual size_t GetSubgraphCount() const override;
53
54     /// Return the input tensor names for a given subgraph
55     virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const override;
56
57     /// Return the output tensor names for a given subgraph
58     virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const override;
59
60     TfLiteParser();
61     virtual ~TfLiteParser() {}
62
63 public:
64     // testable helpers
65     static ModelPtr LoadModelFromFile(const char * fileName);
66     static ModelPtr LoadModelFromBinary(const uint8_t * binaryContent, size_t len);
67     static TensorRawPtrVector GetInputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
68     static TensorRawPtrVector GetOutputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
69     static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr & model, size_t subgraphIndex);
70     static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr & model, size_t subgraphIndex);
71     static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
72     static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
73
74     static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
75     static armnn::TensorInfo OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDims,
76                                                   const armnn::TensorInfo & inputTensorInfo);
77
78
79 private:
80     // No copying allowed until it is wanted and properly implemented
81     TfLiteParser(const TfLiteParser &) = delete;
82     TfLiteParser & operator=(const TfLiteParser &) = delete;
83
84     /// Create the network from an already loaded flatbuffers model
85     armnn::INetworkPtr CreateNetworkFromModel();
86
87     // signature for the parser functions
88     using OperatorParsingFunction = void(TfLiteParser::*)(size_t subgraphIndex, size_t operatorIndex);
89
90     void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
91     void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
92     void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
93     void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
94     void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
95     void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
96
97     void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
98     void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
99     void RegisterInputSlots(size_t subgraphIndex,
100                             size_t operatorIndex,
101                             armnn::IConnectableLayer* layer,
102                             const std::vector<unsigned int>& tensorIndexes);
103     void RegisterOutputSlots(size_t subgraphIndex,
104                              size_t operatorIndex,
105                              armnn::IConnectableLayer* layer,
106                              const std::vector<unsigned int>& tensorIndexes);
107
108     void SetupInputLayers(size_t subgraphIndex);
109     void SetupOutputLayers(size_t subgraphIndex);
110
111     void ResetParser();
112
113     /// Attach an activation layer to the one passed as a parameter
114     armnn::IConnectableLayer* AddActivationLayer(armnn::IConnectableLayer* layer,
115                                                  unsigned int outputSlot,
116                                                  tflite::ActivationFunctionType activationType);
117
118     // SupportedDataStorage's purpose is to hold data till we pass over to the network.
119     // We don't care about the content, and we want a single datatype to simplify the code.
120     struct SupportedDataStorage
121     {
122         std::unique_ptr<float[]>    m_FloatData;
123         std::unique_ptr<uint8_t[]>  m_Uint8Data;
124         std::unique_ptr<int32_t[]>  m_Int32Data;
125
126         SupportedDataStorage(std::unique_ptr<float[]> && data);
127         SupportedDataStorage(std::unique_ptr<uint8_t[]> && data);
128         SupportedDataStorage(std::unique_ptr<int32_t[]> && data);
129     };
130
131     std::pair<armnn::ConstTensor, SupportedDataStorage> CreateConstTensor(TensorRawPtr tensorPtr,
132                                                                           armnn::TensorInfo & tensorInfo,
133                                                                           bool convertFromTfToArmnnFormat);
134
135     /// The network we're building. Gets cleared after it is passed to the user
136     armnn::INetworkPtr                    m_Network;
137     std::vector<OperatorParsingFunction>  m_ParserFunctions;
138     ModelPtr                              m_Model;
139
140     /// A mapping of an output slot to each of the input slots it should be connected to
141     /// The outputSlot is from the layer that creates this tensor as one of its ouputs
142     /// The inputSlots are from the layers that use this tensor as one of their inputs
143     struct TensorSlots
144     {
145         armnn::IOutputSlot* outputSlot;
146         std::vector<armnn::IInputSlot*> inputSlots;
147
148         TensorSlots() : outputSlot(nullptr) { }
149     };
150     typedef std::vector<TensorSlots> TensorConnections;
151     /// Connections for tensors in each subgraph
152     /// The first index is the subgraph ID, the second index is the tensor ID
153     std::vector<TensorConnections> m_SubgraphConnections;
154 };
155
156 }