Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / include / ir / train / TrainableGraph.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__
18 #define __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__
19
20 #include <functional>
21 #include <unordered_map>
22
23 #include "ir/Graph.h"
24 #include "ir/train/ITrainableOperation.h"
25
26 namespace onert
27 {
28 namespace ir
29 {
30 namespace train
31 {
32
33 class TrainableGraph : public IGraph
34 {
35 public:
36   /**
37    * @brief Construct a new Trainable Graph object
38    *
39    * @param graph
40    */
41   explicit TrainableGraph();
42   explicit TrainableGraph(const TrainableGraph &tgraph);
43   explicit TrainableGraph(const Graph &graph);
44   ~TrainableGraph() = default;
45
46   // TrainableGraph Building
47 public:
48   OperandIndex addOperand(const Shape &shape, const TypeInfo &type);
49   /**
50    * @brief Add an operand to the graph with the given index and object
51    *
52    * If the given index is available, it succeeds. And @c operand is moved which invalidates the
53    * caller's pointer. If the given index is already taken, it fails. And @c operand will not be
54    * moved so the caller's pointer will be still valid.
55    *
56    * @param[in] index Index to be added
57    * @param[in] operand Operand to be added
58    * @return OperandIndex @c index if successful, UNDEFINED otherwise
59    */
60   OperandIndex addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand);
61   /**
62    * @brief Add a new trainable operation to the graph
63    *
64    * If the given @c operation has at least one invalid operand index, it fails. And @c operation
65    * will not be moved so the caller's pointer will be still valid.
66    *
67    * @param operation Operation to be added
68    * @return OperationIndex @c index if successful, UNDEFINED otherwise
69    */
70   OperationIndex addOperation(std::unique_ptr<ITrainableOperation> &&operation);
71   /**
72    * @brief Replace a trainable operation which the graph already has
73    *
74    * If the given @c index is available, it succeeds. And @c operation is moved which invalidates
75    * the caller's pointer. If the given @c operation has at least one invalid operand index, it
76    * fails. And @c operation will not be moved so the caller's pointer will be still valid.
77    *
78    * No information in the graph is changed except for replacing an operation.
79    *
80    * @param operation Operation to be added
81    * @return OperationIndex @c index if successful, UNDEFINED otherwise
82    */
83   OperationIndex replaceOperation(OperationIndex index,
84                                   std::unique_ptr<ITrainableOperation> &&operation);
85
86   /**
87    * @brief Add a derivative to the graph with the given index and object
88    *
89    * If the given index is available, it succeeds. And @c derivative is moved which invalidates the
90    * caller's pointer. If the given index is already taken, it fails. And @c derivative will not be
91    * moved so the caller's pointer will be still valid.
92    *
93    * @param[in] index      Index to be added
94    * @param[in] derivative Derivative operand to be added
95    * @return OperandIndex @c index if successful, UNDEFINED otherwise
96    */
97   OperandIndex addDerivative(OperandIndex index, std::unique_ptr<Operand> &&derivative);
98
99 public:
100   void changeShape(const OperandIndex &ind, const ir::Shape &new_shape) override;
101   void changeDerivativeShape(const OperandIndex &ind, const ir::Shape &new_shape);
102   void addInput(const OperandIndex &ind, const std::string &name = "");
103   void addOutput(const OperandIndex &ind, const std::string &name = "");
104   void addLoss(const OperandIndex &loss_ind, const IOIndex &pred_io_ind);
105   void verify() const;
106   void removeOperand(const OperandIndex &ind);
107   void setLayout(Layout layout);
108   void setInputs(OperandIndexSequence inputs,
109                  std::unordered_map<std::string, IOIndex> name_to_input);
110   void setOutputs(OperandIndexSequence outputs,
111                   std::unordered_map<std::string, IOIndex> name_to_output);
112
113   // Accessors
114 public:
115   const OperandIndexSequence &getInputs() const override { return _graph.getInputs(); }
116   const OperandIndexSequence &getOutputs() const override { return _graph.getOutputs(); }
117   IOIndex getInputIndex(const std::string &name) const override;
118   IOIndex getOutputIndex(const std::string &name) const override;
119   const Operands &operands() const override { return _graph.operands(); }
120   Operands &operands() { return _graph.operands(); } // TODO Remove this non-const accessor
121   const Operations &operations() const override { return _graph.operations(); }
122   const Operands &derivatives() const { return _derivatives; }
123   OperandIndex getLossIndex(const IOIndex &pred_io_ind) const;
124   Layout layout() const { return _graph.layout(); }
125   const Graph &graph() const { return _graph; }
126
127 public:
128   const ITrainableOperation &operation(OperationIndex index) const;
129
130 public:
131   std::vector<ir::OperationIndex> topolSortOperations() const;
132   // TODO Support topological sort for backwarding
133
134 private:
135   Graph _graph;
136   Operands _derivatives;
137
138   std::unordered_map<IOIndex, OperandIndex> _losses;
139 };
140
141 } // namespace train
142 } // namespace ir
143 } // namespace onert
144
145 #endif // __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__