2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
9 #include <boost/log/trivial.hpp>
11 #include "armnn/ArmNN.hpp"
12 #include "armnn/Utils.hpp"
13 #include "armnn/INetwork.hpp"
14 #include "armnnCaffeParser/ICaffeParser.hpp"
15 #include "../Cifar10Database.hpp"
16 #include "../InferenceTest.hpp"
17 #include "../InferenceModel.hpp"
20 using namespace std::chrono;
21 using namespace armnn::test;
23 int main(int argc, char* argv[])
26 armnn::LogSeverity level = armnn::LogSeverity::Info;
28 armnn::LogSeverity level = armnn::LogSeverity::Debug;
33 // Configures logging for both the ARMNN library and this test program.
34 armnn::ConfigureLogging(true, true, level);
35 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
37 namespace po = boost::program_options;
39 std::vector<armnn::Compute> computeDevice;
43 po::options_description desc("Options");
46 // Adds generic options needed for all inference tests.
48 ("help", "Display help messages")
49 ("model-dir,m", po::value<std::string>(&modelDir)->required(),
50 "Path to directory containing the Cifar10 model file")
51 ("compute,c", po::value<std::vector<armnn::Compute>>(&computeDevice)->default_value
52 ({armnn::Compute::CpuAcc, armnn::Compute::CpuRef}),
53 "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
54 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
55 "Path to directory containing the Cifar10 test data");
57 catch (const std::exception& e)
59 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
60 // and that desc.add_options() can throw boost::io::too_few_args.
61 // They really won't in any of these cases.
62 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
63 std::cerr << "Fatal internal error: " << e.what() << std::endl;
71 po::store(po::parse_command_line(argc, argv, desc), vm);
75 std::cout << desc << std::endl;
83 std::cerr << e.what() << std::endl << std::endl;
84 std::cerr << desc << std::endl;
88 if (!ValidateDirectory(modelDir))
92 string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel";
95 armnn::IRuntime::CreationOptions options;
96 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
102 Net(armnn::NetworkId netId,
103 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& in,
104 const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& out)
106 , m_InputBindingInfo(in)
107 , m_OutputBindingInfo(out)
110 armnn::NetworkId m_Network;
111 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo;
112 std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo;
114 std::vector<Net> networks;
116 armnnCaffeParser::ICaffeParserPtr parser(armnnCaffeParser::ICaffeParser::Create());
118 const int networksCount = 4;
119 for (int i = 0; i < networksCount; ++i)
121 // Creates a network from a file on the disk.
122 armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
124 // Optimizes the network.
125 armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
128 optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
130 catch (armnn::Exception& e)
132 std::stringstream message;
133 message << "armnn::Exception ("<<e.what()<<") caught from optimize.";
134 BOOST_LOG_TRIVIAL(fatal) << message.str();
138 // Loads the network into the runtime.
139 armnn::NetworkId networkId;
140 status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
141 if (status == armnn::Status::Failure)
143 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to load network";
147 networks.emplace_back(networkId,
148 parser->GetNetworkInputBindingInfo("data"),
149 parser->GetNetworkOutputBindingInfo("prob"));
152 // Loads a test case and tests inference.
153 if (!ValidateDirectory(dataDir))
157 Cifar10Database cifar10(dataDir);
159 for (unsigned int i = 0; i < 3; ++i)
161 // Loads test case data (including image data).
162 std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
165 std::vector<std::array<float, 10>> outputs(networksCount);
167 for (unsigned int k = 0; k < networksCount; ++k)
169 status = runtime->EnqueueWorkload(networks[k].m_Network,
170 MakeInputTensors(networks[k].m_InputBindingInfo, testCaseData->m_InputImage),
171 MakeOutputTensors(networks[k].m_OutputBindingInfo, outputs[k]));
172 if (status == armnn::Status::Failure)
174 BOOST_LOG_TRIVIAL(fatal) << "armnn::IRuntime: Failed to enqueue workload";
180 for (unsigned int k = 1; k < networksCount; ++k)
182 if (!std::equal(outputs[0].begin(), outputs[0].end(), outputs[k].begin(), outputs[k].end()))
184 BOOST_LOG_TRIVIAL(error) << "Multiple networks inference failed!";
190 BOOST_LOG_TRIVIAL(info) << "Multiple networks inference ran successfully!";
193 catch (armnn::Exception const& e)
195 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
196 // exception of type std::length_error.
197 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
198 std::cerr << "Armnn Error: " << e.what() << std::endl;
201 catch (const std::exception& e)
203 // Coverity fix: various boost exceptions can be thrown by methods called by this test.
204 std::cerr << "WARNING: MultipleNetworksCifar10: An error has occurred when running the "
205 "multiple networks inference tests: " << e.what() << std::endl;