Release 18.08
[platform/upstream/armnn.git] / tests / MultipleNetworksCifar10 / MultipleNetworksCifar10.cpp
index 37138f4..ca6ff45 100644 (file)
@@ -30,25 +30,26 @@ int main(int argc, char* argv[])
 
     try
     {
-        // Configure logging for both the ARMNN library and this test program
+        // Configures logging for both the ARMNN library and this test program.
         armnn::ConfigureLogging(true, true, level);
         armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
 
         namespace po = boost::program_options;
 
-        armnn::Compute computeDevice;
+        std::vector<armnn::Compute> computeDevice;
         std::string modelDir;
         std::string dataDir;
 
         po::options_description desc("Options");
         try
         {
-            // Add generic options needed for all inference tests
+            // Adds generic options needed for all inference tests.
             desc.add_options()
                 ("help", "Display help messages")
                 ("model-dir,m", po::value<std::string>(&modelDir)->required(),
                     "Path to directory containing the Cifar10 model file")
-                ("compute,c", po::value<armnn::Compute>(&computeDevice)->default_value(armnn::Compute::CpuAcc),
+                ("compute,c", po::value<std::vector<armnn::Compute>>(&computeDevice)->default_value
+                     ({armnn::Compute::CpuAcc, armnn::Compute::CpuRef}),
                     "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
                 ("data-dir,d", po::value<std::string>(&dataDir)->required(),
                     "Path to directory containing the Cifar10 test data");
@@ -91,9 +92,10 @@ int main(int argc, char* argv[])
         string modelPath = modelDir + "cifar10_full_iter_60000.caffemodel";
 
         // Create runtime
-        armnn::IRuntimePtr runtime(armnn::IRuntime::Create(computeDevice));
+        armnn::IRuntime::CreationOptions options;
+        armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
 
-        // Load networks
+        // Loads networks.
         armnn::Status status;
         struct Net
         {
@@ -116,14 +118,14 @@ int main(int argc, char* argv[])
         const int networksCount = 4;
         for (int i = 0; i < networksCount; ++i)
         {
-            // Create a network from a file on disk
+            // Creates a network from a file on the disk.
             armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str(), {}, { "prob" });
 
-            // optimize the network
+            // Optimizes the network.
             armnn::IOptimizedNetworkPtr optimizedNet(nullptr, nullptr);
             try
             {
-                optimizedNet = armnn::Optimize(*network, runtime->GetDeviceSpec());
+                optimizedNet = armnn::Optimize(*network, computeDevice, runtime->GetDeviceSpec());
             }
             catch (armnn::Exception& e)
             {
@@ -133,7 +135,7 @@ int main(int argc, char* argv[])
                 return 1;
             }
 
-            // Load the network into the runtime
+            // Loads the network into the runtime.
             armnn::NetworkId networkId;
             status = runtime->LoadNetwork(networkId, std::move(optimizedNet));
             if (status == armnn::Status::Failure)
@@ -147,7 +149,7 @@ int main(int argc, char* argv[])
                 parser->GetNetworkOutputBindingInfo("prob"));
         }
 
-        // Load a test case and test inference
+        // Loads a test case and tests inference.
         if (!ValidateDirectory(dataDir))
         {
             return 1;
@@ -156,10 +158,10 @@ int main(int argc, char* argv[])
 
         for (unsigned int i = 0; i < 3; ++i)
         {
-            // Load test case data (including image data)
+            // Loads test case data (including image data).
             std::unique_ptr<Cifar10Database::TTestCaseData> testCaseData = cifar10.GetTestCaseData(i);
 
-            // Test inference
+            // Tests inference.
             std::vector<std::array<float, 10>> outputs(networksCount);
 
             for (unsigned int k = 0; k < networksCount; ++k)
@@ -174,7 +176,7 @@ int main(int argc, char* argv[])
                 }
             }
 
-            // Compare outputs
+            // Compares outputs.
             for (unsigned int k = 1; k < networksCount; ++k)
             {
                 if (!std::equal(outputs[0].begin(), outputs[0].end(), outputs[k].begin(), outputs[k].end()))