2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
19 #include <luci/IR/CircleNodes.h>
21 #include <gtest/gtest.h>
23 void create_fc_net(loco::Graph *g)
27 const uint32_t ROW = 16;
28 const uint32_t COL = 2;
29 const uint32_t elements_num = ROW * COL;
32 auto input = g->nodes()->create<luci::CircleInput>();
33 auto graph_input = g->inputs()->create();
34 input->index(graph_input->index());
37 auto weights = g->nodes()->create<luci::CircleConst>();
38 weights->dtype(loco::DataType::FLOAT32);
39 weights->size<loco::DataType::FLOAT32>(elements_num);
41 weights->dim(0).set(ROW);
42 weights->dim(1).set(COL);
43 for (uint32_t idx = 0; idx < elements_num; idx++)
45 weights->at<loco::DataType::FLOAT32>(idx) = idx;
49 auto fc = g->nodes()->create<luci::CircleFullyConnected>();
50 fc->dtype(loco::DataType::FLOAT32);
55 auto output = g->nodes()->create<luci::CircleOutput>();
57 auto graph_output = g->outputs()->create();
58 output->index(graph_output->index());
61 TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1)
63 auto graph = loco::make_graph();
64 create_fc_net(graph.get());
66 luci::CircleFullyConnected *fc_node = nullptr;
67 for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
69 auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
76 ASSERT_NE(fc_node, nullptr);
77 auto weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
79 ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0));
80 ASSERT_EQ(1, weights->at<loco::DataType::FLOAT32>(1));
81 ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(2));
82 ASSERT_EQ(3, weights->at<loco::DataType::FLOAT32>(3));
83 ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(4));
84 ASSERT_EQ(5, weights->at<loco::DataType::FLOAT32>(5));
85 ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(6));
86 ASSERT_EQ(7, weights->at<loco::DataType::FLOAT32>(7));
87 ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(8));
88 ASSERT_EQ(9, weights->at<loco::DataType::FLOAT32>(9));
89 ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(10));
90 ASSERT_EQ(11, weights->at<loco::DataType::FLOAT32>(11));
91 ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(12));
92 ASSERT_EQ(13, weights->at<loco::DataType::FLOAT32>(13));
93 ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(14));
94 ASSERT_EQ(15, weights->at<loco::DataType::FLOAT32>(15));
96 luci::ShuffleWeightTo16x1Float32Pass pass;
97 while (pass.run(graph.get()))
100 weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
102 ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0));
103 ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(1));
104 ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(2));
105 ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(3));
106 ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(4));
107 ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(5));
108 ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(6));
109 ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(7));
110 ASSERT_EQ(16, weights->at<loco::DataType::FLOAT32>(8));
111 ASSERT_EQ(18, weights->at<loco::DataType::FLOAT32>(9));
112 ASSERT_EQ(20, weights->at<loco::DataType::FLOAT32>(10));
113 ASSERT_EQ(22, weights->at<loco::DataType::FLOAT32>(11));
114 ASSERT_EQ(24, weights->at<loco::DataType::FLOAT32>(12));
115 ASSERT_EQ(26, weights->at<loco::DataType::FLOAT32>(13));
116 ASSERT_EQ(28, weights->at<loco::DataType::FLOAT32>(14));
117 ASSERT_EQ(30, weights->at<loco::DataType::FLOAT32>(15));