94d6ba1a77cdf3496e149f8d8056203943e55fdd
[platform/core/ml/nnfw.git] / runtime / onert / core / include / compiler / StaticShapeInferer.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_COMPILER_STATIC_SHAPE_INFERER_H__
18 #define __ONERT_COMPILER_STATIC_SHAPE_INFERER_H__
19
20 #include "ir/OperationVisitor.h"
21 #include "compiler/LoweredGraph.h"
22 #include "ir/Index.h"
23
24 #include <memory>
25 #include <unordered_map>
26
27 namespace onert
28 {
29 namespace compiler
30 {
31 /**
32  * @brief Class that observe and update operands.
33  */
34 class OperandObserver
35 {
36 public:
37   /**
38    * @brief Constructor of OperandObserver
39    *
40    * @param operands Operands to be updated
41    */
42   OperandObserver(const std::vector<ir::Operand *> &operands) : _operands{operands} {}
43   /**
44    * @brief Destructor of OperandObserver
45    */
46   virtual ~OperandObserver() = default;
47
48 public:
49   /**
50    * @brief Update Shape and some OperandInfo of operands
51    *
52    * @param operands Operands to be updated
53    * @param unpredictable Whether runtime can predict shapes of operands in compilation time
54    */
55   void updateShapes(const std::vector<ir::OperandInfo> &changed_operands_info,
56                     bool unpredictable = false);
57
58 private:
59   std::vector<ir::Operand *> _operands;
60 };
61
62 /**
63  * @brief Class to infer shape before running kernels. It does the following:
64  *        - re-calculate and set output shape at compile time (before running kernels)
65  *        - if calculation cannot be done at compile time, mark the outputs to be dynamic, meaning
66  *          shapes of outputs will be calculated during running kernels
67  */
68 class StaticShapeInferer : public ir::OperationVisitor
69 {
70 public:
71   StaticShapeInferer(compiler::LoweredGraph *lowered_subg)
72     : _lowered_subg{lowered_subg}, _subg_input_observers{}, _controlflow_output_observer{nullptr},
73       _child_inferers{}
74   {
75   }
76   virtual ~StaticShapeInferer() = default;
77
78 public:
79   void appendSubgInputObserver(const ir::SubgraphIndex &subg_idx,
80                                std::unique_ptr<OperandObserver> &&subg_input_observer) noexcept
81   {
82     _subg_input_observers[subg_idx] = std::move(subg_input_observer);
83   }
84
85   void setControlflowOutputObserver(std::unique_ptr<OperandObserver> &&output_observer) noexcept
86   {
87     _controlflow_output_observer = std::move(output_observer);
88   }
89
90   void appendChildInferer(const ir::SubgraphIndex &subg_idx, compiler::StaticShapeInferer *inferer)
91   {
92     _child_inferers[subg_idx] = inferer;
93   }
94
95   /**
96    * @brief Infer shape of operands belonging to ops and set the output shape.
97    *        If output shape cannot be known without running op, mark it so that it can be allocated
98    *        when running kernel.
99    */
100   void infer(void);
101
102   void dump();
103
104   /**
105    * @brief     Create a lowered model shape inferer map
106    * @param[in] lowered_subgs lowered model subgraph map
107    * @return    Shape inferer map
108    */
109   static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
110   createStaticShapeInferers(
111     const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs);
112
113 private:
114   bool checkDynamicInput(const ir::Operation &op);
115   bool checkDynamicOutput(const ir::Operation &op);
116   void setDynamicOutput(const ir::Operation &op);
117
118 private:
119   // TODO Define visitors for operations. List them in alphabetic order.
120   void visit(const ir::operation::ArgMinMax &op) override;
121   void visit(const ir::operation::BatchMatMul &op) override;
122   void visit(const ir::operation::BCQFullyConnected &op) override;
123   void visit(const ir::operation::BCQGather &op) override;
124   void visit(const ir::operation::BinaryArithmetic &op) override;
125   void visit(const ir::operation::BroadcastTo &op) override;
126   void visit(const ir::operation::Comparison &op) override;
127   void visit(const ir::operation::Concat &op) override;
128   void visit(const ir::operation::Conv2D &op) override;
129   void visit(const ir::operation::ElementwiseActivation &op) override;
130   void visit(const ir::operation::ElementwiseBinary &op) override;
131   void visit(const ir::operation::ElementwiseUnary &op) override;
132   void visit(const ir::operation::ExpandDims &op) override;
133   void visit(const ir::operation::Fill &op) override;
134   void visit(const ir::operation::FullyConnected &op) override;
135   void visit(const ir::operation::FusedBatchNorm &op) override;
136   void visit(const ir::operation::Gather &op) override;
137   void visit(const ir::operation::If &op) override;
138   void visit(const ir::operation::L2Normalization &op) override;
139   void visit(const ir::operation::LSTM &op) override;
140   void visit(const ir::operation::MatrixBandPart &op) override;
141   void visit(const ir::operation::OneHot &op) override;
142   void visit(const ir::operation::Pack &op) override;
143   void visit(const ir::operation::Pad &op) override;
144   void visit(const ir::operation::Permute &op) override;
145   void visit(const ir::operation::Pow &op) override;
146   void visit(const ir::operation::Range &op) override;
147   void visit(const ir::operation::Reduce &op) override;
148   void visit(const ir::operation::Reshape &op) override;
149   void visit(const ir::operation::ResizeBilinear &op) override;
150   void visit(const ir::operation::Reverse &op) override;
151   void visit(const ir::operation::Select &op) override;
152   void visit(const ir::operation::Shape &op) override;
153   void visit(const ir::operation::Slice &op) override;
154   void visit(const ir::operation::Softmax &op) override;
155   void visit(const ir::operation::SpaceToBatchND &op) override;
156   void visit(const ir::operation::Split &op) override;
157   void visit(const ir::operation::Squeeze &op) override;
158   void visit(const ir::operation::StridedSlice &op) override;
159   void visit(const ir::operation::SquaredDifference &op) override;
160   void visit(const ir::operation::Tile &op) override;
161   void visit(const ir::operation::Transpose &op) override;
162   void visit(const ir::operation::Unpack &op) override;
163   void visit(const ir::operation::While &op) override;
164   void visit(const ir::operation::DetectionPostProcess &op) override;
165   void visit(const ir::operation::Bulk &op) override;
166
167 private:
168   /**
169    * @brief Performs shape inference for arithmetic operation
170    */
171   void handleBinaryArithmeticOp(const ir::Operation &op, const ir::OperandIndex lhs_idx,
172                                 const ir::OperandIndex rhs_idx);
173
174   /**
175    * @brief Performs shape inference for unary op whose output shape is
176    *        always same with input shape
177    */
178   void handleSimpleUnaryOp(const ir::Operation &op, const ir::OperandIndex input_idx);
179
180 private:
181   compiler::LoweredGraph *_lowered_subg;
182   std::unordered_map<ir::SubgraphIndex, std::unique_ptr<OperandObserver>>
183     _subg_input_observers;                                       // child subg input
184   std::unique_ptr<OperandObserver> _controlflow_output_observer; // parent controlflow op output
185   std::unordered_map<ir::SubgraphIndex, compiler::StaticShapeInferer *> _child_inferers;
186 };
187
188 } // namespace compiler
189 } // namespace onert
190
191 #endif // __ONERT_COMPILER_STATIC_SHAPE_INFERER_H__