Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / TestMultiInputsOutputs.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct MultiInputsOutputsFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14     MultiInputsOutputsFixture()
15     {
16         // input1 = tf.placeholder(tf.float32, shape=[], name = "input1")
17         // input2 = tf.placeholder(tf.float32, shape = [], name = "input2")
18         // add1 = tf.add(input1, input2, name = "add1")
19         // add2 = tf.add(input1, input2, name = "add2")
20         m_Prototext = R"(
21 node {
22   name: "input1"
23   op: "Placeholder"
24   attr {
25     key: "dtype"
26     value {
27       type: DT_FLOAT
28     }
29   }
30   attr {
31     key: "shape"
32     value {
33       shape {
34       }
35     }
36   }
37 }
38 node {
39   name: "input2"
40   op: "Placeholder"
41   attr {
42     key: "dtype"
43     value {
44       type: DT_FLOAT
45     }
46   }
47   attr {
48     key: "shape"
49     value {
50       shape {
51       }
52     }
53   }
54 }
55 node {
56   name: "add1"
57   op: "Add"
58   input: "input1"
59   input: "input2"
60   attr {
61     key: "T"
62     value {
63       type: DT_FLOAT
64     }
65   }
66 }
67 node {
68   name: "add2"
69   op: "Add"
70   input: "input1"
71   input: "input2"
72   attr {
73     key: "T"
74     value {
75       type: DT_FLOAT
76     }
77   }
78 }
79         )";
80         Setup({ { "input1", { 1 } },
81                 { "input2", { 1 } } },
82               { "add1", "add2" });
83     }
84 };
85
86 BOOST_FIXTURE_TEST_CASE(MultiInputsOutputs, MultiInputsOutputsFixture)
87 {
88     RunTest<1>({ { "input1", {12.0f} }, { "input2", { 13.0f } } },
89                { { "add1", { 25.0f } }, { "add2", { 25.0f } } });
90 }
91
92 BOOST_AUTO_TEST_SUITE_END()