Release 18.08
[platform/upstream/armnn.git] / src / armnnTfParser / test / MaximumForLeakyRelu.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct UnsupportedMaximumFixture
13     : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
14 {
15     UnsupportedMaximumFixture()
16     {
17         m_Prototext = R"(
18             node {
19                 name: "graphInput"
20                 op: "Placeholder"
21                 attr {
22                     key: "dtype"
23                     value {
24                         type: DT_FLOAT
25                     }
26                 }
27                 attr {
28                     key: "shape"
29                     value {
30                         shape {
31                         }
32                     }
33                 }
34             }
35             node {
36                 name: "Maximum"
37                 op: "Maximum"
38                 input: "graphInput"
39                 input: "graphInput"
40                 attr {
41                     key: "dtype"
42                     value {
43                         type: DT_FLOAT
44                     }
45                 }
46             }
47         )";
48     }
49 };
50
51 BOOST_FIXTURE_TEST_CASE(UnsupportedMaximum, UnsupportedMaximumFixture)
52 {
53     BOOST_CHECK_THROW(
54         SetupSingleInputSingleOutput({ 1, 1 }, "graphInput", "Maximum"),
55         armnn::ParseException);
56 }
57
58 struct SupportedMaximumFixture
59     : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
60 {
61     SupportedMaximumFixture(const std::string & maxInput0,
62                             const std::string & maxInput1,
63                             const std::string & mulInput0,
64                             const std::string & mulInput1)
65     {
66         m_Prototext = R"(
67             node {
68                 name: "graphInput"
69                 op: "Placeholder"
70                 attr {
71                     key: "dtype"
72                     value { type: DT_FLOAT }
73                 }
74                 attr {
75                     key: "shape"
76                     value { shape { } }
77                 }
78             }
79             node {
80                 name: "Alpha"
81                 op: "Const"
82                 attr {
83                     key: "dtype"
84                     value { type: DT_FLOAT }
85                 }
86                 attr {
87                     key: "value"
88                     value {
89                         tensor {
90                             dtype: DT_FLOAT
91                             tensor_shape {
92                                 dim { size: 1 }
93                             }
94                             float_val: 0.1
95                         }
96                     }
97                 }
98             }
99             node {
100                 name: "Mul"
101                 op: "Mul"
102                 input: ")" + mulInput0 + R"("
103                 input: ")" + mulInput1 + R"("
104                 attr {
105                     key: "T"
106                     value { type: DT_FLOAT }
107                 }
108             }
109             node {
110                 name: "Maximum"
111                 op: "Maximum"
112                 input: ")" + maxInput0 + R"("
113                 input: ")" + maxInput1 + R"("
114                 attr {
115                     key: "T"
116                     value { type: DT_FLOAT }
117                 }
118             }
119         )";
120         SetupSingleInputSingleOutput({ 1, 2 }, "graphInput", "Maximum");
121     }
122 };
123
124 struct LeakyRelu_Max_MulAT_T_Fixture : public SupportedMaximumFixture
125 {
126     LeakyRelu_Max_MulAT_T_Fixture()
127     : SupportedMaximumFixture("Mul","graphInput","Alpha","graphInput") {}
128 };
129
130 BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_MulAT_T, LeakyRelu_Max_MulAT_T_Fixture)
131 {
132     RunTest<2>(std::vector<float>({-5.0, 3.0}), {-0.5, 3.0});
133 }
134
135 struct LeakyRelu_Max_T_MulAT_Fixture : public SupportedMaximumFixture
136 {
137     LeakyRelu_Max_T_MulAT_Fixture()
138     : SupportedMaximumFixture("graphInput","Mul","Alpha","graphInput") {}
139 };
140
141
142 BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_T_MulAT, LeakyRelu_Max_T_MulAT_Fixture)
143 {
144     RunTest<2>(std::vector<float>({-10.0, 3.0}), {-1.0, 3.0});
145 }
146
147 struct LeakyRelu_Max_MulTA_T_Fixture : public SupportedMaximumFixture
148 {
149     LeakyRelu_Max_MulTA_T_Fixture()
150     : SupportedMaximumFixture("Mul", "graphInput","graphInput","Alpha") {}
151 };
152
153 BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_MulTA_T, LeakyRelu_Max_MulTA_T_Fixture)
154 {
155     RunTest<2>(std::vector<float>({-5.0, 3.0}), {-0.5, 3.0});
156 }
157
158 struct LeakyRelu_Max_T_MulTA_Fixture : public SupportedMaximumFixture
159 {
160     LeakyRelu_Max_T_MulTA_Fixture()
161     : SupportedMaximumFixture("graphInput", "Mul", "graphInput", "Alpha") {}
162 };
163
164 BOOST_FIXTURE_TEST_CASE(LeakyRelu_Max_T_MulTA, LeakyRelu_Max_T_MulTA_Fixture)
165 {
166     RunTest<2>(std::vector<float>({-10.0, 13.0}), {-1.0, 13.0});
167 }
168
169 BOOST_AUTO_TEST_SUITE_END()