Release 18.03
[platform/upstream/armnn.git] / include / armnnTfParser / ITfParser.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "armnn/Types.hpp"
8 #include "armnn/Tensor.hpp"
9 #include "armnn/INetwork.hpp"
10
11 #include <map>
12 #include <memory>
13 #include <unordered_map>
14 #include <vector>
15
16 namespace armnnTfParser
17 {
18
19 using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
20
21 class ITfParser;
22 using ITfParserPtr = std::unique_ptr<ITfParser, void(*)(ITfParser* parser)>;
23
24 /// parses a directed acyclic graph from a tensorflow protobuf file
25 class ITfParser
26 {
27 public:
28     static ITfParser* CreateRaw();
29     static ITfParserPtr Create();
30     static void Destroy(ITfParser* parser);
31
32     /// Create the network from a protobuf text file on disk
33     virtual armnn::INetworkPtr CreateNetworkFromTextFile(
34         const char* graphFile,
35         const std::map<std::string, armnn::TensorShape>& inputShapes,
36         const std::vector<std::string>& requestedOutputs) = 0;
37
38     /// Create the network from a protobuf binary file on disk
39     virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
40         const char* graphFile,
41         const std::map<std::string, armnn::TensorShape>& inputShapes,
42         const std::vector<std::string>& requestedOutputs) = 0;
43
44     /// Create the network directly from protobuf text in a string. Useful for debugging/testing
45     virtual armnn::INetworkPtr CreateNetworkFromString(
46         const char* protoText,
47         const std::map<std::string, armnn::TensorShape>& inputShapes,
48         const std::vector<std::string>& requestedOutputs) = 0;
49
50     /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
51     virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const = 0;
52
53     /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
54     virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const = 0;
55
56 protected:
57     virtual ~ITfParser() {};
58 };
59
60 }