From 228afd1fca442b89755262b3832b34e483a94539 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 30 Sep 2019 18:07:15 +0900 Subject: [PATCH] [exo-tflite] converting loco::TensorTranspose to locoex::TFLTranspose (#7824) * [exo-tflite] converter to convert loco::TensorTranspose to locoex::TFLTranspose A converter to convert loco::TensorTranspose to locoex::TFLTranspose is introduced. Signed-off-by: Hyun Sik Yoon * adding more validation check + using std::is_permutation(..) --- .../src/Conversion/TensorTransposeConverter.cpp | 100 +++++++++++++++++++++ .../src/Conversion/TensorTransposeConverter.h | 41 +++++++++ 2 files changed, 141 insertions(+) create mode 100644 compiler/exo-tflite/src/Conversion/TensorTransposeConverter.cpp create mode 100644 compiler/exo-tflite/src/Conversion/TensorTransposeConverter.h diff --git a/compiler/exo-tflite/src/Conversion/TensorTransposeConverter.cpp b/compiler/exo-tflite/src/Conversion/TensorTransposeConverter.cpp new file mode 100644 index 0000000..40b5502 --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/TensorTransposeConverter.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TensorTransposeConverter.h" + +#include "Dialect/IR/TFLNodes.h" +#include "Check.h" + +#include +#include + +#include +#include +#include + +namespace +{ + +void validate_perm(loco::TensorTranspose *origin) +{ + // check perm values are correct + std::vector base_perms; // such as {0, 1, 2, 3, ... } + std::vector perms; // perm values in TensorTranspose + + base_perms.resize(origin->perm()->size()); + perms.resize(origin->perm()->size()); + for (loco::TensorAxis x = 0; x < origin->perm()->size(); x++) + { + base_perms[x] = x; + perms[x] = origin->perm()->axis(x); + } + + if (!std::is_permutation(base_perms.begin(), base_perms.end(), perms.begin())) + EXO_THROW("wrong perm value"); +} + +} // namespace + +namespace exo +{ +/** + * @brief Converts loco::TensorTranspose to locoex::TFLTranspose + */ +bool TensorTransposeConverter::convert(loco::TensorTranspose *origin) +{ + auto *graph = origin->graph(); + + auto tfl_transpose = graph->nodes()->create(); + { + // validation + { + assert(origin->input() != nullptr); + + auto input_rank = loco::shape_get(origin->input()).as().rank(); + if (input_rank != origin->perm()->size()) + EXO_THROW("perm size should be same with input rank"); + + validate_perm(origin); + } + + tfl_transpose->a(origin->input()); + + // perm : set TFLConst + auto perm_const = graph->nodes()->create(); + { + perm_const->dtype(loco::DataType::S32); + perm_const->rank(1); + perm_const->dim(0) = origin->perm()->size(); + perm_const->size(origin->perm()->size()); + + // add perm values into perm TFLConst + for (loco::TensorAxis x = 0; x < origin->perm()->size(); x++) + { + perm_const->at(x) = origin->perm()->axis(x); + } + } + tfl_transpose->perm(perm_const); + } + + // replace canonical node + loco::replace(origin).with(tfl_transpose); + origin->input(nullptr); + + return true; +} + +} // namespace exo diff --git a/compiler/exo-tflite/src/Conversion/TensorTransposeConverter.h b/compiler/exo-tflite/src/Conversion/TensorTransposeConverter.h new file mode 100644 index 0000000..9b61ff3 --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/TensorTransposeConverter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_TENSORTRANSPOSE_CONVERTER__ +#define __CONVERSION_TENSORTRANSPOSE_CONVERTER__ + +#include "CanonicalNodeConverter.h" + +#include + +namespace exo +{ + +/** + * @brief Convert loco::TensorTranspose to locoex::TFLTranspose + */ +class TensorTransposeConverter : public CanonicalNodeConverter +{ +public: + const char *name(void) const final { return "exo::TensorTransposeConverter"; } + +public: + bool convert(loco::TensorTranspose *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_TENSORTRANSPOSE_CONVERTER__ -- 2.7.4