Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / FusedBatchNorm.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct FusedBatchNormFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14     FusedBatchNormFixture()
15     {
16         m_Prototext = "node { \n"
17             "  name: \"graphInput\" \n"
18             "  op: \"Placeholder\" \n"
19             "  attr { \n"
20             "    key: \"dtype\" \n"
21             "    value { \n"
22             "      type: DT_FLOAT \n"
23             "    } \n"
24             "  } \n"
25             "  attr { \n"
26             "    key: \"shape\" \n"
27             "    value { \n"
28             "      shape { \n"
29             "      } \n"
30             "    } \n"
31             "  } \n"
32             "} \n"
33             "node { \n"
34             "  name: \"Const_1\" \n"
35             "  op: \"Const\" \n"
36             "  attr { \n"
37             "    key: \"dtype\" \n"
38             "    value { \n"
39             "      type: DT_FLOAT \n"
40             "    } \n"
41             "  } \n"
42             "  attr { \n"
43             "    key: \"value\" \n"
44             "    value { \n"
45             "      tensor { \n"
46             "        dtype: DT_FLOAT \n"
47             "        tensor_shape { \n"
48             "          dim { \n"
49             "            size: 1 \n"
50             "          } \n"
51             "        } \n"
52             "        float_val: 1.0 \n"
53             "      } \n"
54             "    } \n"
55             "  } \n"
56             "} \n"
57             "node { \n"
58             "  name: \"Const_2\" \n"
59             "  op: \"Const\" \n"
60             "  attr { \n"
61             "    key: \"dtype\" \n"
62             "    value { \n"
63             "      type: DT_FLOAT \n"
64             "    } \n"
65             "  } \n"
66             "  attr { \n"
67             "    key: \"value\" \n"
68             "    value { \n"
69             "      tensor { \n"
70             "        dtype: DT_FLOAT \n"
71             "        tensor_shape { \n"
72             "          dim { \n"
73             "            size: 1 \n"
74             "          } \n"
75             "        } \n"
76             "        float_val: 0.0 \n"
77             "      } \n"
78             "    } \n"
79             "  } \n"
80             "} \n"
81             "node { \n"
82             "  name: \"FusedBatchNormLayer/mean\" \n"
83             "  op: \"Const\" \n"
84             "  attr { \n"
85             "    key: \"dtype\" \n"
86             "    value { \n"
87             "      type: DT_FLOAT \n"
88             "    } \n"
89             "  } \n"
90             "  attr { \n"
91             "    key: \"value\" \n"
92             "    value { \n"
93             "      tensor { \n"
94             "        dtype: DT_FLOAT \n"
95             "        tensor_shape { \n"
96             "          dim { \n"
97             "            size: 1 \n"
98             "          } \n"
99             "        } \n"
100             "        float_val: 5.0 \n"
101             "      } \n"
102             "    } \n"
103             "  } \n"
104             "} \n"
105             "node { \n"
106             "  name: \"FusedBatchNormLayer/variance\" \n"
107             "  op: \"Const\" \n"
108             "  attr { \n"
109             "    key: \"dtype\" \n"
110             "    value { \n"
111             "      type: DT_FLOAT \n"
112             "    } \n"
113             "  } \n"
114             "  attr { \n"
115             "    key: \"value\" \n"
116             "    value { \n"
117             "      tensor { \n"
118             "        dtype: DT_FLOAT \n"
119             "        tensor_shape { \n"
120             "          dim { \n"
121             "            size: 1 \n"
122             "          } \n"
123             "        } \n"
124             "        float_val: 2.0 \n"
125             "      } \n"
126             "    } \n"
127             "  } \n"
128             "} \n"
129             "node { \n"
130             "  name: \"output\" \n"
131             "  op: \"FusedBatchNorm\" \n"
132             "  input: \"graphInput\" \n"
133             "  input: \"Const_1\" \n"
134             "  input: \"Const_2\" \n"
135             "  input: \"FusedBatchNormLayer/mean\" \n"
136             "  input: \"FusedBatchNormLayer/variance\" \n"
137             "  attr { \n"
138             "    key: \"T\" \n"
139             "    value { \n"
140             "      type: DT_FLOAT \n"
141             "    } \n"
142             "  } \n"
143             "  attr { \n"
144             "    key: \"data_format\" \n"
145             "    value { \n"
146             "      s: \"NHWC\" \n"
147             "    } \n"
148             "  } \n"
149             "  attr { \n"
150             "    key: \"epsilon\" \n"
151             "    value { \n"
152             "      f: 0.0010000000475 \n"
153             "    } \n"
154             "  } \n"
155             "  attr { \n"
156             "    key: \"is_training\" \n"
157             "    value { \n"
158             "      b: false \n"
159             "    } \n"
160             "  } \n"
161             "} \n";
162
163         SetupSingleInputSingleOutput({1, 3, 3, 1}, "graphInput", "output");
164     }
165 };
166
167 BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNorm, FusedBatchNormFixture)
168 {
169     RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9},             // input data
170                {-2.8277204f, -2.12079024f, -1.4138602f,
171                 -0.7069301f, 0.0f, 0.7069301f,
172                 1.4138602f, 2.12079024f, 2.8277204f});  // expected output data
173 }
174
175 BOOST_AUTO_TEST_SUITE_END()