Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / Nodes / CircleHardSwish.test.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Service/CircleNodeClone.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Service/CircleShapeInference.h>
21 #include <luci/Service/CircleTypeInference.h>
22
23 #include <loco/IR/TensorShape.h>
24
25 #include <gtest/gtest.h>
26
27 TEST(ShapeRuleTest, simple_hardswish)
28 {
29   luci::CircleInput input;
30   luci::CircleHardSwish hard_swish;
31
32   input.shape({3, 4});
33   input.shape_status(luci::ShapeStatus::VALID);
34
35   hard_swish.features(&input);
36
37   loco::TensorShape shape;
38   luci::sinf::Rule shape_inf_rule;
39
40   ASSERT_TRUE(shape_inf_rule.infer(&hard_swish, shape));
41   ASSERT_EQ(2, shape.rank());
42   ASSERT_EQ(3, shape.dim(0).value());
43   ASSERT_EQ(4, shape.dim(1).value());
44 }
45
46 TEST(DataTypeRuleTest, simple_hardswish)
47 {
48   luci::CircleInput input;
49   luci::CircleHardSwish hard_swish;
50
51   input.dtype(loco::DataType::S32);
52
53   hard_swish.features(&input);
54
55   loco::DataType dtype;
56   luci::tinf::Rule type_inf_rule;
57
58   ASSERT_TRUE(type_inf_rule.infer(&hard_swish, dtype));
59   ASSERT_EQ(loco::DataType::S32, dtype);
60 }
61
62 TEST(CloneNodeTest, clone_HardSwish)
63 {
64   auto g = loco::make_graph();
65   auto node_hardswish = g->nodes()->create<luci::CircleHardSwish>();
66
67   auto gc = loco::make_graph();
68   auto cloned = luci::clone_node(node_hardswish, gc.get());
69   ASSERT_NE(nullptr, cloned);
70   ASSERT_EQ(gc.get(), cloned->graph());
71
72   auto cloned_hardswish = dynamic_cast<luci::CircleHardSwish *>(cloned);
73   ASSERT_NE(nullptr, cloned_hardswish);
74 }