2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "TrainableExecutors.h"
19 #include "../../backend/builtin/IOTensor.h"
21 #include <misc/polymorphic_downcast.h>
30 void TrainableExecutors::emplace(const ir::ModelIndex &, const ir::SubgraphIndex &subg_index,
31 std::unique_ptr<IExecutor> exec)
33 std::unique_ptr<TrainableExecutor> t_exec{
34 nnfw::misc::polymorphic_downcast<TrainableExecutor *>(exec.release())};
35 _executors.emplace(subg_index, std::move(t_exec));
38 TrainableExecutor *TrainableExecutors::at(const ir::ModelIndex &,
39 const ir::SubgraphIndex &subg_index) const
41 return _executors.at(subg_index).get();
44 uint32_t TrainableExecutors::inputSize() const { return entryExecutor()->getInputTensors().size(); }
46 uint32_t TrainableExecutors::outputSize() const
48 return entryExecutor()->getOutputTensors().size();
51 const ir::OperandInfo &TrainableExecutors::inputInfo(const ir::IOIndex &index) const
53 return entryExecutor()->getInputTensors().at(index.value())->orig_info();
56 const ir::OperandInfo &TrainableExecutors::outputInfo(const ir::IOIndex &index) const
58 return entryExecutor()->getOutputTensors().at(index.value())->orig_info();
61 void TrainableExecutors::execute(const IODescription &desc)
63 if (_executors.size() > 1)
64 throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
65 entryExecutor()->forward(desc, false);
67 // TODO Support multple executors
70 void TrainableExecutors::train(const IODescription &desc, uint32_t training_step)
72 if (_executors.size() > 1)
73 throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
74 entryExecutor()->forward(desc, true);
75 entryExecutor()->backward(desc, training_step);
77 // TODO Support multple executors
80 float TrainableExecutors::getLoss(const ir::IOIndex &index) const
82 if (_executors.size() > 1)
83 throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
84 return entryExecutor()->getLoss(index);