Release 18.08
[platform/upstream/armnn.git] / src / armnnOnnxParser / test / BatchNorm.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include  "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12 struct BatchNormalizationMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
14     BatchNormalizationMainFixture()
15     {
16         m_Prototext = R"(
17                    ir_version: 3
18                    producer_name:  "CNTK"
19                    producer_version:  "2.5.1"
20                    domain:  "ai.cntk"
21                    model_version: 1
22                    graph {
23                      name:  "CNTKGraph"
24                      input {
25                         name: "Input"
26                         type {
27                           tensor_type {
28                             elem_type: FLOAT
29                             shape {
30                               dim {
31                                 dim_value: 1
32                               }
33                               dim {
34                                 dim_value: 1
35                               }
36                               dim {
37                                 dim_value: 3
38                               }
39                               dim {
40                                 dim_value: 3
41                               }
42                             }
43                           }
44                         }
45                       }
46                       input {
47                          name: "mean"
48                          type {
49                            tensor_type {
50                              elem_type: FLOAT
51                              shape {
52                                dim {
53                                  dim_value: 1
54                                }
55                              }
56                            }
57                          }
58                        }
59                        input {
60                           name: "var"
61                           type {
62                             tensor_type {
63                               elem_type: FLOAT
64                               shape {
65                                 dim {
66                                   dim_value: 1
67                                 }
68                               }
69                             }
70                           }
71                         }
72                         input {
73                            name: "scale"
74                            type {
75                              tensor_type {
76                                elem_type: FLOAT
77                                shape {
78                                  dim {
79                                    dim_value: 1
80                                  }
81                                }
82                              }
83                            }
84                          }
85                          input {
86                             name: "bias"
87                             type {
88                               tensor_type {
89                                 elem_type: FLOAT
90                                 shape {
91                                   dim {
92                                     dim_value: 1
93                                   }
94                                 }
95                               }
96                             }
97                           }
98                      node {
99                          input: "Input"
100                          input: "scale"
101                          input: "bias"
102                          input: "mean"
103                          input: "var"
104                          output: "Output"
105                          name: "batchNorm"
106                          op_type: "BatchNormalization"
107                          attribute {
108                            name: "epsilon"
109                            f:  0.0010000000475
110                            type: FLOAT
111                          }
112                       }
113                       initializer {
114                           dims: 1
115                           data_type: FLOAT
116                           float_data: 5.0
117                           name: "mean"
118                         }
119                       initializer {
120                         dims: 1
121                         data_type: FLOAT
122                         float_data: 2.0
123                         name: "var"
124                       }
125                       initializer {
126                         dims: 1
127                         data_type: FLOAT
128                         float_data: 0.0
129                         name: "bias"
130                       }
131                       initializer {
132                         dims: 1
133                         data_type: FLOAT
134                         float_data: 1.0
135                         name: "scale"
136                       }
137                       output {
138                           name: "Output"
139                           type {
140                              tensor_type {
141                                elem_type: FLOAT
142                                shape {
143                                    dim {
144                                        dim_value: 1
145                                    }
146                                    dim {
147                                        dim_value: 1
148                                    }
149                                    dim {
150                                        dim_value: 3
151                                    }
152                                    dim {
153                                        dim_value: 3
154                                    }
155                                }
156                             }
157                         }
158                         }
159                     }
160                    opset_import {
161                       version: 7
162                     })";
163         Setup();
164     }
165 };
166
167 BOOST_FIXTURE_TEST_CASE(ValidBatchNormalizationTest, BatchNormalizationMainFixture)
168 {
169     RunTest<4>({{"Input", {1, 2, 3, 4, 5, 6, 7, 8, 9}}},             // Input data.
170                {{"Output", {-2.8277204f, -2.12079024f, -1.4138602f,
171                 -0.7069301f, 0.0f, 0.7069301f,
172                 1.4138602f, 2.12079024f, 2.8277204f}}});  // Expected output data.
173 }
174
175
176 struct BatchNormalizationBisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
177 {
178     BatchNormalizationBisFixture()
179     {
180         m_Prototext = R"(
181                    ir_version: 3
182                    producer_name:  "CNTK"
183                    producer_version:  "2.5.1"
184                    domain:  "ai.cntk"
185                    model_version: 1
186                    graph {
187                      name:  "CNTKGraph"
188                      input {
189                         name: "Input"
190                         type {
191                           tensor_type {
192                             elem_type: FLOAT
193                             shape {
194                               dim {
195                                 dim_value: 1
196                               }
197                               dim {
198                                 dim_value: 2
199                               }
200                               dim {
201                                 dim_value: 1
202                               }
203                               dim {
204                                 dim_value: 3
205                               }
206                             }
207                           }
208                         }
209                       }
210                       input {
211                          name: "mean"
212                          type {
213                            tensor_type {
214                              elem_type: FLOAT
215                              shape {
216                                dim {
217                                  dim_value: 2
218                                }
219                              }
220                            }
221                          }
222                        }
223                        input {
224                           name: "var"
225                           type {
226                             tensor_type {
227                               elem_type: FLOAT
228                               shape {
229                                 dim {
230                                   dim_value: 2
231                                 }
232                               }
233                             }
234                           }
235                         }
236                         input {
237                            name: "scale"
238                            type {
239                              tensor_type {
240                                elem_type: FLOAT
241                                shape {
242                                  dim {
243                                    dim_value: 2
244                                  }
245                                }
246                              }
247                            }
248                          }
249                          input {
250                             name: "bias"
251                             type {
252                               tensor_type {
253                                 elem_type: FLOAT
254                                 shape {
255                                   dim {
256                                     dim_value: 2
257                                   }
258                                 }
259                               }
260                             }
261                           }
262                      node {
263                          input: "Input"
264                          input: "scale"
265                          input: "bias"
266                          input: "mean"
267                          input: "var"
268                          output: "Output"
269                          name: "batchNorm"
270                          op_type: "BatchNormalization"
271                          attribute {
272                            name: "epsilon"
273                            f:  0.00001
274                            type: FLOAT
275                          }
276                       }
277                       initializer {
278                           dims: 2
279                           data_type: FLOAT
280                           float_data: 0.0
281                           float_data: 3.0
282                           name: "mean"
283                         }
284                       initializer {
285                         dims: 2
286                         data_type: FLOAT
287                         float_data: 1.0
288                         float_data: 1.5
289                         name: "var"
290                       }
291                       initializer {
292                         dims: 2
293                         data_type: FLOAT
294                         float_data: 0.0
295                         float_data: 1.0
296                         name: "bias"
297                       }
298                       initializer {
299                         dims: 2
300                         data_type: FLOAT
301                         float_data: 1.0
302                         float_data: 1.5
303                         name: "scale"
304                       }
305                       output {
306                           name: "Output"
307                           type {
308                              tensor_type {
309                                elem_type: FLOAT
310                                shape {
311                                    dim {
312                                        dim_value: 1
313                                    }
314                                    dim {
315                                        dim_value: 2
316                                    }
317                                    dim {
318                                        dim_value: 1
319                                    }
320                                    dim {
321                                        dim_value: 3
322                                    }
323                                }
324                             }
325                         }
326                         }
327                     }
328                    opset_import {
329                       version: 7
330                     })";
331         Setup();
332     }
333 };
334
335 BOOST_FIXTURE_TEST_CASE(ValidBatchNormalizationBisTest, BatchNormalizationBisFixture)
336 {
337     RunTest<4>({{"Input", {-1, 0.0, 1, 2, 3.0, 4.0}}},             // Input data.
338                {{"Output", {-0.999995f, 0.0, 0.999995f,
339                             -0.22474074f, 1.0f, 2.2247407f}}});  // Expected output data.
340 }
341
342 BOOST_AUTO_TEST_SUITE_END()