Release 18.08
[platform/upstream/armnn.git] / src / armnnTfParser / test / Shape.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct ShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14     ShapeFixture()
15     {
16         m_Prototext =
17             "node { \n"
18             "  name: \"Placeholder\" \n"
19             "  op: \"Placeholder\" \n"
20             "  attr { \n"
21             "    key: \"dtype\" \n"
22             "    value { \n"
23             "      type: DT_FLOAT \n"
24             "    } \n"
25             "  } \n"
26             "  attr { \n"
27             "    key: \"shape\" \n"
28             "    value { \n"
29             "      shape { \n"
30             "        dim { \n"
31             "          size: 1 \n"
32             "        } \n"
33             "        dim { \n"
34             "          size: 1 \n"
35             "        } \n"
36             "        dim { \n"
37             "          size: 1 \n"
38             "        } \n"
39             "        dim { \n"
40             "          size: 4 \n"
41             "        } \n"
42             "      } \n"
43             "    } \n"
44             "  } \n"
45             "} \n"
46             "node { \n"
47             "  name: \"shapeTest\" \n"
48             "  op: \"Shape\" \n"
49             "  input: \"Placeholder\" \n"
50             "  attr { \n"
51             "    key: \"T\" \n"
52             "    value { \n"
53             "      type: DT_FLOAT \n"
54             "    } \n"
55             "  } \n"
56             "  attr { \n"
57             "    key: \"out_type\" \n"
58             "    value { \n"
59             "      type: DT_INT32 \n"
60             "    } \n"
61             "  } \n"
62             "} \n"
63             "node { \n"
64             "  name: \"Reshape\" \n"
65             "  op: \"Reshape\" \n"
66             "  input: \"Placeholder\" \n"
67             "  input: \"shapeTest\" \n"
68             "  attr { \n"
69             "    key: \"T\" \n"
70             "    value { \n"
71             "      type: DT_FLOAT \n"
72             "    } \n"
73             "  } \n"
74             "  attr { \n"
75             "    key: \"Tshape\" \n"
76             "    value { \n"
77             "      type: DT_INT32 \n"
78             "    } \n"
79             "  } \n"
80             "} \n";
81
82         SetupSingleInputSingleOutput({1, 4}, "Placeholder", "Reshape");
83     }
84 };
85
86 BOOST_FIXTURE_TEST_CASE(ParseShape, ShapeFixture)
87 {
88     // Note: the test's output cannot be an int32 const layer, because ARMNN only supports u8 and float layers.
89     //       For that reason I added a reshape layer which reshapes the input to its original dimensions.
90     RunTest<2>({ 0.0f, 1.0f, 2.0f, 3.0f }, { 0.0f, 1.0f, 2.0f, 3.0f });
91 }
92
93 BOOST_AUTO_TEST_SUITE_END()