2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
18 #define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
20 #include "exec/IExecutor.h"
22 #include "../ExecutionObservee.h"
23 #include "../../compiler/train/TensorRegistries.h"
25 #include "backend/train/TrainableBackendContext.h"
26 #include "compiler/train/TrainableCodeMap.h"
27 #include "compiler/train/LoweredTrainableGraph.h"
29 #include "util/TracingCtx.h"
38 class TrainableExecutor : public IExecutor
42 * @brief Construct a new TrainableExecutor object
43 * @param lowered_graph LoweredTrainableGraph object
44 * @param tensor_builders Tensor builders that are currently used
45 * @param code_map @c ir::Operation and its code map
47 TrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
48 backend::train::TrainableBackendContexts &&backend_contexts,
49 const compiler::train::TensorRegistries &tensor_regs,
50 compiler::train::TrainableCodeMap &&code_map,
51 const std::vector<ir::OperationIndex> &order,
52 const util::TracingCtx *tracing_ctx);
55 const ir::Graph &graph() const final { return _trainable_graph.graph(); }
57 void execute(const IODescription &desc) override { forward(desc, false); };
59 void execute(const std::vector<backend::IPortableTensor *> &inputs,
60 const std::vector<backend::IPortableTensor *> &outputs) override;
62 void forward(const IODescription &desc, bool training);
63 void backward(const IODescription &desc, uint32_t training_step);
65 // Used only in Dataflow and Parallel Executors
66 void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final
68 _indexed_ranks = std::move(ranks);
71 void addObserver(std::unique_ptr<IExecutionObserver> ref) { _subject.add(std::move(ref)); };
73 const std::vector<backend::builtin::IOTensor *> &getInputTensors() const override
75 return _input_tensors;
78 const std::vector<backend::builtin::IOTensor *> &getOutputTensors() const override
80 return _output_tensors;
83 float getLoss(const ir::IOIndex &pred_io_ind) const;
85 backend::train::TrainableBackendContexts &getBackendContexts() { return _backend_contexts; }
88 void forwardImpl(bool training);
89 void backwardImpl(uint32_t training_step);
92 std::vector<compiler::train::TrainableCodeAndInfo> _code;
93 ExecutionObservee _subject;
94 std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks;
95 std::unique_ptr<compiler::train::LoweredTrainableGraph> _lowered_graph;
96 backend::train::TrainableBackendContexts _backend_contexts;
97 const ir::train::TrainableGraph &_trainable_graph;
98 compiler::train::TensorRegistries _tensor_regs;
99 std::vector<backend::builtin::IOTensor *> _input_tensors;
100 std::vector<backend::builtin::IOTensor *> _output_tensors;
102 const util::TracingCtx *_tracing_ctx;
109 #endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_