Release 18.08
[platform/upstream/armnn.git] / tests / MultipleNetworksCifar10 / MultipleNetworksCifar10.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include <iostream>
6 #include <chrono>
7 #include <vector>
8 #include <array>
9 #include <boost/log/trivial.hpp>
10
11 #include "armnn/ArmNN.hpp"
12 #include "armnn/Utils.hpp"
13 #include "armnn/INetwork.hpp"
14 #include "armnnCaffeParser/ICaffeParser.hpp"
15 #include "../Cifar10Database.hpp"
16 #include "../InferenceTest.hpp"
17 #include "../InferenceModel.hpp"
18
19 using namespace std;
20 using namespace std::chrono;
21 using namespace armnn::test;
22
23 int main(int argc, char* argv[])
24 {
25 #ifdef NDEBUG
26     armnn::LogSeverity level = armnn::LogSeverity::Info;
27 #else
28     armnn::LogSeverity level = armnn::LogSeverity::Debug;
29 #endif
30
31     try
32     {
33         // Configures logging for both the ARMNN library and this test program.
34         armnn::ConfigureLogging(true, true, level);
35         armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
36
37         namespace po = boost::program_options;
38
39         std::vector<armnn::Compute> computeDevice;
40         std::string modelDir;
41         std::string dataDir;
42
43         po::options_description desc("Options");
44         try
45         {
46             // Adds generic options needed for all inference tests.
47             desc.add_options()
48                 ("help", "Display help messages")
49                 ("model-dir,m", po::value<std::string>(&modelDir)->required(),
50                     "Path to directory containing the Cifar10 model file")
51                 ("compute,c", po::value<std::vector<armnn::Compute>>(&computeDevice)->default_value
52                      ({armnn::Compute::CpuAcc, armnn::Compute::CpuRef}),
53                     "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
54                 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
55                     "Path to directory containing the Cifar10 test data");
56         }
57         catch (const std::exception& e)
58         {
59             // Coverity points out that default_value(...) can throw a bad_lexical_cast,
60             // and that desc.add_options() can throw boost::io::too_few_args.
61             // They really won't in any of these cases.
62             BOOST_ASSERT_MSG(false, "Caught unexpected exception");
63             std::cerr << "Fatal internal error: " << e.what() << std::endl;
64             return 1;
65         }
66
67         po::variables_map vm;
68
69         try
70         {
71             po::store(po::parse_command_line(argc, argv, desc), vm);
72
73             if (vm.count("help"))
74             {
75                 std::cout << desc << std::endl;
76                 return 1;
77             }
78
79             po::notify(vm);
80         }
81         catch (po::error& e)
82         {
83             std::cerr << e.what() << std::endl << std::endl;
84             std::cerr << desc << std::endl;
85             return 1;
86         }
87
88         if (!ValidateDirectory(modelDir))
89         {
90             return 1;
91         }
92         string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel";
93
94         // Create runtime
95         armnn::IRuntime::CreationOptions options;
96         armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
97
98         // Loads networks.
99         armnn::Status status;
100         struct Net
101         {
102             Net(armnn::NetworkId netId,
103                 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
104                 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
105             : m_Network(netId)
106             , m_InputBindingInfo(in)
107             , m_OutputBindingInfo(out)
108             {}
109
110             armnn::NetworkId m_Network;
111             std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
112             std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
113         };
114         std::vector<Net> networks;
115
116         armnnCaffeParser::ICaffeParserPtr parser(armnnCaffeParser::ICaffeParser::Create());
117
118         const int networksCount = 4;
119         for (int i = 0; i < networksCount; ++i)
120         {
121             // Creates a network from a file on the disk.
122             armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
123
124             // Optimizes the network.
125             armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
126             try
127             {
128                 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
129             }
130             catch (armnn::Exception& e)
131             {
132                 std::stringstream message;
133                 message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
134                 BOOST_LOG_TRIVIAL(fatal) << message.str();
135                 return 1;
136             }
137
138             // Loads the network into the runtime.
139             armnn::NetworkId networkId;
140             status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
141             if (status == armnn::Status::Failure)
142             {
143                 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to load network";
144                 return 1;
145             }
146
147             networks.emplace_back(networkId,
148                 parser->GetNetworkInputBindingInfo("data"),
149                 parser->GetNetworkOutputBindingInfo("prob"));
150         }
151
152         // Loads a test case and tests inference.
153         if (!ValidateDirectory(dataDir))
154         {
155             return 1;
156         }
157         Cifar10Database cifar10(dataDir);
158
159         for (unsigned int i = 0; i < 3; ++i)
160         {
161             // Loads test case data (including image data).
162             std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
163
164             // Tests inference.
165             std::vector<std::array<float, 10>> outputs(networksCount);
166
167             for (unsigned int k = 0; k < networksCount; ++k)
168             {
169                 status = runtime->EnqueueWorkload(networks[k].m_Network,
170                     MakeInputTensors(networks[k].m_InputBindingInfo, testCaseData->m_InputImage),
171                     MakeOutputTensors(networks[k].m_OutputBindingInfo, outputs[k]));
172                 if (status == armnn::Status::Failure)
173                 {
174                     BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload";
175                     return 1;
176                 }
177             }
178
179             // Compares outputs.
180             for (unsigned int k = 1; k < networksCount; ++k)
181             {
182                 if (!std::equal(outputs[0].begin(), outputs[0].end(), outputs[k].begin(), outputs[k].end()))
183                 {
184                     BOOST_LOG_TRIVIAL(error) << "Multiple networks inference failed!";
185                     return 1;
186                 }
187             }
188         }
189
190         BOOST_LOG_TRIVIAL(info) << "Multiple networks inference ran successfully!";
191         return 0;
192     }
193     catch (armnn::Exception const& e)
194     {
195         // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
196         // exception of type std::length_error.
197         // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
198         std::cerr << "Armnn Error: " << e.what() << std::endl;
199         return 1;
200     }
201     catch (const std::exception& e)
202     {
203         // Coverity fix: various boost exceptions can be thrown by methods called by this test. 
204         std::cerr << "WARNING: MultipleNetworksCifar10: An error has occurred when running the "
205                      "multiple networks inference tests: " << e.what() << std::endl;
206         return 1;
207     }
208 }