Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / GraphBlock.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __GRAPH_BLOCK_H__
18 #define __GRAPH_BLOCK_H__
19
20 #include <loco.h>
21 #include <loco/Service/ShapeInference.h>
22
23 #include <oops/InternalExn.h>
24
25 #include <functional>
26
27 // TODO Change all Canonical nodes to Circle nodes
28
29 namespace luci
30 {
31
32 /// @brief feature layout of TFlite/Circle file
33 enum class FeatureLayout
34 {
35   NHWC,
36 };
37
38 /// @brief Creates a loco::FeatureEncode with T layout (NHWC for tflite) and add it to graph.
39 template <FeatureLayout T> loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode);
40
41 /// @brief Creates a loco::FeatureDecode with T layout (NHWC for tflite) and add it to graph.
42 template <FeatureLayout T> loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode);
43
44 enum class FilterLayout
45 {
46   OHWI, // a.k.a., NHWC, Tensorflow Lite uses this layout for filter
47   HWIO, // a.k.a., HWCN, Tensorflow uses this layout for filter
48 };
49
50 /// @brief Create a loco::FilterEncode of given layout
51 template <FilterLayout T> loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode);
52
53 /// @brief Create a loco::FilterDecode of given layout
54 template <FilterLayout T> loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode);
55
56 enum class DepthwiseFilterLayout
57 {
58   HWCM,
59 };
60
61 /// @brief Create a loco::DepthwiseFilterDecode of given layout
62 template <DepthwiseFilterLayout T>
63 loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode);
64
65 enum class MatrixLayout
66 {
67   HW,
68   WH
69 };
70
71 /// @brief Create a loco::MatrixEncode of given layout
72 template <MatrixLayout T> loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode);
73
74 /// @brief Create a loco::MatrixDecode of given layout
75 template <MatrixLayout T> loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode);
76
77 } // luci
78
79 //
80 // DomainConverter
81 //
82
83 /**
84  * Some canonical nodes can have input of various loco::Domain, e.g., loco::Domain::Tensor,
85  * loco::Domain::Feature, etc. However, TFL node accepts only loco::Domain::Tensor.
86  * So, When converting such canonical node to TFL node and input(s) of a canonical node are not
87  * loco::Domain::Tensor, additional nodes need to be inserted.
88  *
89  * The following two classes helps this insertion.
90  *
91  * For example, in case of loco::Relu conversion,
92  *
93  * Before:
94  *
95  *    A (output: feature) -- loco::ReLU --- B (input:feature)
96  *
97  * After:
98  *
99  *    A -- loco::FeatureDecode -- locoex::TFLRelu -- loco::FeatureEncode --- B
100  *
101  *                  loco::ReLU (dead node)
102  */
103
104 namespace luci
105 {
106
107 /**
108  * @brief Handles input(s) while converting a canonical node to TFL node(s).
109  *        This class informs DomainConverter how to handle inputs of a specific canonical node.
110  */
111 template <class CanonicalT, class TFLT> class InputHandler
112 {
113 public:
114   /**
115    * @brief Assign origin's inputs to replacer's inputs.
116    *        (This is called when origin belongs in Tensor domain.)
117    */
118   virtual void handover(CanonicalT *origin, TFLT *replacer) = 0;
119
120   /**
121    * @brief Returns the list of inputs that needs to have FeatureDecode as its input.
122    *        (This is called when origin belongs in Feature domain.)
123    */
124   virtual std::vector<loco::Node *> getInputsToConvert(CanonicalT *origin) = 0;
125
126   /// @brief Set the inputs of replacer to new_inputs
127   virtual void set(TFLT *replacer, std::vector<loco::Node *> &new_inputs) = 0;
128
129   /// @brief Set the inputs to nullptr
130   virtual void nullify(CanonicalT *origin) = 0;
131 };
132
133 /**
134  * @brief Class to handle domain conversion while converting a canonical node to TFL node(s)
135  */
136 template <class CanonicalT, class TFLT> class DomainConverter
137 {
138 public:
139   template <FeatureLayout FeatureLayoutT>
140   TFLT *convert(CanonicalT *origin, InputHandler<CanonicalT, TFLT> &input_handler);
141 };
142
143 /**
144  * @brief Performs domain conversion
145  *
146  * 1. if origin belong to loco::Domain::Tensor, and replace origin to a TFL node.
147  * 2. if origin belong to loco::Domain::Feature, insert loco::FeatureDecode for input(s) and
148  *    insert loco::FeatureEncode for output. Then replace origin to a TFL node.
149  *
150  * @return new TFL node; nullptr if shape of origin cannot be known
151  */
152 template <class CanonicalT, class TFLT>
153 template <FeatureLayout FeatureLayoutT>
154 TFLT *DomainConverter<CanonicalT, TFLT>::convert(CanonicalT *origin,
155                                                  InputHandler<CanonicalT, TFLT> &input_handler)
156 {
157   static_assert(FeatureLayoutT == FeatureLayout::NHWC, "Feature layout should be NHWC");
158
159   if (!loco::shape_known(origin))
160   {
161     return nullptr;
162   }
163
164   auto tfl_node = origin->graph()->nodes()->template create<TFLT>();
165
166   // when the input is Tensor, just replace canonical node to TFL node.
167   if (loco::shape_get(origin).domain() == loco::Domain::Tensor)
168   {
169     input_handler.handover(origin, tfl_node);
170
171     loco::replace(origin).with(tfl_node);
172     input_handler.nullify(origin);
173
174     return tfl_node;
175   }
176   else if (loco::shape_get(origin).domain() == loco::Domain::Feature)
177   {
178     std::vector<loco::Node *> feature_decodes;
179
180     for (auto input : input_handler.getInputsToConvert(origin))
181     {
182       auto dec = make_feature_decode<FeatureLayoutT>(input);
183       feature_decodes.emplace_back(dec);
184     }
185
186     input_handler.set(tfl_node, feature_decodes);
187
188     auto enc = make_feature_encode<FeatureLayoutT>(tfl_node);
189
190     loco::replace(origin).with(enc);
191     input_handler.nullify(origin);
192
193     return tfl_node;
194   }
195   else
196     INTERNAL_EXN_V("Unsupported loco::Domain", oops::to_uint32(loco::shape_get(origin).domain()));
197 }
198
199 } // namespace luci
200
201 #endif //__GRAPH_BLOCK_H__