6f6eb77bc4cb00bd244c89a1ccdc545825d4840d
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ConstantInitializer.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ConstantInitializer.h"
18 #include "Tensor.h"
19
20 namespace onert
21 {
22 namespace backend
23 {
24 namespace cpu
25 {
26
27 ConstantInitializer::ConstantInitializer(const ir::Operands &operands,
28                                          const std::shared_ptr<ITensorRegistry> &tensor_reg)
29     : IConstantInitializer{operands}, _tensor_reg{tensor_reg}
30 {
31   // DO NOTHING
32 }
33
34 void ConstantInitializer::registerDefaultInitializer(const ir::OperandIndex &index,
35                                                      const ir::Operand &obj)
36 {
37   registerExternalInitializer(index, obj);
38 }
39
40 void ConstantInitializer::registerExternalInitializer(const ir::OperandIndex &index,
41                                                       const ir::Operand &obj)
42 {
43   // For only CONSTANTS
44   // TODO Add to check if tensor has been allocated
45   if (!obj.isConstant())
46     return;
47
48   _init_map[index] = [](const onert::ir::Operand &model_obj, onert::backend::ITensor &itensor) {
49     auto data = model_obj.shareData();
50     assert(data && data->base());
51     ExternalTensor &tensor = dynamic_cast<ExternalTensor &>(itensor);
52     tensor.setData(data);
53   };
54 }
55
56 void ConstantInitializer::visit(const ir::operation::Conv2D &node)
57 {
58   const auto &kernel_index = node.getInputs().at(ir::operation::Conv2D::KERNEL);
59   const auto &kernel_obj = _operands.at(kernel_index);
60   registerExternalInitializer(kernel_index, kernel_obj);
61
62   const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::BIAS);
63   const auto &bias_obj = _operands.at(bias_index);
64   registerExternalInitializer(bias_index, bias_obj);
65 }
66
67 void ConstantInitializer::visit(const ir::operation::DepthwiseConv2D &node)
68 {
69   const auto &kernel_index = node.getInputs().at(ir::operation::DepthwiseConv2D::KERNEL);
70   const auto &kernel_obj = _operands.at(kernel_index);
71   registerExternalInitializer(kernel_index, kernel_obj);
72
73   const auto &bias_index = node.getInputs().at(ir::operation::DepthwiseConv2D::BIAS);
74   const auto &bias_obj = _operands.at(bias_index);
75   registerExternalInitializer(bias_index, bias_obj);
76 }
77
78 void ConstantInitializer::visit(const ir::operation::FullyConnected &node)
79 {
80   const auto &weight_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT);
81   const auto &weight_obj = _operands.at(weight_index);
82   registerExternalInitializer(weight_index, weight_obj);
83
84   const auto &bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS);
85   if (!bias_index.undefined())
86   {
87     const auto &bias_obj = _operands.at(bias_index);
88     registerExternalInitializer(bias_index, bias_obj);
89   }
90 }
91
92 } // namespace cpu
93 } // namespace backend
94 } // namespace onert