Release 18.08
[platform/upstream/armnn.git] / src / armnnTfLiteParser / test / InputOutputTensorNames.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
11
12 struct EmptyNetworkFixture : public ParserFlatbuffersFixture
13 {
14     explicit EmptyNetworkFixture() {
15         m_JsonString = R"(
16             {
17                 "version": 3,
18                 "operator_codes": [],
19                 "subgraphs": [ {} ]
20             })";
21     }
22 };
23
24 BOOST_FIXTURE_TEST_CASE(EmptyNetworkHasNoInputsAndOutputs, EmptyNetworkFixture)
25 {
26     Setup();
27     BOOST_TEST(m_Parser->GetSubgraphCount() == 1);
28     BOOST_TEST(m_Parser->GetSubgraphInputTensorNames(0).size() == 0);
29     BOOST_TEST(m_Parser->GetSubgraphOutputTensorNames(0).size() == 0);
30 }
31
32 struct MissingTensorsFixture : public ParserFlatbuffersFixture
33 {
34     explicit MissingTensorsFixture()
35     {
36         m_JsonString = R"(
37             {
38                 "version": 3,
39                 "operator_codes": [],
40                 "subgraphs": [{
41                     "inputs" : [ 0, 1 ],
42                     "outputs" : [ 2, 3 ],
43                 }]
44             })";
45     }
46 };
47
48 BOOST_FIXTURE_TEST_CASE(MissingTensorsThrowException, MissingTensorsFixture)
49 {
50     // this throws because it cannot do the input output tensor connections
51     BOOST_CHECK_THROW(Setup(), armnn::ParseException);
52 }
53
54 struct InvalidTensorsFixture : public ParserFlatbuffersFixture
55 {
56     explicit InvalidTensorsFixture()
57     {
58         m_JsonString = R"(
59             {
60                 "version": 3,
61                 "operator_codes": [ ],
62                 "subgraphs": [{
63                     "tensors": [ {}, {}, {}, {} ],
64                     "inputs" : [ 0, 1 ],
65                     "outputs" : [ 2, 3 ],
66                 }]
67             })";
68     }
69 };
70
71 BOOST_FIXTURE_TEST_CASE(InvalidTensorsThrowException, InvalidTensorsFixture)
72 {
73     // this throws because it cannot do the input output tensor connections
74     BOOST_CHECK_THROW(Setup(), armnn::InvalidArgumentException);
75 }
76
77 struct ValidTensorsFixture : public ParserFlatbuffersFixture
78 {
79     explicit ValidTensorsFixture()
80     {
81         m_JsonString = R"(
82             {
83                 "version": 3,
84                 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
85                 "subgraphs": [{
86                     "tensors": [ {
87                         "shape": [ 1, 1, 1, 1 ],
88                         "type": "FLOAT32",
89                         "name": "In",
90                         "buffer": 0,
91                     }, {
92                         "shape": [ 1, 1, 1, 1 ],
93                         "type": "FLOAT32",
94                         "name": "Out",
95                         "buffer": 1,
96                     }],
97                     "inputs" : [ 0 ],
98                     "outputs" : [ 1 ],
99                     "operators": [{
100                         "opcode_index": 0,
101                         "inputs": [ 0 ],
102                         "outputs": [ 1 ],
103                         "builtin_options_type": "Pool2DOptions",
104                         "builtin_options":
105                         {
106                             "padding": "VALID",
107                             "stride_w": 1,
108                             "stride_h": 1,
109                             "filter_width": 1,
110                             "filter_height": 1,
111                             "fused_activation_function": "NONE"
112                         },
113                         "custom_options_format": "FLEXBUFFERS"
114                     }]
115                 }]
116             })";
117     }
118 };
119
120 BOOST_FIXTURE_TEST_CASE(GetValidInputOutputTensorNames, ValidTensorsFixture)
121 {
122     Setup();
123     BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0).size(), 1u);
124     BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
125     BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[0], "In");
126     BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
127 }
128
129 BOOST_FIXTURE_TEST_CASE(ThrowIfSubgraphIdInvalidForInOutNames, ValidTensorsFixture)
130 {
131     Setup();
132
133     // these throw because of the invalid subgraph id
134     BOOST_CHECK_THROW(m_Parser->GetSubgraphInputTensorNames(1), armnn::ParseException);
135     BOOST_CHECK_THROW(m_Parser->GetSubgraphOutputTensorNames(1), armnn::ParseException);
136 }
137
138 BOOST_AUTO_TEST_SUITE_END()