arm_compute v17.12
[platform/upstream/armcl.git] / examples / graph_vgg19.cpp
1 /*
2  * Copyright (c) 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph/Graph.h"
25 #include "arm_compute/graph/Nodes.h"
26 #include "support/ToolchainSupport.h"
27 #include "utils/GraphUtils.h"
28 #include "utils/Utils.h"
29
30 #include <cstdlib>
31
32 using namespace arm_compute::graph;
33 using namespace arm_compute::graph_utils;
34
35 /** Example demonstrating how to implement VGG19's network using the Compute Library's graph API
36  *
37  * @param[in] argc Number of arguments
38  * @param[in] argv Arguments ( [optional] Target (0 = NEON, 1 = OpenCL), [optional] Path to the weights folder, [optional] image, [optional] labels )
39  */
40 void main_graph_vgg19(int argc, const char **argv)
41 {
42     std::string data_path; /* Path to the trainable data */
43     std::string image;     /* Image data */
44     std::string label;     /* Label data */
45
46     constexpr float mean_r = 123.68f;  /* Mean value to subtract from red channel */
47     constexpr float mean_g = 116.779f; /* Mean value to subtract from green channel */
48     constexpr float mean_b = 103.939f; /* Mean value to subtract from blue channel */
49
50     // Set target. 0 (NEON), 1 (OpenCL). By default it is NEON
51     TargetHint            target_hint      = set_target_hint(argc > 1 ? std::strtol(argv[1], nullptr, 10) : 0);
52     ConvolutionMethodHint convolution_hint = ConvolutionMethodHint::DIRECT;
53
54     // Parse arguments
55     if(argc < 2)
56     {
57         // Print help
58         std::cout << "Usage: " << argv[0] << " [target] [path_to_data] [image] [labels]\n\n";
59         std::cout << "No data folder provided: using random values\n\n";
60     }
61     else if(argc == 2)
62     {
63         std::cout << "Usage: " << argv[0] << " " << argv[1] << " [path_to_data] [image] [labels]\n\n";
64         std::cout << "No data folder provided: using random values\n\n";
65     }
66     else if(argc == 3)
67     {
68         data_path = argv[2];
69         std::cout << "Usage: " << argv[0] << " " << argv[1] << " " << argv[2] << " [image] [labels]\n\n";
70         std::cout << "No image provided: using random values\n\n";
71     }
72     else if(argc == 4)
73     {
74         data_path = argv[2];
75         image     = argv[3];
76         std::cout << "Usage: " << argv[0] << " " << argv[1] << " " << argv[2] << " " << argv[3] << " [labels]\n\n";
77         std::cout << "No text file with labels provided: skipping output accessor\n\n";
78     }
79     else
80     {
81         data_path = argv[2];
82         image     = argv[3];
83         label     = argv[4];
84     }
85
86     Graph graph;
87
88     graph << target_hint
89           << convolution_hint
90           << Tensor(TensorInfo(TensorShape(224U, 224U, 3U, 1U), 1, DataType::F32),
91                     get_input_accessor(image, mean_r, mean_g, mean_b))
92           // Layer 1
93           << ConvolutionLayer(
94               3U, 3U, 64U,
95               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv1_1_w.npy"),
96               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv1_1_b.npy"),
97               PadStrideInfo(1, 1, 1, 1))
98           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
99           << ConvolutionLayer(
100               3U, 3U, 64U,
101               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv1_2_w.npy"),
102               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv1_2_b.npy"),
103               PadStrideInfo(1, 1, 1, 1))
104           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
105           << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0)))
106           // Layer 2
107           << ConvolutionLayer(
108               3U, 3U, 128U,
109               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv2_1_w.npy"),
110               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv2_1_b.npy"),
111               PadStrideInfo(1, 1, 1, 1))
112           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
113           << ConvolutionLayer(
114               3U, 3U, 128U,
115               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv2_2_w.npy"),
116               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv2_2_b.npy"),
117               PadStrideInfo(1, 1, 1, 1))
118           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
119           << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0)))
120           // Layer 3
121           << ConvolutionLayer(
122               3U, 3U, 256U,
123               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv3_1_w.npy"),
124               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv3_1_b.npy"),
125               PadStrideInfo(1, 1, 1, 1))
126           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
127           << ConvolutionLayer(
128               3U, 3U, 256U,
129               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv3_2_w.npy"),
130               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv3_2_b.npy"),
131               PadStrideInfo(1, 1, 1, 1))
132           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
133           << ConvolutionLayer(
134               3U, 3U, 256U,
135               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv3_3_w.npy"),
136               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv3_3_b.npy"),
137               PadStrideInfo(1, 1, 1, 1))
138           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
139           << ConvolutionLayer(
140               3U, 3U, 256U,
141               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv3_4_w.npy"),
142               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv3_4_b.npy"),
143               PadStrideInfo(1, 1, 1, 1))
144           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
145           << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0)))
146           // Layer 4
147           << ConvolutionLayer(
148               3U, 3U, 512U,
149               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv4_1_w.npy"),
150               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv4_1_b.npy"),
151               PadStrideInfo(1, 1, 1, 1))
152           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
153           << ConvolutionLayer(
154               3U, 3U, 512U,
155               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv4_2_w.npy"),
156               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv4_2_b.npy"),
157               PadStrideInfo(1, 1, 1, 1))
158           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
159           << ConvolutionLayer(
160               3U, 3U, 512U,
161               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv4_3_w.npy"),
162               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv4_3_b.npy"),
163               PadStrideInfo(1, 1, 1, 1))
164           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
165           << ConvolutionLayer(
166               3U, 3U, 512U,
167               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv4_4_w.npy"),
168               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv4_4_b.npy"),
169               PadStrideInfo(1, 1, 1, 1))
170           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
171           << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0)))
172           // Layer 5
173           << ConvolutionLayer(
174               3U, 3U, 512U,
175               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv5_1_w.npy"),
176               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv5_1_b.npy"),
177               PadStrideInfo(1, 1, 1, 1))
178           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
179           << ConvolutionLayer(
180               3U, 3U, 512U,
181               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv5_2_w.npy"),
182               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv5_2_b.npy"),
183               PadStrideInfo(1, 1, 1, 1))
184           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
185           << ConvolutionLayer(
186               3U, 3U, 512U,
187               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv5_3_w.npy"),
188               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv5_3_b.npy"),
189               PadStrideInfo(1, 1, 1, 1))
190           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
191           << ConvolutionLayer(
192               3U, 3U, 512U,
193               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv5_4_w.npy"),
194               get_weights_accessor(data_path, "/cnn_data/vgg19_model/conv5_4_b.npy"),
195               PadStrideInfo(1, 1, 1, 1))
196           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
197           << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0)))
198           // Layer 6
199           << FullyConnectedLayer(
200               4096U,
201               get_weights_accessor(data_path, "/cnn_data/vgg19_model/fc6_w.npy"),
202               get_weights_accessor(data_path, "/cnn_data/vgg19_model/fc6_b.npy"))
203           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
204           // Layer 7
205           << FullyConnectedLayer(
206               4096U,
207               get_weights_accessor(data_path, "/cnn_data/vgg19_model/fc7_w.npy"),
208               get_weights_accessor(data_path, "/cnn_data/vgg19_model/fc7_b.npy"))
209           << ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))
210           // Layer 8
211           << FullyConnectedLayer(
212               1000U,
213               get_weights_accessor(data_path, "/cnn_data/vgg19_model/fc8_w.npy"),
214               get_weights_accessor(data_path, "/cnn_data/vgg19_model/fc8_b.npy"))
215           // Softmax
216           << SoftmaxLayer()
217           << Tensor(get_output_accessor(label, 5));
218
219     // Run graph
220     graph.run();
221 }
222
223 /** Main program for VGG19
224  *
225  * @param[in] argc Number of arguments
226  * @param[in] argv Arguments ( [optional] Target (0 = NEON, 1 = OpenCL), [optional] Path to the weights folder, [optional] image, [optional] labels )
227  */
228 int main(int argc, const char **argv)
229 {
230     return arm_compute::utils::run_example(argc, argv, main_graph_vgg19);
231 }