Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / moco / import / src / Nodes / Concat.test.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "moco/Import/Nodes/Concat.h"
18 #include "TestHelper.h"
19
20 #include <gtest/gtest.h>
21
22 using namespace moco::test;
23
24 namespace
25 {
26
27 // clang-format off
28 const char *concat_01_pbtxtdata = STRING_CONTENT(
29   name: "Concat"
30   op: "ConcatV2"
31   input: "Input01"
32   input: "Input02"
33   input: "Axis"
34   attr {
35     key: "N"
36     value {
37       i: 2
38     }
39   }
40   attr {
41     key: "T"
42     value {
43       type: DT_FLOAT
44     }
45   }
46   attr {
47     key: "Tidx"
48     value {
49       type: DT_INT32
50     }
51   }
52 );
53 // clang-format on
54
55 } // namespace
56
57 TEST(TensorFlowImport, concat_01)
58 {
59   TFNodeBuildTester tester;
60   moco::ConcatV2GraphBuilder graphbuilder;
61   tensorflow::NodeDef nodedef;
62
63   EXPECT_TRUE(plier::tf::parse_nodedef(concat_01_pbtxtdata, nodedef));
64
65   // what to test:
66   // - there should exist TFConcatV2
67   // - there should be two values
68   // - values(idx) should not be nullptr
69   // - axis() should not be nullptr
70
71   tester.inputs({"Input01", "Input02", "Axis"});
72   tester.output("Concat");
73   tester.run(nodedef, graphbuilder);
74
75   auto test_node = loco::must_cast<moco::TFConcatV2 *>(tester.output());
76   ASSERT_NE(test_node, nullptr);
77   ASSERT_EQ(test_node->num_values(), 2);
78 }
79
80 namespace
81 {
82
83 // clang-format off
84 const char *concat_02_pbtxtdata = STRING_CONTENT(
85   name: "Concat"
86   op: "ConcatV2"
87   input: "Input01"
88   input: "Input02"
89   input: "Input03"
90   input: "Axis"
91   attr {
92     key: "N"
93     value {
94       i: 3
95     }
96   }
97   attr {
98     key: "T"
99     value {
100       type: DT_FLOAT
101     }
102   }
103   attr {
104     key: "Tidx"
105     value {
106       type: DT_INT32
107     }
108   }
109 );
110 // clang-format on
111
112 } // namespace
113
114 TEST(TensorFlowImport, concat_02)
115 {
116   TFNodeBuildTester tester;
117   moco::ConcatV2GraphBuilder graphbuilder;
118   tensorflow::NodeDef nodedef;
119
120   EXPECT_TRUE(plier::tf::parse_nodedef(concat_02_pbtxtdata, nodedef));
121
122   // what to test: TFConcatV2 has 3 inputs
123   // - there should exist TFConcatV2
124   // - values(idx) should not be nullptr
125   // - axis() should not be nullptr
126
127   tester.inputs({"Input01", "Input02", "Input03", "Axis"});
128   tester.output("Concat");
129   tester.run(nodedef, graphbuilder);
130
131   auto test_node = loco::must_cast<moco::TFConcatV2 *>(tester.output());
132   ASSERT_NE(test_node, nullptr);
133   ASSERT_EQ(test_node->num_values(), 3);
134 }