Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / train / StaticDerivativeShapeInferer.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __ONERT_COMPILER_STATIC_DERIVATIVE_SHAPE_INFERER_H__
18 #define __ONERT_COMPILER_STATIC_DERIVATIVE_SHAPE_INFERER_H__
19
20 #include "ir/train/TrainableOperationVisitor.h"
21
22 #include "compiler/train/LoweredTrainableGraph.h"
23 #include "ir/Index.h"
24
25 #include <memory>
26 #include <unordered_map>
27
28 namespace onert
29 {
30 namespace compiler
31 {
32 namespace train
33 {
34
35 /**
36  * @brief Class to infer shape before running kernels. It does the following:
37  *        - re-calculate and set output shape at compile time (before running kernels)
38  *        - if calculation cannot be done at compile time, mark the outputs to be dynamic, meaning
39  *          shapes of outputs will be calculated during running kernels
40  */
41 class StaticDerivativeShapeInferer : public ir::train::TrainableOperationVisitor
42 {
43 public:
44   StaticDerivativeShapeInferer(compiler::train::LoweredTrainableGraph *lowered_subg)
45     : _lowered_subg{lowered_subg}
46   {
47   }
48
49   /**
50    * @brief Infer shape of operands belonging to ops and set the output shape.
51    *        If output shape cannot be known without running op, mark it so that it can be allocated
52    *        when running kernel.
53    */
54   void infer(void);
55
56   void dump();
57
58 private:
59   bool checkDynamicInput(const ir::IOperation &op);
60   void checkOutput(const ir::IOperation &op);
61   void setShape(const ir::OperandIndex &index, const ir::Shape &shape);
62
63 private:
64   void visit(const ir::train::operation::Conv2D &op) override;
65   void visit(const ir::train::operation::ElementwiseActivation &op) override;
66   void visit(const ir::train::operation::Loss &op) override;
67   void visit(const ir::train::operation::Permute &op) override;
68   void visit(const ir::train::operation::Pool2D &op) override;
69   void visit(const ir::train::operation::Reshape &op) override;
70   void visit(const ir::train::operation::Softmax &op) override;
71
72 private:
73   compiler::train::LoweredTrainableGraph *_lowered_subg;
74 };
75
76 } // namespace train
77 } // namespace compiler
78 } // namespace onert
79
80 #endif // __ONERT_COMPILER_STATIC_DERIVATIVE_SHAPE_INFERER_H__