Release 18.02
[platform/upstream/armnn.git] / src / armnnUtils / ParserPrototxtFixture.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #pragma once
7
8 #include "armnn/IRuntime.hpp"
9 #include "test/TensorHelpers.hpp"
10 #include <string>
11
12 template<typename TParser>
13 struct ParserPrototxtFixture
14 {
15     ParserPrototxtFixture()
16         : m_Parser(TParser::Create())
17         , m_Runtime(armnn::IRuntime::Create(armnn::Compute::CpuRef))
18         , m_NetworkIdentifier(-1)
19     {}
20
21     /// Parses and loads the network defined by the m_Prototext string.
22     /// @{
23     void SetupSingleInputSingleOutput(const std::string& inputName, const std::string& outputName);
24     void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
25         const std::string& inputName,
26         const std::string& outputName);
27     void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
28         const std::vector<std::string>& requestedOutputs);
29     /// @}
30
31     /// Executes the network with the given input tensor and checks the result against the given output tensor.
32     /// This overload assumes the network has a single input and a single output.
33     template <std::size_t NumOutputDimensions>
34     void RunTest(const std::vector<float>& inputData, const std::vector<float>& expectedOutputData);
35
36     /// Executes the network with the given input tensors and checks the results against the given output tensors.
37     /// This overload supports multiple inputs and multiple outputs, identified by name.
38     template <std::size_t NumOutputDimensions>
39     void RunTest(const std::map<std::string, std::vector<float>>& inputData,
40         const std::map<std::string, std::vector<float>>& expectedOutputData);
41
42     std::string                 m_Prototext;
43     std::unique_ptr<TParser, void(*)(TParser* parser)> m_Parser;
44     armnn::IRuntimePtr          m_Runtime;
45     armnn::NetworkId            m_NetworkIdentifier;
46
47     /// If the single-input-single-output overload of Setup() is called, these will store the input and output name
48     /// so they don't need to be passed to the single-input-single-output overload of RunTest().
49     /// @{
50     std::string m_SingleInputName;
51     std::string m_SingleOutputName;
52     /// @}
53 };
54
55 template<typename TParser>
56 void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const std::string& inputName,
57     const std::string& outputName)
58 {
59     // Store the input and output name so they don't need to be passed to the single-input-single-output RunTest().
60     m_SingleInputName = inputName;
61     m_SingleOutputName = outputName;
62     Setup({ }, { outputName });
63 }
64
65 template<typename TParser>
66 void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
67     const std::string& inputName,
68     const std::string& outputName)
69 {
70     // Store the input and output name so they don't need to be passed to the single-input-single-output RunTest().
71     m_SingleInputName = inputName;
72     m_SingleOutputName = outputName;
73     Setup({ { inputName, inputTensorShape } }, { outputName });
74 }
75
76 template<typename TParser>
77 void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
78     const std::vector<std::string>& requestedOutputs)
79 {
80     armnn::INetworkPtr network =
81         m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes, requestedOutputs);
82
83     auto optimized = Optimize(*network, m_Runtime->GetDeviceSpec());
84     armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, move(optimized));
85     if (ret != armnn::Status::Success)
86     {
87         throw armnn::Exception("LoadNetwork failed");
88     }
89 }
90
91 template<typename TParser>
92 template <std::size_t NumOutputDimensions>
93 void ParserPrototxtFixture<TParser>::RunTest(const std::vector<float>& inputData,
94     const std::vector<float>& expectedOutputData)
95 {
96     RunTest<NumOutputDimensions>({ { m_SingleInputName, inputData } }, { { m_SingleOutputName, expectedOutputData } });
97 }
98
99 template<typename TParser>
100 template <std::size_t NumOutputDimensions>
101 void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::vector<float>>& inputData,
102     const std::map<std::string, std::vector<float>>& expectedOutputData)
103 {
104     using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
105
106     // Setup the armnn input tensors from the given vectors.
107     armnn::InputTensors inputTensors;
108     for (auto&& it : inputData)
109     {
110         BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(it.first);
111         inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
112     }
113
114     // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
115     std::map<std::string, boost::multi_array<float, NumOutputDimensions>> outputStorage;
116     armnn::OutputTensors outputTensors;
117     for (auto&& it : expectedOutputData)
118     {
119         BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
120         outputStorage.emplace(it.first, MakeTensor<float, NumOutputDimensions>(bindingInfo.second));
121         outputTensors.push_back(
122             { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
123     }
124
125     m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
126
127     // Compare each output tensor to the expected values
128     for (auto&& it : expectedOutputData)
129     {
130         BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
131         auto outputExpected = MakeTensor<float, NumOutputDimensions>(bindingInfo.second, it.second);
132         BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first]));
133     }
134 }