2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "moco/Import/Nodes/Mean.h"
18 #include "TestHelper.h"
20 #include <gtest/gtest.h>
22 using namespace moco::test;
28 const char *mean_true_pbtxtdata = STRING_CONTENT(
35 value { type: DT_FLOAT }
39 value { type: DT_INT32 }
50 TEST(TensorFlowImport, mean_true)
52 TFNodeBuildTester tester;
53 moco::MeanGraphBuilder graphbuilder;
54 tensorflow::NodeDef nodedef;
56 EXPECT_TRUE(plier::tf::parse_nodedef(mean_true_pbtxtdata, nodedef));
59 // - there should exist TFMean
60 // - input node should not be nullptr
61 // - reduction_indeces node should not be nullptr
62 // - keep_dims attribute is set same as pbtxt
64 tester.inputs({"Placeholder", "Const"});
65 tester.output("Mean");
66 tester.run(nodedef, graphbuilder);
68 auto test_node = loco::must_cast<moco::TFMean *>(tester.output());
69 ASSERT_NE(test_node, nullptr);
70 ASSERT_EQ(test_node->keep_dims(), true);
77 const char *mean_false_pbtxtdata = STRING_CONTENT(
84 value { type: DT_FLOAT }
88 value { type: DT_INT32 }
99 TEST(TensorFlowImport, mean_false)
101 TFNodeBuildTester tester;
102 moco::MeanGraphBuilder graphbuilder;
103 tensorflow::NodeDef nodedef;
105 EXPECT_TRUE(plier::tf::parse_nodedef(mean_false_pbtxtdata, nodedef));
108 // - there should exist TFMean
109 // - input node should not be nullptr
110 // - reduction_indeces node should not be nullptr
111 // - keep_dims attribute is set same as pbtxt
113 tester.inputs({"Placeholder", "Const"});
114 tester.output("Mean");
115 tester.run(nodedef, graphbuilder);
117 auto test_node = loco::must_cast<moco::TFMean *>(tester.output());
118 ASSERT_NE(test_node, nullptr);
119 ASSERT_EQ(test_node->keep_dims(), false);