1783cdca066053936faee0035c40f3ac93f63685
[platform/core/ml/nnfw.git] / runtime / onert / core / include / ir / Graph.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_IR_GRAPH_H__
18 #define __ONERT_IR_GRAPH_H__
19
20 #include <functional>
21 #include <unordered_map>
22
23 #include "ir/Model.h"
24 #include "ir/Operands.h"
25 #include "ir/Operations.h"
26
27 namespace onert
28 {
29 namespace backend
30 {
31 namespace custom
32 {
33 class IKernelBuilder;
34 } // namespace custom
35 } // namespace backend
36 } // namespace onert
37
38 namespace onert
39 {
40 namespace ir
41 {
42
43 class Graph
44 {
45 private:
46   enum class Phase
47   {
48     BUILDING,
49     MODEL
50   };
51
52 public:
53   explicit Graph(void);
54   explicit Graph(const Graph &);
55
56   ~Graph(void);
57
58   // Graph Building
59 public:
60   OperandIndex addOperand(const Shape &shape, const TypeInfo &type);
61   /**
62    * @brief Add an operand to the graph with the given index and object
63    *
64    * If the given index is available, it succeeds. And @c operand is moved which invalidates the
65    * caller's pointer. If the given index is already taken, it fails. And @c operand will not be
66    * moved so the caller's pointer will be still valid.
67    *
68    * @param[in] index Index to be added
69    * @param[in] operand Operand to be added
70    * @return OperandIndex @c index if successful, Undefined otherwise
71    */
72   OperandIndex addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand);
73   OperationIndex addOperation(std::unique_ptr<Operation> &&node);
74   /**
75    * @brief Add an operation to the graph with the given index and object
76    *
77    * If the given index is available, it succeeds. And @c operation is moved which invalidates the
78    * caller's pointer. If the given index is already taken, it fails. And @c operation will not be
79    * moved so the caller's pointer will be still valid.
80    *
81    * @param index Index to be added
82    * @param operation Operation to be added
83    * @return OperandIndex @c index if successful, Undefined otherwise
84    */
85   OperationIndex addOperation(OperationIndex index, std::unique_ptr<Operation> &&operation);
86   void setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data);
87   void addInput(const OperandIndex &ind, const std::string &name = "");
88   void addOutput(const OperandIndex &ind, const std::string &name = "");
89   void verify(void);
90   void removeOperand(const OperandIndex &ind) { _operands.remove(ind); }
91   void setLayout(Layout layout) { _layout = layout; }
92
93 private:
94   bool checkOperandsForOperation(const Operation &operation);
95   void linkOperandToOperation(OperationIndex index, const Operation &operation);
96   void initializeUseDef();
97   // TODO Rename to `sweepUnusedOperands`
98   // TODO Make this public
99   void sweepGarbageOperands();
100
101   // Custom operations support
102 public:
103   void
104   bindKernelBuilder(const std::shared_ptr<onert::backend::custom::IKernelBuilder> &kernel_builder)
105   {
106     _kernel_builder = kernel_builder;
107   }
108
109   const std::shared_ptr<backend::custom::IKernelBuilder> &getKernelBuilder() const
110   {
111     return _kernel_builder;
112   }
113
114 private:
115   std::shared_ptr<backend::custom::IKernelBuilder> _kernel_builder;
116
117   // Accessors
118 public:
119   const OperandIndexSequence &getInputs() const { return _inputs; }
120   OperandIndexSequence &getInputs() { return _inputs; }
121   const OperandIndexSequence &getOutputs() const { return _outputs; }
122   OperandIndexSequence &getOutputs() { return _outputs; }
123   IOIndex getInputIndex(const std::string &name) const;
124   IOIndex getOutputIndex(const std::string &name) const;
125   const Operands &operands() const { return _operands; }
126   Operands &operands() { return _operands; } // TODO Remove this non-const accessor
127   const Operations &operations() const { return _operations; }
128   Operations &operations() { return _operations; }
129   Layout layout() const { return _layout; }
130
131   // Topological sort
132 public:
133   std::vector<ir::OperationIndex> topolSortOperations() const;
134
135 private:
136   Operations _operations;
137   Operands _operands;
138   OperandIndexSequence _inputs;
139   OperandIndexSequence _outputs;
140   std::unordered_map<std::string, IOIndex> _name_to_input;
141   std::unordered_map<std::string, IOIndex> _name_to_output;
142   // TFLite and circle's default layout is NHWC;
143   Layout _layout{Layout::NHWC};
144 };
145
146 } // namespace ir
147 } // namespace onert
148
149 #endif // __ONERT_IR_GRAPH_H__