Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / compiler / mir-caffe2-importer / caffe2_op_creator.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef MIR_CAFFE2_OP_CREATOR_H
18 #define MIR_CAFFE2_OP_CREATOR_H
19
20 #include <set>
21 #include <unordered_map>
22 #include <vector>
23 #include <memory>
24
25 #include "mir/Graph.h"
26 #include "mir/Operation.h"
27 #include "mir/TensorVariant.h"
28 #include "mir/Shape.h"
29
30 #include "caffe2/proto/caffe2.pb.h"
31
32 namespace mir_caffe2
33 {
34
35 using mir::Operation;
36 using mir::Shape;
37
38 class Caffe2OpCreator
39 {
40 public:
41   explicit Caffe2OpCreator(mir::Graph *g) : _graph(g) {}
42
43   std::vector<mir::Operation::Output *>
44   convertConstant(const std::vector<mir::Operation::Output *> &inputs,
45                   const ::caffe2::OperatorDef &op);
46
47   std::vector<mir::Operation::Output *>
48   convertAdd(const std::vector<mir::Operation::Output *> &inputs, const ::caffe2::OperatorDef &op);
49
50   std::vector<mir::Operation::Output *>
51   convertAveragePool(const std::vector<mir::Operation::Output *> &inputs,
52                      const ::caffe2::OperatorDef &op);
53
54   std::vector<mir::Operation::Output *>
55   convertConv(const std::vector<mir::Operation::Output *> &inputs, const ::caffe2::OperatorDef &op);
56
57   std::vector<mir::Operation::Output *>
58   convertConcat(const std::vector<mir::Operation::Output *> &inputs,
59                 const ::caffe2::OperatorDef &op);
60
61   std::vector<mir::Operation::Output *>
62   convertDropout(const std::vector<mir::Operation::Output *> &inputs,
63                  const ::caffe2::OperatorDef &op);
64
65   std::vector<mir::Operation::Output *>
66   convertFC(const std::vector<mir::Operation::Output *> &inputs, const ::caffe2::OperatorDef &op);
67
68   std::vector<mir::Operation::Output *>
69   convertMaxPool(const std::vector<mir::Operation::Output *> &inputs,
70                  const ::caffe2::OperatorDef &op);
71
72   std::vector<mir::Operation::Output *>
73   convertMul(const std::vector<mir::Operation::Output *> &inputs, const ::caffe2::OperatorDef &op);
74
75   std::vector<mir::Operation::Output *>
76   convertRelu(const std::vector<mir::Operation::Output *> &inputs);
77
78   std::vector<mir::Operation::Output *>
79   convertResizeNearest(const std::vector<mir::Operation::Output *> &inputs,
80                        const ::caffe2::OperatorDef &op);
81
82   std::vector<mir::Operation::Output *>
83   convertSigmoid(const std::vector<mir::Operation::Output *> &inputs);
84
85   std::vector<mir::Operation::Output *>
86   convertSoftmax(const std::vector<mir::Operation::Output *> &inputs,
87                  const ::caffe2::OperatorDef &op);
88
89   std::vector<mir::Operation::Output *>
90   convertSpatialBN(const std::vector<mir::Operation::Output *> &inputs,
91                    const ::caffe2::OperatorDef &op);
92
93   std::vector<mir::Operation::Output *>
94   convertSum(const std::vector<mir::Operation::Output *> &inputs);
95
96   std::vector<mir::Operation::Output *>
97   convertClip(const std::vector<mir::Operation::Output *> &inputs, const ::caffe2::OperatorDef &op);
98
99   std::vector<mir::Operation::Output *>
100   convertReshape(const std::vector<mir::Operation::Output *> &inputs,
101                  const ::caffe2::OperatorDef &op);
102
103 private:
104   mir::Graph *_graph = nullptr;
105
106   template <typename OpType, typename... Types> mir::Operation *createOp(Types &&... args);
107 };
108
109 template <typename OpType, typename... Types>
110 mir::Operation *Caffe2OpCreator::createOp(Types &&... args)
111 {
112   return _graph->create<OpType>(std::forward<Types>(args)...);
113 }
114
115 } // namespace mir_caffe2
116
117 #endif // MIR_CAFFE2_OP_CREATOR_H