Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / moco / import / src / Nodes / Mean.test.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "moco/Import/Nodes/Mean.h"
18 #include "TestHelper.h"
19
20 #include <gtest/gtest.h>
21
22 using namespace moco::test;
23
24 namespace
25 {
26
27 // clang-format off
28 const char *mean_true_pbtxtdata = STRING_CONTENT(
29   name: "Mean"
30   op: "Mean"
31   input: "Placeholder"
32   input: "Const"
33   attr {
34     key: "T"
35     value { type: DT_FLOAT }
36   }
37   attr {
38     key: "Tidx"
39     value { type: DT_INT32 }
40   }
41   attr {
42     key: "keep_dims"
43     value { b: true }
44   }
45 );
46 // clang-format on
47
48 } // namespace
49
50 TEST(TensorFlowImport, mean_true)
51 {
52   TFNodeBuildTester tester;
53   moco::MeanGraphBuilder graphbuilder;
54   tensorflow::NodeDef nodedef;
55
56   EXPECT_TRUE(plier::tf::parse_nodedef(mean_true_pbtxtdata, nodedef));
57
58   // what to test:
59   // - there should exist TFMean
60   // - input node should not be nullptr
61   // - reduction_indeces node should not be nullptr
62   // - keep_dims attribute is set same as pbtxt
63
64   tester.inputs({"Placeholder", "Const"});
65   tester.output("Mean");
66   tester.run(nodedef, graphbuilder);
67
68   auto test_node = loco::must_cast<moco::TFMean *>(tester.output());
69   ASSERT_NE(test_node, nullptr);
70   ASSERT_EQ(test_node->keep_dims(), true);
71 }
72
73 namespace
74 {
75
76 // clang-format off
77 const char *mean_false_pbtxtdata = STRING_CONTENT(
78   name: "Mean"
79   op: "Mean"
80   input: "Placeholder"
81   input: "Const"
82   attr {
83     key: "T"
84     value { type: DT_FLOAT }
85   }
86   attr {
87     key: "Tidx"
88     value { type: DT_INT32 }
89   }
90   attr {
91     key: "keep_dims"
92     value { b: false }
93   }
94 );
95 // clang-format on
96
97 } // namespace
98
99 TEST(TensorFlowImport, mean_false)
100 {
101   TFNodeBuildTester tester;
102   moco::MeanGraphBuilder graphbuilder;
103   tensorflow::NodeDef nodedef;
104
105   EXPECT_TRUE(plier::tf::parse_nodedef(mean_false_pbtxtdata, nodedef));
106
107   // what to test:
108   // - there should exist TFMean
109   // - input node should not be nullptr
110   // - reduction_indeces node should not be nullptr
111   // - keep_dims attribute is set same as pbtxt
112
113   tester.inputs({"Placeholder", "Const"});
114   tester.output("Mean");
115   tester.run(nodedef, graphbuilder);
116
117   auto test_node = loco::must_cast<moco::TFMean *>(tester.output());
118   ASSERT_NE(test_node, nullptr);
119   ASSERT_EQ(test_node->keep_dims(), false);
120 }