2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "KernelGenerator.h"
19 #include "ops/ConvolutionLayer.h"
20 #include "ops/ElementwiseActivationLayer.h"
21 #include "ops/FullyConnectedLayer.h"
22 #include "ops/LossLayer.h"
23 #include "ops/GradientApplier.h"
24 #include "ops/PoolLayer.h"
25 #include "ops/ReshapeLayer.h"
27 #include <backend/Backend.h>
28 #include <backend/IConfig.h>
30 #include <util/Utils.h>
31 #include <util/logging.h>
32 #include <exec/DynamicShapeInferer.h>
45 ops::ElementwiseActivationType
46 convertElementwiseActivationType(ir::operation::ElementwiseActivation::Type type_ir)
50 case ir::operation::ElementwiseActivation::Type::RELU:
51 return ops::ElementwiseActivationType::kReLU;
53 throw std::runtime_error("train KernelGenerator : Not supported operation yet");
57 ops::LossType convertLossType(ir::operation::Loss::Type type_ir)
61 case ir::operation::Loss::Type::MEAN_SQUARED_ERROR:
62 return ops::LossType::kMSE;
64 throw std::runtime_error("train KernelGenerator : Not supported operation yet");
68 ops::PoolType convertPoolType(ir::operation::Pool2D::PoolType type_ir)
72 // TODO Implement AVG PoolType
73 case ir::operation::Pool2D::PoolType::MAX:
74 return ops::PoolType::kMax;
76 throw std::runtime_error("train KernelGenerator : Not supported operation yet");
80 std::unique_ptr<ops::GradientApplier>
81 generateGradientApplier(const std::shared_ptr<exec::train::optimizer::Optimizer> optimizer,
82 const IPortableTensor *gradient, ITrainableTensor *trainable)
84 auto update_fn = std::make_unique<ops::GradientApplier>();
85 update_fn->configure(optimizer, gradient, trainable);
90 std::unique_ptr<exec::train::TrainableFnSequence> KernelGenerator::generate(ir::OperationIndex idx)
92 auto ret = std::make_unique<exec::train::TrainableFnSequence>();
94 const auto &op = _tgraph.operation(idx);
97 ret->append(std::move(_return_fn));
99 for (auto &&update_fn : _update_funcs)
100 ret->append(std::move(update_fn));
101 _update_funcs.clear();
103 for (auto &&ind : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs())
105 auto portable_tensor = _tensor_reg->getPortableTensor(ind);
108 assert(portable_tensor->layout() == ir::Layout::NHWC);
110 auto tensor = _tensor_reg->getNonConstTensor(ind);
113 tensor->increase_ref();
119 KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph,
120 const std::shared_ptr<TensorRegistry> &tensor_reg,
121 const std::shared_ptr<ExternalContext> &external_context,
122 std::shared_ptr<exec::train::optimizer::Optimizer> optimizer)
123 : backend::train::KernelGeneratorBase{tgraph}, _current_layout{tgraph.layout()},
124 _tensor_reg{tensor_reg},
125 _external_context(external_context), _optimizer{optimizer}, _update_funcs{}
130 void KernelGenerator::visit(const ir::train::operation::Conv2D &node)
132 // TODO Generate kernel
134 // Generate GradientApplier
135 const auto ker_index{node.getInputs().at(ir::train::operation::Conv2D::Input::KERNEL)};
137 auto grad_tensor = _tensor_reg->getGradientTensor(ker_index);
138 auto ker_tensor = _tensor_reg->getTrainableTensor(ker_index);
140 auto update_fn = std::make_unique<ops::GradientApplier>();
142 update_fn->configure(_optimizer, grad_tensor, ker_tensor);
144 _update_funcs.emplace_back(generateGradientApplier(_optimizer, grad_tensor, ker_tensor));
147 void KernelGenerator::visit(const ir::train::operation::ElementwiseActivation &node)
149 using ir::train::operation::ElementwiseActivation;
151 const auto output_index{node.getOutputs().at(0)};
152 const auto input_index{node.getInputs().at(ElementwiseActivation::Input::INPUT)};
154 auto output_tensor = _tensor_reg->getPortableTensor(output_index);
155 auto input_tensor = _tensor_reg->getPortableTensor(input_index);
157 auto deriv_input_tensor = _tensor_reg->getDerivativeTensor(input_index);
158 auto deriv_output_tensor = _tensor_reg->getDerivativeTensor(output_index);
160 auto fn = std::make_unique<ops::ElementwiseActivationLayer>();
162 fn->configure(input_tensor, output_tensor, deriv_input_tensor, deriv_output_tensor,
163 node.param().alpha, node.param().beta,
164 convertElementwiseActivationType(node.param().op_type));
166 _return_fn = std::move(fn);
169 void KernelGenerator::visit(const ir::train::operation::FullyConnected &node)
171 using ir::train::operation::FullyConnected;
173 const auto out_index{node.getOutputs().at(0)};
174 const auto in_index{node.getInputs().at(FullyConnected::Input::INPUT)};
175 const auto weights_index{node.getInputs().at(FullyConnected::Input::WEIGHT)};
176 const auto bias_index{node.getInputs().at(FullyConnected::Input::BIAS)};
178 auto out_tensor = _tensor_reg->getPortableTensor(out_index);
179 auto in_tensor = _tensor_reg->getPortableTensor(in_index);
180 auto weights_tensor = _tensor_reg->getTrainableTensor(weights_index);
181 auto bias_tensor = _tensor_reg->getTrainableTensor(bias_index);
183 auto out_deriv_tensor = _tensor_reg->getDerivativeTensor(out_index);
184 auto in_deriv_tensor = _tensor_reg->getDerivativeTensor(in_index);
185 auto weights_grad_tensor = _tensor_reg->getGradientTensor(weights_index);
186 auto bias_grad_tensor = _tensor_reg->getGradientTensor(bias_index);
189 const auto activation = node.param().activation;
190 const auto weights_format = node.param().weights_format;
192 auto fn = std::make_unique<ops::FullyConnectedLayer>();
194 fn->configure(in_tensor, weights_tensor, bias_tensor, out_tensor, in_deriv_tensor,
195 weights_grad_tensor, bias_grad_tensor, out_deriv_tensor, activation, weights_format,
198 _return_fn = std::move(fn);
200 // Generate GradientAppliers
202 _update_funcs.emplace_back(generateGradientApplier(_optimizer, bias_grad_tensor, bias_tensor));
203 _update_funcs.emplace_back(
204 generateGradientApplier(_optimizer, weights_grad_tensor, weights_tensor));
207 void KernelGenerator::visit(const ir::train::operation::Loss &node)
209 using ir::train::operation::Loss;
211 const auto output_index{node.getOutputs().at(0)};
212 const auto y_pred_index{node.getInputs().at(Loss::Y_PRED)};
213 const auto y_true_index{node.getInputs().at(Loss::Y_TRUE)};
215 auto output_tensor = _tensor_reg->getPortableTensor(output_index);
216 auto y_pred_tensor = _tensor_reg->getPortableTensor(y_pred_index);
217 auto y_true_tensor = _tensor_reg->getPortableTensor(y_true_index);
219 auto deriv_y_pred_tensor = _tensor_reg->getDerivativeTensor(y_pred_index);
220 auto fn = std::make_unique<ops::LossLayer>();
222 fn->configure(y_pred_tensor, y_true_tensor, output_tensor, deriv_y_pred_tensor,
223 convertLossType(node.param().op_type));
225 _return_fn = std::move(fn);
227 UNUSED_RELEASE(convertPoolType);
230 void KernelGenerator::visit(const ir::train::operation::Reshape &node)
232 using ir::train::operation::Reshape;
234 const auto output_index{node.getOutputs().at(0)};
235 const auto input_index{node.getInputs().at(ir::operation::Reshape::Input::INPUT)};
237 auto output_tensor = _tensor_reg->getPortableTensor(output_index);
238 auto input_tensor = _tensor_reg->getPortableTensor(input_index);
240 auto output_deriv_tensor = _tensor_reg->getDerivativeTensor(output_index);
241 auto input_deriv_tensor = _tensor_reg->getDerivativeTensor(input_index);
243 // optional 2nd input
244 IPortableTensor *shape_tensor = nullptr;
246 if (node.getInputs().size() == 2)
248 const auto shape_index{node.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
249 shape_tensor = _tensor_reg->getPortableTensor(shape_index);
252 auto fn = std::make_unique<ops::ReshapeLayer>();
254 fn->configure(input_tensor, shape_tensor, output_tensor, input_deriv_tensor, output_deriv_tensor);
255 _return_fn = std::move(fn);
259 } // namespace backend