2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Service/CircleNodeClone.h"
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Service/CircleShapeInference.h>
21 #include <luci/Service/CircleTypeInference.h>
23 #include <loco/IR/TensorShape.h>
25 #include <gtest/gtest.h>
27 TEST(ShapeRuleTest, simple_hardswish)
29 luci::CircleInput input;
30 luci::CircleHardSwish hard_swish;
33 input.shape_status(luci::ShapeStatus::VALID);
35 hard_swish.features(&input);
37 loco::TensorShape shape;
38 luci::sinf::Rule shape_inf_rule;
40 ASSERT_TRUE(shape_inf_rule.infer(&hard_swish, shape));
41 ASSERT_EQ(2, shape.rank());
42 ASSERT_EQ(3, shape.dim(0).value());
43 ASSERT_EQ(4, shape.dim(1).value());
46 TEST(DataTypeRuleTest, simple_hardswish)
48 luci::CircleInput input;
49 luci::CircleHardSwish hard_swish;
51 input.dtype(loco::DataType::S32);
53 hard_swish.features(&input);
56 luci::tinf::Rule type_inf_rule;
58 ASSERT_TRUE(type_inf_rule.infer(&hard_swish, dtype));
59 ASSERT_EQ(loco::DataType::S32, dtype);
62 TEST(CloneNodeTest, clone_HardSwish)
64 auto g = loco::make_graph();
65 auto node_hardswish = g->nodes()->create<luci::CircleHardSwish>();
67 auto gc = loco::make_graph();
68 auto cloned = luci::clone_node(node_hardswish, gc.get());
69 ASSERT_NE(nullptr, cloned);
70 ASSERT_EQ(gc.get(), cloned->graph());
72 auto cloned_hardswish = dynamic_cast<luci::CircleHardSwish *>(cloned);
73 ASSERT_NE(nullptr, cloned_hardswish);