2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__
18 #define __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__
21 #include <unordered_map>
24 #include "ir/train/ITrainableOperation.h"
33 class TrainableGraph : public IGraph
37 * @brief Construct a new Trainable Graph object
41 explicit TrainableGraph();
42 explicit TrainableGraph(const TrainableGraph &tgraph);
43 explicit TrainableGraph(const Graph &graph);
44 ~TrainableGraph() = default;
46 // TrainableGraph Building
48 OperandIndex addOperand(const Shape &shape, const TypeInfo &type);
50 * @brief Add an operand to the graph with the given index and object
52 * If the given index is available, it succeeds. And @c operand is moved which invalidates the
53 * caller's pointer. If the given index is already taken, it fails. And @c operand will not be
54 * moved so the caller's pointer will be still valid.
56 * @param[in] index Index to be added
57 * @param[in] operand Operand to be added
58 * @return OperandIndex @c index if successful, UNDEFINED otherwise
60 OperandIndex addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand);
62 * @brief Add a new trainable operation to the graph
64 * If the given @c operation has at least one invalid operand index, it fails. And @c operation
65 * will not be moved so the caller's pointer will be still valid.
67 * @param operation Operation to be added
68 * @return OperationIndex @c index if successful, UNDEFINED otherwise
70 OperationIndex addOperation(std::unique_ptr<ITrainableOperation> &&operation);
72 * @brief Replace a trainable operation which the graph already has
74 * If the given @c index is available, it succeeds. And @c operation is moved which invalidates
75 * the caller's pointer. If the given @c operation has at least one invalid operand index, it
76 * fails. And @c operation will not be moved so the caller's pointer will be still valid.
78 * No information in the graph is changed except for replacing an operation.
80 * @param operation Operation to be added
81 * @return OperationIndex @c index if successful, UNDEFINED otherwise
83 OperationIndex replaceOperation(OperationIndex index,
84 std::unique_ptr<ITrainableOperation> &&operation);
87 * @brief Add a derivative to the graph with the given index and object
89 * If the given index is available, it succeeds. And @c derivative is moved which invalidates the
90 * caller's pointer. If the given index is already taken, it fails. And @c derivative will not be
91 * moved so the caller's pointer will be still valid.
93 * @param[in] index Index to be added
94 * @param[in] derivative Derivative operand to be added
95 * @return OperandIndex @c index if successful, UNDEFINED otherwise
97 OperandIndex addDerivative(OperandIndex index, std::unique_ptr<Operand> &&derivative);
100 void changeShape(const OperandIndex &ind, const ir::Shape &new_shape) override;
101 void changeDerivativeShape(const OperandIndex &ind, const ir::Shape &new_shape);
102 void addInput(const OperandIndex &ind, const std::string &name = "");
103 void addOutput(const OperandIndex &ind, const std::string &name = "");
104 void addLoss(const OperandIndex &loss_ind, const IOIndex &pred_io_ind);
106 void removeOperand(const OperandIndex &ind);
107 void setLayout(Layout layout);
108 void setInputs(OperandIndexSequence inputs,
109 std::unordered_map<std::string, IOIndex> name_to_input);
110 void setOutputs(OperandIndexSequence outputs,
111 std::unordered_map<std::string, IOIndex> name_to_output);
115 const OperandIndexSequence &getInputs() const override { return _graph.getInputs(); }
116 const OperandIndexSequence &getOutputs() const override { return _graph.getOutputs(); }
117 IOIndex getInputIndex(const std::string &name) const override;
118 IOIndex getOutputIndex(const std::string &name) const override;
119 const Operands &operands() const override { return _graph.operands(); }
120 Operands &operands() { return _graph.operands(); } // TODO Remove this non-const accessor
121 const Operations &operations() const override { return _graph.operations(); }
122 const Operands &derivatives() const { return _derivatives; }
123 OperandIndex getLossIndex(const IOIndex &pred_io_ind) const;
124 Layout layout() const { return _graph.layout(); }
125 const Graph &graph() const { return _graph; }
128 const ITrainableOperation &operation(OperationIndex index) const;
131 std::vector<ir::OperationIndex> topolSortOperations() const;
132 // TODO Support topological sort for backwarding
136 Operands _derivatives;
138 std::unordered_map<IOIndex, OperandIndex> _losses;
145 #endif // __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__