Release 18.08
[platform/upstream/armnn.git] / src / armnnTfParser / test / Identity.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct IdentitySimpleFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14     IdentitySimpleFixture()
15     {
16         m_Prototext = "node{ "
17             "  name: \"Placeholder\""
18             "  op: \"Placeholder\""
19             "  attr {"
20             "    key: \"dtype\""
21             "    value {"
22             "      type: DT_FLOAT"
23             "    }"
24             "  }"
25             "  attr {"
26             "    key: \"shape\""
27             "    value {"
28             "      shape {"
29             "        unknown_rank: true"
30             "      }"
31             "    }"
32             "  }"
33             "}"
34             "node {"
35             "  name: \"Identity\""
36             "  op: \"Identity\""
37             "  input: \"Placeholder\""
38             "  attr {"
39             "    key: \"T\""
40             "    value {"
41             "      type: DT_FLOAT"
42             "    }"
43             "  }"
44             "}";
45         SetupSingleInputSingleOutput({ 4 }, "Placeholder", "Identity");
46     }
47 };
48
49 BOOST_FIXTURE_TEST_CASE(IdentitySimple, IdentitySimpleFixture)
50 {
51     RunTest<1>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 1.0f, 2.0f, 3.0f, 4.0f });
52 }
53
54 struct IdentityFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
55 {
56     IdentityFixture()
57     {
58         m_Prototext = "node{ "
59             "  name: \"Placeholder\""
60             "  op: \"Placeholder\""
61             "  attr {"
62             "    key: \"dtype\""
63             "    value {"
64             "      type: DT_FLOAT"
65             "    }"
66             "  }"
67             "  attr {"
68             "    key: \"shape\""
69             "    value {"
70             "      shape {"
71             "        unknown_rank: true"
72             "      }"
73             "    }"
74             "  }"
75             "}"
76             "node {"
77             "  name: \"Identity\""
78             "  op: \"Identity\""
79             "  input: \"Placeholder\""
80             "  attr {"
81             "    key: \"T\""
82             "    value {"
83             "      type: DT_FLOAT"
84             "    }"
85             "  }"
86             "}"
87             "node {"
88             "  name: \"Add\""
89             "  op: \"Add\""
90             "  input: \"Identity\""
91             "  input: \"Identity\""
92             "  attr {"
93             "    key: \"T\""
94             "    value {"
95             "      type: DT_FLOAT"
96             "    }"
97             "  }"
98             "}";
99         SetupSingleInputSingleOutput({ 4 }, "Placeholder", "Add");
100     }
101 };
102
103 BOOST_FIXTURE_TEST_CASE(ParseIdentity, IdentityFixture)
104 {
105     RunTest<1>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 2.0f, 4.0f, 6.0f, 8.0f });
106 }
107
108 struct IdentityChainFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
109 {
110     IdentityChainFixture()
111     {
112         m_Prototext = "node{ "
113             "  name: \"Placeholder\""
114             "  op: \"Placeholder\""
115             "  attr {"
116             "    key: \"dtype\""
117             "    value {"
118             "      type: DT_FLOAT"
119             "    }"
120             "  }"
121             "  attr {"
122             "    key: \"shape\""
123             "    value {"
124             "      shape {"
125             "        unknown_rank: true"
126             "      }"
127             "    }"
128             "  }"
129             "}"
130             "node {"
131             "  name: \"Identity\""
132             "  op: \"Identity\""
133             "  input: \"Placeholder\""
134             "  attr {"
135             "    key: \"T\""
136             "    value {"
137             "      type: DT_FLOAT"
138             "    }"
139             "  }"
140             "}"
141             "node {"
142             "  name: \"Identity2\""
143             "  op: \"Identity\""
144             "  input: \"Identity\""
145             "  attr {"
146             "    key: \"T\""
147             "    value {"
148             "      type: DT_FLOAT"
149             "    }"
150             "  }"
151             "}";
152         SetupSingleInputSingleOutput({ 4 }, "Placeholder", "Identity2");
153     }
154 };
155
156 BOOST_FIXTURE_TEST_CASE(IdentityChain, IdentityChainFixture)
157 {
158     RunTest<1>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 1.0f, 2.0f, 3.0f, 4.0f });
159 }
160
161 BOOST_AUTO_TEST_SUITE_END()