From 220b2e716b0b2eac79af0a9aa582695faa434b9d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 2 Sep 2019 18:06:39 +0900 Subject: [PATCH] [exo-tflite] functions to build node (#7080) This adds functions to build featureEncode and featureDecode, which are frequently used. Signed-off-by: Hyun Sik Yoon --- compiler/exo-tflite/src/Conversion/GraphBlock.cpp | 86 +++++++++++++++++++++++ compiler/exo-tflite/src/Conversion/GraphBlock.h | 39 ++++++++++ 2 files changed, 125 insertions(+) create mode 100644 compiler/exo-tflite/src/Conversion/GraphBlock.cpp create mode 100644 compiler/exo-tflite/src/Conversion/GraphBlock.h diff --git a/compiler/exo-tflite/src/Conversion/GraphBlock.cpp b/compiler/exo-tflite/src/Conversion/GraphBlock.cpp new file mode 100644 index 0000000..e38b1d1 --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/GraphBlock.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GraphBlock.h" + +#include "Check.h" + +#include +#include + +namespace +{ + +template loco::Permutation perm(); + +template <> loco::Permutation perm() +{ + // Make NHWC permutation for encoder and decoder + loco::Permutation NHWC; + + NHWC.axis(loco::FeatureAxis::Count) = 0; + NHWC.axis(loco::FeatureAxis::Height) = 1; + NHWC.axis(loco::FeatureAxis::Width) = 2; + NHWC.axis(loco::FeatureAxis::Depth) = 3; + + return NHWC; +} + +} // namespace + +namespace exo +{ + +template loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode) +{ + EXO_ASSERT(input_for_encode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_encode->graph(); + + auto encoder = stdex::make_unique>(); + + encoder->perm(perm()); + + auto enc = g->nodes()->create(); + enc->input(input_for_encode); + enc->encoder(std::move(encoder)); + + return enc; +} + +template loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode) +{ + EXO_ASSERT(input_for_decode != nullptr, "input should not be nullptr"); + loco::Graph *g = input_for_decode->graph(); + + auto decoder = stdex::make_unique>(); + + decoder->perm(perm()); + + auto dec = g->nodes()->create(); + dec->input(input_for_decode); + dec->decoder(std::move(decoder)); + + return dec; +} + +// template instantiation +template loco::FeatureEncode * +make_feature_encode(loco::Node *input_for_encode); + +template loco::FeatureDecode * +make_feature_decode(loco::Node *input_for_encode); + +} // namespace exo diff --git a/compiler/exo-tflite/src/Conversion/GraphBlock.h b/compiler/exo-tflite/src/Conversion/GraphBlock.h new file mode 100644 index 0000000..d6dd575 --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/GraphBlock.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_GRAPH_BLOCK_H__ +#define __CONVERSION_GRAPH_BLOCK_H__ + +#include + +namespace exo +{ + +/// @brief default layout of TFLITE file +enum class DefaultLayout +{ + NHWC, +}; + +/// @brief Creates a loco::FeatureEncode of default layout (NHWC for tflite) and add it to graph. +template loco::FeatureEncode *make_feature_encode(loco::Node *input_for_encode); + +/// @brief Create a loco::FeatureDecode of default layout (NHWC for tflite) and add it to graph. +template loco::FeatureDecode *make_feature_decode(loco::Node *input_for_decode); + +} // namespace exo + +#endif //__CONVERSION_GRAPH_BLOCK_H__ -- 2.7.4