arm_compute v18.02
[platform/upstream/armcl.git] / examples / graph_alexnet.cpp
index 8705c8e..a396c76 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 20172018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -38,7 +38,7 @@ using namespace arm_compute::graph_utils;
 /** Example demonstrating how to implement AlexNet's network using the Compute Library's graph API
  *
  * @param[in] argc Number of arguments
- * @param[in] argv Arguments ( [optional] Target (0 = NEON, 1 = OpenCL), [optional] Path to the weights folder, [optional] image, [optional] labels )
+ * @param[in] argv Arguments ( [optional] Target (0 = NEON, 1 = OpenCL, 2 = OpenCL with Tuner), [optional] Path to the weights folder, [optional] image, [optional] labels )
  */
 class GraphAlexnetExample : public Example
 {
@@ -49,13 +49,16 @@ public:
         std::string image;     /* Image data */
         std::string label;     /* Label data */
 
-        constexpr float mean_r = 122.68f; /* Mean value to subtract from red channel */
-        constexpr float mean_g = 116.67f; /* Mean value to subtract from green channel */
-        constexpr float mean_b = 104.01f; /* Mean value to subtract from blue channel */
+        // Create a preprocessor object
+        const std::array<float, 3> mean_rgb{ { 122.68f, 116.67f, 104.01f } };
+        std::unique_ptr<IPreprocessor> preprocessor = arm_compute::support::cpp14::make_unique<CaffePreproccessor>(mean_rgb);
+
+        // Set target. 0 (NEON), 1 (OpenCL), 2 (OpenCL with Tuner). By default it is NEON
+        const int  int_target_hint = argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0;
+        TargetHint target_hint     = set_target_hint(int_target_hint);
 
-        // Set target. 0 (NEON), 1 (OpenCL). By default it is NEON
-        TargetHint            target_hint      = set_target_hint(argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0);
-        ConvolutionMethodHint convolution_hint = target_hint == TargetHint::NEON ? ConvolutionMethodHint::GEMM : ConvolutionMethodHint::DIRECT;
+        const bool            is_gemm_convolution5x5 = Graph::gpu_target() == arm_compute::GPUTarget::MIDGARD || target_hint == TargetHint::NEON;
+        ConvolutionMethodHint convolution_5x5_hint   = is_gemm_convolution5x5 ? ConvolutionMethodHint::GEMM : ConvolutionMethodHint::DIRECT;
 
         // Parse arguments
         if(argc < 2)
@@ -91,7 +94,7 @@ public:
 
         graph << target_hint
               << Tensor(TensorInfo(TensorShape(227U, 227U, 3U, 1U), 1, DataType::F32),
-                        get_input_accessor(image, mean_r, mean_g, mean_b))
+                        get_input_accessor(image, std::move(preprocessor)))
               // Layer 1
               << ConvolutionLayer(
                   11U, 11U, 96U,
@@ -102,7 +105,7 @@ public:
               << NormalizationLayer(NormalizationLayerInfo(NormType::CROSS_MAP, 5, 0.0001f, 0.75f))
               << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 3, PadStrideInfo(2, 2, 0, 0)))
               // Layer 2
-              << convolution_hint
+              << convolution_5x5_hint
               << ConvolutionLayer(
                   5U, 5U, 256U,
                   get_weights_accessor(data_path, "/cnn_data/alexnet_model/conv2_w.npy"),
@@ -111,6 +114,7 @@ public:
               << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
               << NormalizationLayer(NormalizationLayerInfo(NormType::CROSS_MAP, 5, 0.0001f, 0.75f))
               << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 3, PadStrideInfo(2, 2, 0, 0)))
+              << ConvolutionMethodHint::GEMM
               // Layer 3
               << ConvolutionLayer(
                   3U, 3U, 384U,
@@ -153,6 +157,9 @@ public:
               // Softmax
               << SoftmaxLayer()
               << Tensor(get_output_accessor(label, 5));
+
+        // In order to enable the OpenCL tuner, graph_init() has to be called only when all nodes have been instantiated
+        graph.graph_init(int_target_hint == 2);
     }
     void do_run() override
     {