Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / MultiOutput.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct MultiOutMatchFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14     MultiOutMatchFixture()
15     {
16         m_Prototext = R"(
17 node {
18     name: "input"
19     op: "Placeholder"
20     attr {
21         key: "dtype"
22         value {
23             type: DT_FLOAT
24         }
25     }
26     attr {
27         key: "shape"
28         value {
29             shape {
30             }
31         }
32     }
33 }
34 node {
35     name: "softmax1"
36     op: "Softmax"
37     input: "input:0"
38     attr {
39         key: "T"
40         value {
41             type: DT_FLOAT
42         }
43     }
44 }
45         )";
46         SetupSingleInputSingleOutput({ 1, 7 }, "input", "softmax1");
47     }
48 };
49
50 BOOST_FIXTURE_TEST_CASE(MultiOutMatch, MultiOutMatchFixture)
51 {
52     // Note that the point of this test is to verify the parsing went well.
53     // Here we make sure the softmax has really connected to the input layer.
54     RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 1, 0, 0, 0, 0 });
55 }
56
57 struct MultiOutFailFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
58 {
59     MultiOutFailFixture()
60     {
61         m_Prototext = R"(
62 node {
63     name: "input"
64     op: "Placeholder"
65     attr {
66         key: "dtype"
67         value {
68             type: DT_FLOAT
69         }
70     }
71     attr {
72         key: "shape"
73         value {
74             shape {
75             }
76         }
77     }
78 }
79 node {
80     name: "softmax1"
81     op: "Softmax"
82     input: "input:1"
83     attr {
84         key: "T"
85         value {
86             type: DT_FLOAT
87         }
88     }
89 }
90         )";
91         BOOST_CHECK_THROW(SetupSingleInputSingleOutput({ 1, 7 }, "input", "softmax1"), armnn::ParseException);
92     }
93 };
94
95 BOOST_FIXTURE_TEST_CASE(MultiOutFail, MultiOutFailFixture)
96 {
97     // Not running the graph because this is expected to throw an exception during parsing.
98 }
99
100 struct MultiOutInvalidFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
101 {
102     MultiOutInvalidFixture()
103     {
104         m_Prototext = R"(
105 node {
106     name: "input"
107     op: "Placeholder"
108     attr {
109         key: "dtype"
110         value {
111             type: DT_FLOAT
112         }
113     }
114     attr {
115         key: "shape"
116         value {
117             shape {
118             }
119         }
120     }
121 }
122 node {
123     name: "softmax1"
124     op: "Softmax"
125     input: "input:-1"
126     attr {
127         key: "T"
128         value {
129             type: DT_FLOAT
130         }
131     }
132 }
133         )";
134         BOOST_CHECK_THROW(SetupSingleInputSingleOutput({ 1, 7 }, "input", "softmax1"), armnn::ParseException);
135     }
136 };
137
138 BOOST_FIXTURE_TEST_CASE(MultiOutInvalid, MultiOutInvalidFixture)
139 {
140     // Not running the graph because this is expected to throw an exception during parsing.
141 }
142
143
144 BOOST_AUTO_TEST_SUITE_END()