Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / helpers / CreateCircleConst.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LUCI_PASS_HELPERS_CREATE_CIRCLE_CONST_H__
18 #define __LUCI_PASS_HELPERS_CREATE_CIRCLE_CONST_H__
19
20 #include <luci/IR/CircleNodes.h>
21
22 #include "TypeMapper.h"
23
24 #include <vector>
25
26 namespace luci
27 {
28
29 // Create CircleConst filled with a single value
30 // Never return nullptr
31 // TODO Remove dtype from the argument
32 template <typename T>
33 CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
34                                const std::vector<uint32_t> &shape, const T value)
35 {
36   auto node = g->nodes()->create<CircleConst>();
37   node->dtype(dtype);
38   node->rank(shape.size());
39
40   uint32_t size = 1;
41   for (uint32_t i = 0; i < shape.size(); ++i)
42   {
43     node->dim(i) = shape.at(i);
44     size *= shape.at(i);
45   }
46   node->shape_status(ShapeStatus::VALID);
47
48   node->size<TypeMapper<T>::get()>(size);
49   for (uint32_t i = 0; i < size; i++)
50   {
51     node->at<TypeMapper<T>::get()>(i) = value;
52   }
53
54   return node;
55 }
56
57 // Create CircleConst filled with values
58 // Never return nullptr
59 // TODO Remove dtype from the argument
60 template <typename T>
61 luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
62                                      const std::vector<uint32_t> &shape,
63                                      const std::vector<T> &values)
64 {
65   auto node = g->nodes()->create<luci::CircleConst>();
66   node->dtype(dtype);
67   node->rank(shape.size());
68
69   uint32_t size = 1;
70   for (uint32_t i = 0; i < shape.size(); ++i)
71   {
72     node->dim(i) = shape.at(i);
73     size *= shape.at(i);
74   }
75   node->shape_status(luci::ShapeStatus::VALID);
76
77   node->size<TypeMapper<T>::get()>(size);
78   for (uint32_t i = 0; i < size; i++)
79   {
80     node->at<TypeMapper<T>::get()>(i) = values[i];
81   }
82
83   return node;
84 }
85
86 } // namespace luci
87
88 #endif // __LUCI_PASS_HELPERS_CREATE_CIRCLE_CONST_H__