2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "BackendContext.h"
19 #include "backend/basic/train/TrainableBackendContextHelpers.h"
20 #include "exec/FunctionSequence.h"
31 backend::ITensorRegistry *BackendContext::genTensors()
33 // For now, there is no need to generate tensors for forwarding.
34 // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`.
35 // `Permute`: Tensor generation is not required.
36 // `IF`, `WHILE`: Not supported yet
37 return tensor_registry().get();
40 backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
42 // For now, there is no need to generate tensors for backwarding.
43 return tensor_registry().get();
46 backend::train::FunctionMap BackendContext::genKernels()
48 backend::train::FunctionMap ret;
50 for (auto &&op_ind : _tdata->op_order)
52 auto tn_seq = kernel_gen->generate(op_ind);
53 ret.emplace_back(op_ind, std::move(tn_seq));
56 trainable_graph()->operands().iterate(
57 [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
58 if (!external_operands().contains(ind) && operand.isConstant())
60 throw std::runtime_error(
61 "BackendContext: builtin backend does not support updatable weights yet");
65 // TODO Enable prepare()
66 // for (auto &&it : ret)
68 // auto &fn_seq = it.second;
69 // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
76 } // namespace builtin
77 } // namespace backend