From ebacb1c85b4bbff789cf1fd246b6fdeabb3db9c8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 16 Oct 2019 15:51:09 +0900 Subject: [PATCH] [exo] Relu6Converter and TFL/CircleOperationExporter for TFLRelu6 (#8173) This adds Relu6Converter and TFL/CircleOperationExporter for TFLRelu6. Signed-off-by: Hyun Sik Yoon --- .../exo/src/Circle/CircleOperationExporter.cpp | 13 ++++- compiler/exo/src/Conversion/Relu6Converter.cpp | 65 ++++++++++++++++++++++ compiler/exo/src/Conversion/Relu6Converter.h | 41 ++++++++++++++ compiler/exo/src/TFLite/TFLOperationExporter.cpp | 13 ++++- 4 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 compiler/exo/src/Conversion/Relu6Converter.cpp create mode 100644 compiler/exo/src/Conversion/Relu6Converter.h diff --git a/compiler/exo/src/Circle/CircleOperationExporter.cpp b/compiler/exo/src/Circle/CircleOperationExporter.cpp index f116443..79f61e1 100644 --- a/compiler/exo/src/Circle/CircleOperationExporter.cpp +++ b/compiler/exo/src/Circle/CircleOperationExporter.cpp @@ -61,7 +61,7 @@ public: void visit(locoex::TFLMaxPool2D *) final; void visit(locoex::TFLMul *) final; void visit(locoex::TFLRelu *) final; - // TODO TFLRelu6 + void visit(locoex::TFLRelu6 *) final; // TODO TFLReshape // TODO TFLSoftmax // TODO TFLSqrt @@ -199,7 +199,16 @@ void OperationExporter::visit(locoex::TFLRelu *node) gd._operators.push_back(op_offset); } -// TODO TFLRelu6 +void OperationExporter::visit(locoex::TFLRelu6 *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(circle::BuiltinOperator_RELU6); + std::vector inputs_vec{get_tensor_index(node->features())}; + std::vector outputs_vec{get_tensor_index(static_cast(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} // TODO TFLReshape diff --git a/compiler/exo/src/Conversion/Relu6Converter.cpp b/compiler/exo/src/Conversion/Relu6Converter.cpp new file mode 100644 index 0000000..8691ddb --- /dev/null +++ b/compiler/exo/src/Conversion/Relu6Converter.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Relu6Converter.h" + +#include "GraphBlock.h" +#include "Check.h" + +#include "Dialect/IR/TFLNodes.h" + +#include + +namespace exo +{ + +bool Relu6Converter::convert(loco::ReLU6 *origin) +{ + if (!loco::shape_known(origin)) + { + return false; + } + + if (loco::shape_get(origin).domain() == loco::Domain::Tensor) + { + auto tfl_relu6 = origin->graph()->nodes()->create(); + tfl_relu6->features(origin->input()); + + loco::replace(origin).with(tfl_relu6); + origin->input(nullptr); + + return true; + } + else if (loco::shape_get(origin).domain() == loco::Domain::Feature) + { + auto graph = origin->graph(); + auto dec = make_feature_decode(origin->input()); + auto tfl_relu6 = graph->nodes()->create(); + { + tfl_relu6->features(dec); + } + auto enc = make_feature_encode(tfl_relu6); + + loco::replace(origin).with(enc); + origin->input(nullptr); + + return true; + } + else + EXO_THROW("Not yet supported loco::Domain"); +} + +} // namespace exo diff --git a/compiler/exo/src/Conversion/Relu6Converter.h b/compiler/exo/src/Conversion/Relu6Converter.h new file mode 100644 index 0000000..d987b42 --- /dev/null +++ b/compiler/exo/src/Conversion/Relu6Converter.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __CONVERSION_RELU6_CONVERTER_H__ +#define __CONVERSION_RELU6_CONVERTER_H__ + +#include "CanonicalNodeConverter.h" + +#include + +namespace exo +{ + +/** + * @brief Convert loco::Relu6 to TFLRelu6 + */ +class Relu6Converter : public CanonicalNodeConverter +{ +public: + const char *name(void) const final { return "exo::Relu6Converter"; } + +public: + bool convert(loco::ReLU6 *origin) final; +}; + +} // namespace exo + +#endif // __CONVERSION_RELU6_CONVERTER_H__ diff --git a/compiler/exo/src/TFLite/TFLOperationExporter.cpp b/compiler/exo/src/TFLite/TFLOperationExporter.cpp index bc4b564..27a1dfd 100644 --- a/compiler/exo/src/TFLite/TFLOperationExporter.cpp +++ b/compiler/exo/src/TFLite/TFLOperationExporter.cpp @@ -61,7 +61,7 @@ public: void visit(locoex::TFLMaxPool2D *) final; void visit(locoex::TFLMul *) final; void visit(locoex::TFLRelu *) final; - // TODO TFLRelu6 + void visit(locoex::TFLRelu6 *) final; // TODO TFLReshape // TODO TFLSoftmax // TODO TFLSqrt @@ -199,7 +199,16 @@ void OperationExporter::visit(locoex::TFLRelu *node) gd._operators.push_back(op_offset); } -// TODO TFLRelu6 +void OperationExporter::visit(locoex::TFLRelu6 *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU6); + std::vector inputs_vec{get_tensor_index(node->features())}; + std::vector outputs_vec{get_tensor_index(static_cast(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs); + gd._operators.push_back(op_offset); +} // TODO TFLReshape -- 2.7.4