Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / Multiplication.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct MultiplicationFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14     MultiplicationFixture()
15     {
16         m_Prototext = "node { \n"
17             "    name: \"graphInput\" \n"
18             "    op: \"Placeholder\" \n"
19             "    attr { \n"
20             "      key: \"dtype\" \n"
21             "      value { \n"
22             "        type: DT_FLOAT \n"
23             "      } \n"
24             "    } \n"
25             "    attr { \n"
26             "      key: \"shape\" \n"
27             "      value { \n"
28             "        shape { \n"
29             "        } \n"
30             "      } \n"
31             "    } \n"
32             "  } \n"
33             "  node { \n"
34             "    name: \"softmax1\" \n"
35             "    op: \"Softmax\" \n"
36             "    input: \"graphInput\" \n"
37             "    attr { \n"
38             "      key: \"T\" \n"
39             "      value { \n"
40             "        type: DT_FLOAT \n"
41             "      } \n"
42             "    } \n"
43             "  }\n"
44             "  node {\n"
45             "    name: \"softmax2\"\n"
46             "    op : \"Softmax\"\n"
47             "    input: \"graphInput\"\n"
48             "    attr { \n"
49             "      key: \"T\" \n"
50             "      value { \n"
51             "        type: DT_FLOAT \n"
52             "      } \n"
53             "    } \n"
54             "  }\n"
55             "  node {\n"
56             "    name: \"multiplication\"\n"
57             "    op : \"Mul\"\n"
58             "    input: \"softmax1\"\n"
59             "    input: \"softmax2\"\n"
60             "    attr { \n"
61             "      key: \"T\" \n"
62             "      value { \n"
63             "        type: DT_FLOAT \n"
64             "      } \n"
65             "    } \n"
66             "  }\n";
67
68         SetupSingleInputSingleOutput({ 1, 7 }, "graphInput", "multiplication");
69     }
70 };
71
72 BOOST_FIXTURE_TEST_CASE(ParseMultiplication, MultiplicationFixture)
73 {
74     RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 1, 0, 0, 0, 0 });
75 }
76
77 struct MultiplicationBroadcastFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
78 {
79     MultiplicationBroadcastFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1)
80     {
81         m_Prototext = R"(
82 node {
83   name: "input0"
84   op: "Placeholder"
85   attr {
86     key: "dtype"
87     value {
88       type: DT_FLOAT
89     }
90   }
91   attr {
92     key: "shape"
93     value {
94       shape {
95       }
96     }
97   }
98 }
99 node {
100   name: "input1"
101   op: "Placeholder"
102   attr {
103     key: "dtype"
104     value {
105       type: DT_FLOAT
106     }
107   }
108   attr {
109     key: "shape"
110     value {
111       shape {
112       }
113     }
114   }
115 }
116 node {
117   name: "output"
118   op: "Mul"
119   input: "input0"
120   input: "input1"
121   attr {
122     key: "T"
123     value {
124       type: DT_FLOAT
125     }
126   }
127 }
128         )";
129
130         Setup({ { "input0", inputShape0 },
131                 { "input1", inputShape1 } },
132               { "output" });
133     }
134 };
135
136 struct MultiplicationBroadcastFixture4D1D : public MultiplicationBroadcastFixture
137 {
138     MultiplicationBroadcastFixture4D1D() : MultiplicationBroadcastFixture({ 1, 2, 2, 3 }, { 1 }) {}
139 };
140
141 BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast4D1D, MultiplicationBroadcastFixture4D1D)
142 {
143     RunTest<4>({ { "input0", { 0.0f,  1.0f,  2.0f,
144                                3.0f,  4.0f,  5.0f,
145                                6.0f,  7.0f,  8.0f,
146                                9.0f, 10.0f, 11.0f } },
147                  { "input1", { 5.0f } } },
148                { { "output", { 0.0f,  5.0f, 10.0f,
149                               15.0f, 20.0f, 25.0f,
150                               30.0f, 35.0f, 40.0f,
151                               45.0f, 50.0f, 55.0f } } });
152 }
153
154 struct MultiplicationBroadcastFixture1D4D : public MultiplicationBroadcastFixture
155 {
156     MultiplicationBroadcastFixture1D4D() : MultiplicationBroadcastFixture({ 1 }, { 1, 2, 2, 3 }) {}
157 };
158
159 BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast1D4D, MultiplicationBroadcastFixture1D4D)
160 {
161     RunTest<4>({ { "input0", { 3.0f } },
162                  { "input1", { 0.0f,  1.0f,  2.0f,
163                                3.0f,  4.0f,  5.0f,
164                                6.0f,  7.0f,  8.0f,
165                                9.0f, 10.0f, 11.0f } } },
166                { { "output", { 0.0f,  3.0f,  6.0f,
167                                9.0f, 12.0f, 15.0f,
168                               18.0f, 21.0f, 24.0f,
169                               27.0f, 30.0f, 33.0f } } });
170 }
171
172 BOOST_AUTO_TEST_SUITE_END()