2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include "armnn/Types.hpp"
8 #include "armnn/Tensor.hpp"
9 #include "armnn/INetwork.hpp"
13 #include <unordered_map>
16 namespace armnnTfParser
19 using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
22 using ITfParserPtr = std::unique_ptr<ITfParser, void(*)(ITfParser* parser)>;
24 /// parses a directed acyclic graph from a tensorflow protobuf file
28 static ITfParser* CreateRaw();
29 static ITfParserPtr Create();
30 static void Destroy(ITfParser* parser);
32 /// Create the network from a protobuf text file on disk
33 virtual armnn::INetworkPtr CreateNetworkFromTextFile(
34 const char* graphFile,
35 const std::map<std::string, armnn::TensorShape>& inputShapes,
36 const std::vector<std::string>& requestedOutputs) = 0;
38 /// Create the network from a protobuf binary file on disk
39 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
40 const char* graphFile,
41 const std::map<std::string, armnn::TensorShape>& inputShapes,
42 const std::vector<std::string>& requestedOutputs) = 0;
44 /// Create the network directly from protobuf text in a string. Useful for debugging/testing
45 virtual armnn::INetworkPtr CreateNetworkFromString(
46 const char* protoText,
47 const std::map<std::string, armnn::TensorShape>& inputShapes,
48 const std::vector<std::string>& requestedOutputs) = 0;
50 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
51 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const = 0;
53 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
54 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const = 0;
57 virtual ~ITfParser() {};