From 66386e06a3746a944ed279b0b1688facd17a224a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 17 Oct 2019 15:26:13 +0900 Subject: [PATCH] [exo] Export TFLTransposeConv (#8260) This commit introduces export stage of TFLTransposeConv operator Signed-off-by: Cheongyo Bahk --- compiler/exo/src/TFLite/TFLOperationExporter.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/compiler/exo/src/TFLite/TFLOperationExporter.cpp b/compiler/exo/src/TFLite/TFLOperationExporter.cpp index 27a1dfd..23b2695 100644 --- a/compiler/exo/src/TFLite/TFLOperationExporter.cpp +++ b/compiler/exo/src/TFLite/TFLOperationExporter.cpp @@ -68,6 +68,7 @@ public: void visit(locoex::TFLSub *) final; // TODO TFLTanh void visit(locoex::TFLTranspose *) final; + void visit(locoex::TFLTransposeConv *) final; // FOR canonical nodes. These will be removed later void visit(loco::ReLU *) final; @@ -247,6 +248,27 @@ void OperationExporter::visit(locoex::TFLTranspose *node) gd._operators.push_back(op_offset); } +void OperationExporter::visit(locoex::TFLTransposeConv *node) +{ + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE_CONV); + + // Make input, output and options for operator + std::vector inputs_vec{get_tensor_index(node->inputSizes()), + get_tensor_index(node->filter()), + get_tensor_index(node->outBackprop())}; + std::vector outputs_vec{get_tensor_index(static_cast(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + tflite::Padding padding = getOpPadding(node->padding()); + auto options = + CreateTransposeConvOptions(builder, padding, node->stride()->w(), node->stride()->h()); + + // Make TRANSPOSE_CONV operator + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_TransposeConvOptions, options.Union()); + gd._operators.push_back(op_offset); +} + template void OperationExporter::export_pool_2d(TFLPool2D *node, tflite::BuiltinOperator builtin_op) { -- 2.7.4