2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __GRAPH_BLOCK_H__
18 #define __GRAPH_BLOCK_H__
21 #include <loco/Service/ShapeInference.h>
23 #include <oops/InternalExn.h>
27 // TODO Change all Canonical nodes to Circle nodes
32 /// @brief feature layout of TFlite/Circle file
33 enum class FeatureLayout
38 /// @brief Creates a loco::FeatureEncode with T layout (NHWC for tflite) and add it to graph.
39 template <FeatureLayout T> loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode);
41 /// @brief Creates a loco::FeatureDecode with T layout (NHWC for tflite) and add it to graph.
42 template <FeatureLayout T> loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode);
44 enum class FilterLayout
46 OHWI, // a.k.a., NHWC, Tensorflow Lite uses this layout for filter
47 HWIO, // a.k.a., HWCN, Tensorflow uses this layout for filter
50 /// @brief Create a loco::FilterEncode of given layout
51 template <FilterLayout T> loco::FilterEncode *make_filter_encode(loco::Node *input_for_encode);
53 /// @brief Create a loco::FilterDecode of given layout
54 template <FilterLayout T> loco::FilterDecode *make_filter_decode(loco::Node *input_for_decode);
56 enum class DepthwiseFilterLayout
61 /// @brief Create a loco::DepthwiseFilterDecode of given layout
62 template <DepthwiseFilterLayout T>
63 loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode);
65 enum class MatrixLayout
71 /// @brief Create a loco::MatrixEncode of given layout
72 template <MatrixLayout T> loco::MatrixEncode *make_matrix_encode(loco::Node *input_for_encode);
74 /// @brief Create a loco::MatrixDecode of given layout
75 template <MatrixLayout T> loco::MatrixDecode *make_matrix_decode(loco::Node *input_for_decode);
84 * Some canonical nodes can have input of various loco::Domain, e.g., loco::Domain::Tensor,
85 * loco::Domain::Feature, etc. However, TFL node accepts only loco::Domain::Tensor.
86 * So, When converting such canonical node to TFL node and input(s) of a canonical node are not
87 * loco::Domain::Tensor, additional nodes need to be inserted.
89 * The following two classes helps this insertion.
91 * For example, in case of loco::Relu conversion,
95 * A (output: feature) -- loco::ReLU --- B (input:feature)
99 * A -- loco::FeatureDecode -- locoex::TFLRelu -- loco::FeatureEncode --- B
101 * loco::ReLU (dead node)
108 * @brief Handles input(s) while converting a canonical node to TFL node(s).
109 * This class informs DomainConverter how to handle inputs of a specific canonical node.
111 template <class CanonicalT, class TFLT> class InputHandler
115 * @brief Assign origin's inputs to replacer's inputs.
116 * (This is called when origin belongs in Tensor domain.)
118 virtual void handover(CanonicalT *origin, TFLT *replacer) = 0;
121 * @brief Returns the list of inputs that needs to have FeatureDecode as its input.
122 * (This is called when origin belongs in Feature domain.)
124 virtual std::vector<loco::Node *> getInputsToConvert(CanonicalT *origin) = 0;
126 /// @brief Set the inputs of replacer to new_inputs
127 virtual void set(TFLT *replacer, std::vector<loco::Node *> &new_inputs) = 0;
129 /// @brief Set the inputs to nullptr
130 virtual void nullify(CanonicalT *origin) = 0;
134 * @brief Class to handle domain conversion while converting a canonical node to TFL node(s)
136 template <class CanonicalT, class TFLT> class DomainConverter
139 template <FeatureLayout FeatureLayoutT>
140 TFLT *convert(CanonicalT *origin, InputHandler<CanonicalT, TFLT> &input_handler);
144 * @brief Performs domain conversion
146 * 1. if origin belong to loco::Domain::Tensor, and replace origin to a TFL node.
147 * 2. if origin belong to loco::Domain::Feature, insert loco::FeatureDecode for input(s) and
148 * insert loco::FeatureEncode for output. Then replace origin to a TFL node.
150 * @return new TFL node; nullptr if shape of origin cannot be known
152 template <class CanonicalT, class TFLT>
153 template <FeatureLayout FeatureLayoutT>
154 TFLT *DomainConverter<CanonicalT, TFLT>::convert(CanonicalT *origin,
155 InputHandler<CanonicalT, TFLT> &input_handler)
157 static_assert(FeatureLayoutT == FeatureLayout::NHWC, "Feature layout should be NHWC");
159 if (!loco::shape_known(origin))
164 auto tfl_node = origin->graph()->nodes()->template create<TFLT>();
166 // when the input is Tensor, just replace canonical node to TFL node.
167 if (loco::shape_get(origin).domain() == loco::Domain::Tensor)
169 input_handler.handover(origin, tfl_node);
171 loco::replace(origin).with(tfl_node);
172 input_handler.nullify(origin);
176 else if (loco::shape_get(origin).domain() == loco::Domain::Feature)
178 std::vector<loco::Node *> feature_decodes;
180 for (auto input : input_handler.getInputsToConvert(origin))
182 auto dec = make_feature_decode<FeatureLayoutT>(input);
183 feature_decodes.emplace_back(dec);
186 input_handler.set(tfl_node, feature_decodes);
188 auto enc = make_feature_encode<FeatureLayoutT>(tfl_node);
190 loco::replace(origin).with(enc);
191 input_handler.nullify(origin);
196 INTERNAL_EXN_V("Unsupported loco::Domain", oops::to_uint32(loco::shape_get(origin).domain()));
201 #endif //__GRAPH_BLOCK_H__