Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / BiasAdd.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct BiasAddFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14     explicit BiasAddFixture(const std::string& dataFormat)
15     {
16         m_Prototext = R"(
17 node {
18   name: "graphInput"
19   op: "Placeholder"
20   attr {
21     key: "dtype"
22     value {
23       type: DT_FLOAT
24     }
25   }
26   attr {
27     key: "shape"
28     value {
29       shape {
30       }
31     }
32   }
33 }
34 node {
35   name: "bias"
36   op: "Const"
37   attr {
38     key: "dtype"
39     value {
40       type: DT_FLOAT
41     }
42   }
43   attr {
44     key: "value"
45     value {
46       tensor {
47         dtype: DT_FLOAT
48         tensor_shape {
49           dim {
50             size: 3
51           }
52         }
53         float_val: 1
54         float_val: 2
55         float_val: 3
56       }
57     }
58   }
59 }
60 node {
61   name: "biasAdd"
62   op : "BiasAdd"
63   input: "graphInput"
64   input: "bias"
65   attr {
66     key: "T"
67     value {
68       type: DT_FLOAT
69     }
70   }
71   attr {
72     key: "data_format"
73     value {
74       s: ")" + dataFormat + R"("
75     }
76   }
77 }
78 )";
79
80         SetupSingleInputSingleOutput({ 1, 3, 1, 3 }, "graphInput", "biasAdd");
81     }
82 };
83
84 struct BiasAddFixtureNCHW : BiasAddFixture
85 {
86     BiasAddFixtureNCHW() : BiasAddFixture("NCHW") {}
87 };
88
89 struct BiasAddFixtureNHWC : BiasAddFixture
90 {
91     BiasAddFixtureNHWC() : BiasAddFixture("NHWC") {}
92 };
93
94 BOOST_FIXTURE_TEST_CASE(ParseBiasAddNCHW, BiasAddFixtureNCHW)
95 {
96     RunTest<4>(std::vector<float>(9), { 1, 1, 1, 2, 2, 2, 3, 3, 3 });
97 }
98
99 BOOST_FIXTURE_TEST_CASE(ParseBiasAddNHWC, BiasAddFixtureNHWC)
100 {
101     RunTest<4>(std::vector<float>(9), { 1, 2, 3, 1, 2, 3, 1, 2, 3 });
102 }
103
104 BOOST_AUTO_TEST_SUITE_END()