Release 18.08
[platform/upstream/armnn.git] / src / armnnOnnxParser / test / Reshape.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include  "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12 struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
14     ReshapeMainFixture(const std::string& dataType)
15     {
16         m_Prototext = R"(
17                    ir_version: 3
18                    producer_name:  "CNTK"
19                    producer_version:  "2.5.1"
20                    domain:  "ai.cntk"
21                    model_version: 1
22                    graph {
23                      name:  "CNTKGraph"
24                      input {
25                         name: "Input"
26                         type {
27                           tensor_type {
28                             elem_type: )" + dataType + R"(
29                             shape {
30                               dim {
31                                 dim_value: 4
32                               }
33                             }
34                           }
35                         }
36                       }
37                       input {
38                          name: "Shape"
39                          type {
40                            tensor_type {
41                              elem_type: INT64
42                              shape {
43                                dim {
44                                  dim_value: 2
45                                }
46                              }
47                            }
48                          }
49                        }
50                      node {
51                          input: "Input"
52                          input: "Shape"
53                          output: "Output"
54                          name: "reshape"
55                          op_type: "Reshape"
56
57                       }
58                       initializer {
59                         dims: 2
60                         data_type: INT64
61                         int64_data: 2
62                         int64_data: 2
63                         name: "Shape"
64                      }
65                       output {
66                           name: "Output"
67                           type {
68                              tensor_type {
69                                elem_type: FLOAT
70                                shape {
71                                    dim {
72                                        dim_value: 2
73                                    }
74                                    dim {
75                                        dim_value: 2
76                                    }
77                                }
78                             }
79                           }
80                        }
81                     }
82                    opset_import {
83                       version: 7
84                     })";
85     }
86 };
87
88 struct ReshapeValidFixture : ReshapeMainFixture
89 {
90     ReshapeValidFixture() : ReshapeMainFixture("FLOAT") {
91         Setup();
92     }
93 };
94
95 struct ReshapeInvalidFixture : ReshapeMainFixture
96 {
97     ReshapeInvalidFixture() : ReshapeMainFixture("FLOAT16") { }
98 };
99
100 BOOST_FIXTURE_TEST_CASE(ValidReshapeTest, ReshapeValidFixture)
101 {
102     RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
103 }
104
105 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeReshape, ReshapeInvalidFixture)
106 {
107    BOOST_CHECK_THROW(Setup(), armnn::ParseException);
108 }
109
110 BOOST_AUTO_TEST_SUITE_END()