7131dc115d82dc91ce532988ddc3e22cf3e691d3
[platform/core/ml/nnfw.git] / compiler / luci / import / src / Nodes / CircleConst.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Import/Nodes/CircleConst.h"
18
19 #include <luci/IR/Nodes/CircleConst.h>
20 #include <luci/Log.h>
21
22 #include <loco.h>
23 #include <oops/UserExn.h>
24
25 #include <cassert>
26
27 namespace
28 {
29
30 std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect)
31 {
32   uint32_t seq = 0;
33   for (auto &v : vect)
34   {
35     if (seq)
36       os << ", ";
37     os << v;
38     seq++;
39   }
40   return os;
41 }
42
43 } // namespace
44
45 namespace luci
46 {
47
48 template <loco::DataType DT>
49 static void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements,
50                       CircleConst *const_node)
51 {
52   using T = typename loco::DataTypeImpl<DT>::Type;
53
54   assert(raw_data.size() == num_elements * sizeof(T));
55   const auto *data = reinterpret_cast<const T *>(raw_data.data());
56
57   const_node->size<DT>(num_elements);
58   for (uint32_t i = 0; i < num_elements; ++i)
59   {
60     const_node->at<DT>(i) = data[i];
61   }
62 }
63
64 //
65 // circleconst_from_tensor() ?
66 //
67 CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index)
68 {
69   LOGGER(l);
70
71   auto graph = context->graph();
72   auto reader = context->reader();
73   const auto &tensors = reader->tensors();
74   const circle::TensorT &const_tensor = *tensors[tensor_index];
75
76   const std::vector<uint8_t> &buffer = reader->buffers()[const_tensor.buffer]->data;
77   std::vector<int32_t> const_dims = const_tensor.shape; // in NHWC
78   if (const_dims.size() == 0 && buffer.empty())
79   {
80     // unknown shape tensor
81     return nullptr;
82   }
83
84   // if tensor_index is used as output to some other operator, this is not a constant
85   auto tensoroutputs = context->tensoroutputs();
86   if (tensoroutputs->find(tensor_index))
87   {
88     // other operator output tensor
89     return nullptr;
90   }
91
92   uint32_t num_elements = 1;
93   for (uint32_t r = 0; r < const_dims.size(); ++r)
94   {
95     num_elements = num_elements * const_dims[r];
96   }
97
98   if (buffer.empty() && num_elements > 0)
99   {
100     // normal empty tensor
101     return nullptr;
102   }
103
104   auto const_node = graph->nodes()->create<CircleConst>();
105   copy_tensor_attributes(const_tensor, const_node);
106   const_node->shape_status(luci::ShapeStatus::VALID);
107   INFO(l) << "[luci] NodeFinder const_node(" << tensor_index << ") -> " << const_node << " "
108           << const_dims << std::endl;
109   if (num_elements > 0)
110   {
111     switch (luci_datatype(const_tensor.type))
112     {
113       case loco::DataType::FLOAT32:
114         copy_data<loco::DataType::FLOAT32>(buffer, num_elements, const_node);
115         break;
116
117       case loco::DataType::U8:
118         copy_data<loco::DataType::U8>(buffer, num_elements, const_node);
119         break;
120
121       case loco::DataType::S16:
122         copy_data<loco::DataType::S16>(buffer, num_elements, const_node);
123         break;
124
125       case loco::DataType::S32:
126         copy_data<loco::DataType::S32>(buffer, num_elements, const_node);
127         break;
128
129       case loco::DataType::S64:
130         copy_data<loco::DataType::S64>(buffer, num_elements, const_node);
131         break;
132
133       case loco::DataType::BOOL:
134         copy_data<loco::DataType::BOOL>(buffer, num_elements, const_node);
135         break;
136
137       default:
138         throw oops::UserExn("Unsupported tensor type",
139                             circle::EnumNameTensorType(const_tensor.type));
140     }
141   }
142
143   return const_node;
144 }
145
146 } // namespace luci