Release 18.08
[platform/upstream/armnn.git] / src / armnnOnnxParser / test / Addition.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include  "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12 struct AddMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
14     AddMainFixture(const std::string& dataType)
15     {
16         m_Prototext = R"(
17                    ir_version: 3
18                    producer_name:  "CNTK"
19                    producer_version:  "2.5.1"
20                    domain:  "ai.cntk"
21                    model_version: 1
22                    graph {
23                      name:  "CNTKGraph"
24                      input {
25                         name: "Input0"
26                         type {
27                           tensor_type {
28                             elem_type: )" + dataType + R"(
29                             shape {
30                               dim {
31                                 dim_value: 1
32                               }
33                               dim {
34                                 dim_value: 1
35                               }
36                               dim {
37                                 dim_value: 2
38                               }
39                               dim {
40                                 dim_value: 2
41                               }
42                             }
43                           }
44                         }
45                       }
46                       input {
47                          name: "Input1"
48                          type {
49                            tensor_type {
50                              elem_type: )" + dataType + R"(
51                              shape {
52                                dim {
53                                  dim_value: 1
54                                }
55                                dim {
56                                  dim_value: 1
57                                }
58                                dim {
59                                  dim_value: 2
60                                }
61                                dim {
62                                  dim_value: 2
63                                }
64                              }
65                            }
66                          }
67                        }
68                        node {
69                             input: "Input0"
70                             input: "Input1"
71                             output: "Output"
72                             name: "addition"
73                             op_type: "Add"
74                             doc_string: ""
75                             domain: ""
76                           }
77                           output {
78                               name: "Output"
79                               type {
80                                  tensor_type {
81                                    elem_type: FLOAT
82                                    shape {
83                                        dim {
84                                            dim_value: 1
85                                        }
86                                        dim {
87                                            dim_value: 1
88                                        }
89                                        dim {
90                                            dim_value: 2
91                                        }
92                                        dim {
93                                            dim_value: 2
94                                        }
95                                    }
96                                 }
97                             }
98                         }
99                     }
100                    opset_import {
101                       version: 7
102                     })";
103     }
104 };
105
106 struct AddValidFixture : AddMainFixture
107 {
108     AddValidFixture() : AddMainFixture("FLOAT") {
109         Setup();
110     }
111 };
112
113 struct AddInvalidFixture : AddMainFixture
114 {
115     AddInvalidFixture() : AddMainFixture("INT32") { }
116 };
117
118 struct AddValidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
119 {
120     AddValidBroadcastFixture() {
121
122         m_Prototext = R"(
123                    ir_version: 3
124                    producer_name:  "CNTK"
125                    producer_version:  "2.5.1"
126                    domain:  "ai.cntk"
127                    model_version: 1
128                    graph {
129                      name:  "CNTKGraph"
130                      input {
131                         name: "Input0"
132                         type {
133                           tensor_type {
134                             elem_type: FLOAT
135                             shape {
136                               dim {
137                                 dim_value: 1
138                               }
139                               dim {
140                                 dim_value: 1
141                               }
142                               dim {
143                                 dim_value: 1
144                               }
145                               dim {
146                                 dim_value: 4
147                               }
148                             }
149                           }
150                         }
151                       }
152                       input {
153                          name: "Input1"
154                          type {
155                            tensor_type {
156                              elem_type: FLOAT
157                              shape {
158                                  dim {
159                                    dim_value: 4
160                                  }
161                              }
162                            }
163                          }
164                        }
165                        node {
166                             input: "Input0"
167                             input: "Input1"
168                             output: "Output"
169                             name: "addition"
170                             op_type: "Add"
171                             doc_string: ""
172                             domain: ""
173                           }
174                           output {
175                               name: "Output"
176                               type {
177                                  tensor_type {
178                                    elem_type: FLOAT
179                                    shape {
180                                        dim {
181                                            dim_value: 1
182                                        }
183                                        dim {
184                                            dim_value: 1
185                                        }
186                                        dim {
187                                            dim_value: 1
188                                        }
189                                        dim {
190                                            dim_value: 4
191                                        }
192                                    }
193                                 }
194                             }
195                         }
196                     }
197                    opset_import {
198                       version: 7
199                     })";
200         Setup();
201     }
202 };
203
204 struct AddInvalidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
205 {
206     AddInvalidBroadcastFixture() {
207
208         m_Prototext = R"(
209                    ir_version: 3
210                    producer_name:  "CNTK"
211                    producer_version:  "2.5.1"
212                    domain:  "ai.cntk"
213                    model_version: 1
214                    graph {
215                      name:  "CNTKGraph"
216                      input {
217                         name: "Input0"
218                         type {
219                           tensor_type {
220                             elem_type: FLOAT
221                             shape {
222                               dim {
223                                 dim_value: 1
224                               }
225                               dim {
226                                 dim_value: 1
227                               }
228                               dim {
229                                 dim_value: 1
230                               }
231                               dim {
232                                 dim_value: 3
233                               }
234                             }
235                           }
236                         }
237                       }
238                       input {
239                          name: "Input1"
240                          type {
241                            tensor_type {
242                              elem_type: FLOAT
243                              shape {
244                                  dim {
245                                    dim_value: 4
246                                  }
247                              }
248                            }
249                          }
250                        }
251                        node {
252                             input: "Input0"
253                             input: "Input1"
254                             output: "Output"
255                             name: "addition"
256                             op_type: "Add"
257                             doc_string: ""
258                             domain: ""
259                           }
260                           output {
261                               name: "Output"
262                               type {
263                                  tensor_type {
264                                    elem_type: FLOAT
265                                    shape {
266                                        dim {
267                                            dim_value: 1
268                                        }
269                                        dim {
270                                            dim_value: 1
271                                        }
272                                        dim {
273                                            dim_value: 1
274                                        }
275                                        dim {
276                                            dim_value: 4
277                                        }
278                                    }
279                                 }
280                             }
281                         }
282                     }
283                    opset_import {
284                       version: 7
285                     })";
286     }
287 };
288
289 BOOST_FIXTURE_TEST_CASE(ValidAddTest, AddValidFixture)
290 {
291     RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
292                 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
293 }
294
295 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeAdd, AddInvalidFixture)
296 {
297    BOOST_CHECK_THROW(Setup(), armnn::ParseException);
298 }
299
300 BOOST_FIXTURE_TEST_CASE(InvalidBroadcastAdd, AddInvalidBroadcastFixture)
301 {
302    BOOST_CHECK_THROW(Setup(), armnn::ParseException);
303 }
304
305 BOOST_FIXTURE_TEST_CASE(ValidBroadcastAdd, AddValidBroadcastFixture)
306 {
307     RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
308                 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
309 }
310
311 BOOST_AUTO_TEST_SUITE_END()