Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / mir / src / mir_tflite_importer / tflite_op_creator.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef MIR_TFLITE_OP_CREATOR_H
18 #define MIR_TFLITE_OP_CREATOR_H
19
20 #include "schema_generated.h"
21
22 #include "mir/Graph.h"
23
24 #include <utility>
25 #include <vector>
26
27 namespace mir_tflite
28 {
29
30 class TFLiteOpCreator
31 {
32 public:
33   explicit TFLiteOpCreator(mir::Graph *g) : _graph(g) {}
34
35   std::vector<mir::Operation::Output *>
36   convertConv2D(const tflite::Conv2DOptionsT *opts,
37                 const std::vector<mir::Operation::Output *> &inputs);
38
39   std::vector<mir::Operation::Output *>
40   convertDepthwiseConv2D(const tflite::DepthwiseConv2DOptionsT *opts,
41                          const std::vector<mir::Operation::Output *> &inputs);
42
43   std::vector<mir::Operation::Output *>
44   convertConcatenation(const tflite::ConcatenationOptionsT *opts,
45                        const std::vector<mir::Operation::Output *> &inputs);
46
47   std::vector<mir::Operation::Output *>
48   convertMaxPool2D(const tflite::Pool2DOptionsT *opts,
49                    const std::vector<mir::Operation::Output *> &inputs);
50
51   std::vector<mir::Operation::Output *>
52   convertAveragePool2D(const tflite::Pool2DOptionsT *opts,
53                        const std::vector<mir::Operation::Output *> &inputs);
54
55   std::vector<mir::Operation::Output *>
56   convertMean(const tflite::ReducerOptionsT *opts,
57               const std::vector<mir::Operation::Output *> &inputs);
58
59   std::vector<mir::Operation::Output *>
60   convertSoftmax(const tflite::SoftmaxOptionsT *opts,
61                  const std::vector<mir::Operation::Output *> &inputs);
62
63   std::vector<mir::Operation::Output *>
64   convertSlice(const tflite::SliceOptionsT *opts,
65                const std::vector<mir::Operation::Output *> &inputs);
66
67   std::vector<mir::Operation::Output *>
68   convertReshape(const tflite::ReshapeOptionsT *opts,
69                  const std::vector<mir::Operation::Output *> &inputs);
70
71   std::vector<mir::Operation::Output *>
72   convertFullyConnected(const tflite::FullyConnectedOptionsT *opts,
73                         const std::vector<mir::Operation::Output *> &inputs);
74
75   std::vector<mir::Operation::Output *>
76   convertResizeNearestNeighbor(const tflite::ResizeNearestNeighborOptionsT *opts,
77                                const std::vector<mir::Operation::Output *> &inputs);
78
79   std::vector<mir::Operation::Output *>
80   convertLogistic(const std::vector<mir::Operation::Output *> &inputs);
81
82   std::vector<mir::Operation::Output *>
83   convertRsqrt(const std::vector<mir::Operation::Output *> &inputs);
84
85   std::vector<mir::Operation::Output *>
86   convertSqrt(const std::vector<mir::Operation::Output *> &inputs);
87
88   std::vector<mir::Operation::Output *>
89   convertSqueeze(const tflite::SqueezeOptionsT *opts,
90                  const std::vector<mir::Operation::Output *> &inputs);
91
92   std::vector<mir::Operation::Output *>
93   convertAdd(const tflite::AddOptionsT *opts, const std::vector<mir::Operation::Output *> &inputs);
94
95   std::vector<mir::Operation::Output *>
96   convertSub(const tflite::SubOptionsT *opts, const std::vector<mir::Operation::Output *> &inputs);
97
98   std::vector<mir::Operation::Output *>
99   convertMul(const tflite::MulOptionsT *opts, const std::vector<mir::Operation::Output *> &inputs);
100
101   std::vector<mir::Operation::Output *>
102   convertDiv(const tflite::DivOptionsT *opts, const std::vector<mir::Operation::Output *> &inputs);
103
104   std::vector<mir::Operation::Output *>
105   convertMax(const std::vector<mir::Operation::Output *> &inputs);
106
107   std::vector<mir::Operation::Output *>
108   convertSquaredDifference(const std::vector<mir::Operation::Output *> &inputs);
109
110   std::vector<mir::Operation::Output *>
111   convertTanh(const std::vector<mir::Operation::Output *> &inputs);
112
113   std::vector<mir::Operation::Output *>
114   convertReLU(const std::vector<mir::Operation::Output *> &inputs);
115
116   std::vector<mir::Operation::Output *>
117   convertReLU6(const std::vector<mir::Operation::Output *> &inputs);
118
119   std::vector<mir::Operation::Output *>
120   convertTransposeConv(const tflite::TransposeConvOptionsT *opts,
121                        const std::vector<mir::Operation::Output *> &inputs);
122
123   std::vector<mir::Operation::Output *>
124   convertPad(const tflite::PadOptionsT *opts, const std::vector<mir::Operation::Output *> &inputs);
125
126   std::vector<mir::Operation::Output *>
127   convertTranspose(const tflite::TransposeOptionsT *opts,
128                    const std::vector<mir::Operation::Output *> &inputs);
129
130   std::vector<mir::Operation::Output *>
131   convertStridedSlice(const tflite::StridedSliceOptionsT *opts,
132                       const std::vector<mir::Operation::Output *> &inputs);
133
134   std::vector<mir::Operation::Output *>
135   convertLeakyReLU(const tflite::LeakyReluOptionsT *opts,
136                    const std::vector<mir::Operation::Output *> &inputs);
137
138   std::vector<mir::Operation::Output *>
139   convertShape(const tflite::ShapeOptionsT *opts,
140                const std::vector<mir::Operation::Output *> &inputs);
141
142   std::vector<mir::Operation::Output *>
143   convertHardSwish(const tflite::HardSwishOptionsT *opts,
144                    const std::vector<mir::Operation::Output *> &inputs);
145
146 private:
147   mir::Graph *_graph;
148
149   mir::Operation::Output *addFusedActivation(mir::Operation::Output *input,
150                                              tflite::ActivationFunctionType activation_type);
151
152   template <typename OpType, typename... Types> mir::Operation *createOp(Types &&... args);
153 };
154
155 template <typename OpType, typename... Types>
156 mir::Operation *TFLiteOpCreator::createOp(Types &&... args)
157 {
158   return _graph->create<OpType>(std::forward<Types>(args)...);
159 }
160
161 } // namespace mir_tflite
162
163 #endif // MIR_TFLITE_OP_CREATOR_H