2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __ONERT_COMPILER_STATIC_SHAPE_INFERER_H__
18 #define __ONERT_COMPILER_STATIC_SHAPE_INFERER_H__
20 #include "ir/OperationVisitor.h"
21 #include "compiler/LoweredGraph.h"
25 #include <unordered_map>
32 * @brief Class that observe and update operands.
38 * @brief Constructor of OperandObserver
40 * @param operands Operands to be updated
42 OperandObserver(const std::vector<ir::Operand *> &operands) : _operands{operands} {}
44 * @brief Destructor of OperandObserver
46 virtual ~OperandObserver() = default;
50 * @brief Update Shape and some OperandInfo of operands
52 * @param operands Operands to be updated
53 * @param unpredictable Whether runtime can predict shapes of operands in compilation time
55 void updateShapes(const std::vector<ir::OperandInfo> &changed_operands_info,
56 bool unpredictable = false);
59 std::vector<ir::Operand *> _operands;
63 * @brief Class to infer shape before running kernels. It does the following:
64 * - re-calculate and set output shape at compile time (before running kernels)
65 * - if calculation cannot be done at compile time, mark the outputs to be dynamic, meaning
66 * shapes of outputs will be calculated during running kernels
68 class StaticShapeInferer : public ir::OperationVisitor
71 StaticShapeInferer(compiler::LoweredGraph *lowered_subg)
72 : _lowered_subg{lowered_subg}, _subg_input_observers{}, _controlflow_output_observer{nullptr},
76 virtual ~StaticShapeInferer() = default;
79 void appendSubgInputObserver(const ir::SubgraphIndex &subg_idx,
80 std::unique_ptr<OperandObserver> &&subg_input_observer) noexcept
82 _subg_input_observers[subg_idx] = std::move(subg_input_observer);
85 void setControlflowOutputObserver(std::unique_ptr<OperandObserver> &&output_observer) noexcept
87 _controlflow_output_observer = std::move(output_observer);
90 void appendChildInferer(const ir::SubgraphIndex &subg_idx, compiler::StaticShapeInferer *inferer)
92 _child_inferers[subg_idx] = inferer;
96 * @brief Infer shape of operands belonging to ops and set the output shape.
97 * If output shape cannot be known without running op, mark it so that it can be allocated
98 * when running kernel.
105 * @brief Create a lowered model shape inferer map
106 * @param[in] lowered_subgs lowered model subgraph map
107 * @return Shape inferer map
109 static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
110 createStaticShapeInferers(
111 const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs);
114 bool checkDynamicInput(const ir::Operation &op);
115 bool checkDynamicOutput(const ir::Operation &op);
116 void setDynamicOutput(const ir::Operation &op);
119 // TODO Define visitors for operations. List them in alphabetic order.
120 void visit(const ir::operation::ArgMinMax &op) override;
121 void visit(const ir::operation::BatchMatMul &op) override;
122 void visit(const ir::operation::BCQFullyConnected &op) override;
123 void visit(const ir::operation::BCQGather &op) override;
124 void visit(const ir::operation::BinaryArithmetic &op) override;
125 void visit(const ir::operation::BroadcastTo &op) override;
126 void visit(const ir::operation::Comparison &op) override;
127 void visit(const ir::operation::Concat &op) override;
128 void visit(const ir::operation::Conv2D &op) override;
129 void visit(const ir::operation::ElementwiseActivation &op) override;
130 void visit(const ir::operation::ElementwiseBinary &op) override;
131 void visit(const ir::operation::ElementwiseUnary &op) override;
132 void visit(const ir::operation::ExpandDims &op) override;
133 void visit(const ir::operation::Fill &op) override;
134 void visit(const ir::operation::FullyConnected &op) override;
135 void visit(const ir::operation::FusedBatchNorm &op) override;
136 void visit(const ir::operation::Gather &op) override;
137 void visit(const ir::operation::If &op) override;
138 void visit(const ir::operation::L2Normalization &op) override;
139 void visit(const ir::operation::LSTM &op) override;
140 void visit(const ir::operation::MatrixBandPart &op) override;
141 void visit(const ir::operation::OneHot &op) override;
142 void visit(const ir::operation::Pack &op) override;
143 void visit(const ir::operation::Pad &op) override;
144 void visit(const ir::operation::Permute &op) override;
145 void visit(const ir::operation::Pow &op) override;
146 void visit(const ir::operation::Range &op) override;
147 void visit(const ir::operation::Reduce &op) override;
148 void visit(const ir::operation::Reshape &op) override;
149 void visit(const ir::operation::ResizeBilinear &op) override;
150 void visit(const ir::operation::Reverse &op) override;
151 void visit(const ir::operation::Select &op) override;
152 void visit(const ir::operation::Shape &op) override;
153 void visit(const ir::operation::Slice &op) override;
154 void visit(const ir::operation::Softmax &op) override;
155 void visit(const ir::operation::SpaceToBatchND &op) override;
156 void visit(const ir::operation::Split &op) override;
157 void visit(const ir::operation::Squeeze &op) override;
158 void visit(const ir::operation::StridedSlice &op) override;
159 void visit(const ir::operation::SquaredDifference &op) override;
160 void visit(const ir::operation::Tile &op) override;
161 void visit(const ir::operation::Transpose &op) override;
162 void visit(const ir::operation::Unpack &op) override;
163 void visit(const ir::operation::While &op) override;
164 void visit(const ir::operation::DetectionPostProcess &op) override;
165 void visit(const ir::operation::Bulk &op) override;
169 * @brief Performs shape inference for arithmetic operation
171 void handleBinaryArithmeticOp(const ir::Operation &op, const ir::OperandIndex lhs_idx,
172 const ir::OperandIndex rhs_idx);
175 * @brief Performs shape inference for unary op whose output shape is
176 * always same with input shape
178 void handleSimpleUnaryOp(const ir::Operation &op, const ir::OperandIndex input_idx);
181 compiler::LoweredGraph *_lowered_subg;
182 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<OperandObserver>>
183 _subg_input_observers; // child subg input
184 std::unique_ptr<OperandObserver> _controlflow_output_observer; // parent controlflow op output
185 std::unordered_map<ir::SubgraphIndex, compiler::StaticShapeInferer *> _child_inferers;
188 } // namespace compiler
191 #endif // __ONERT_COMPILER_STATIC_SHAPE_INFERER_H__