From f02a557a7ff9cbbf173862639c73f566aa663f7d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 18 Jun 2019 12:50:17 +0900 Subject: [PATCH] [loco exporter] Refactor export Transpose (#3835) FeatureEncode, FeatureDecode are now exported as TRANSPOSE in tflite when they are not identity(NHWC). In future, FilterEncode will be as well. This commit refactors exporting TRANSPOSE in a form reusable by FilterEncode. Signed-off-by: Cheongyo Bahk --- contrib/loco-exporter/src/OperationExporter.cpp | 46 ++++++++++--------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/contrib/loco-exporter/src/OperationExporter.cpp b/contrib/loco-exporter/src/OperationExporter.cpp index b49e20d..57e2dc3 100644 --- a/contrib/loco-exporter/src/OperationExporter.cpp +++ b/contrib/loco-exporter/src/OperationExporter.cpp @@ -88,38 +88,18 @@ void exportIdentity(NodeT *node, FlatBufferBuilder &builder, SerializedModelData gd._operators.push_back(op_offset); } -/// @brief Export Feature Codec nodes as TRANSPOSE -void exportFeatureTranspose(loco::Node *node, FlatBufferBuilder &builder, - loco::Permutation *perm, bool inverted, - SerializedModelData &gd) +/// @brief Export loco nodes as TRANSPOSE +void exportAsTranspose(loco::Node *node, FlatBufferBuilder &builder, + std::vector &perm_vec_data, SerializedModelData &gd) { uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE); auto options = CreateTransposeOptions(builder); // Create constant tensor with perm vector - auto perm_vec_shape_offset = builder.CreateVector(std::vector{4}); - constexpr int perm_vec_size = 4; - std::vector perm_vec_data(perm_vec_size); - - using loco::FeatureAxis; - - if (!inverted) - { - perm_vec_data[0] = perm->axis(FeatureAxis::Count); - perm_vec_data[1] = perm->axis(FeatureAxis::Height); - perm_vec_data[2] = perm->axis(FeatureAxis::Width); - perm_vec_data[3] = perm->axis(FeatureAxis::Depth); - } - else - { - perm_vec_data[perm->axis(FeatureAxis::Count)] = 0; - perm_vec_data[perm->axis(FeatureAxis::Height)] = 1; - perm_vec_data[perm->axis(FeatureAxis::Width)] = 2; - perm_vec_data[perm->axis(FeatureAxis::Depth)] = 3; - } - + assert(perm_vec_data.size() == perm_vec_size); + auto perm_vec_shape_offset = builder.CreateVector(std::vector{perm_vec_size}); constexpr size_t raw_perm_vec_size = perm_vec_size * sizeof(int32_t); auto perm_vec_offset = @@ -165,7 +145,13 @@ void exportFeatureEncode(loco::FeatureEncode *node, FlatBufferBuilder &builder, } else { - exportFeatureTranspose(node, builder, perm, /*inverted*/ false, gd); + std::vector perm_vec_data(4); + perm_vec_data[0] = perm->axis(loco::FeatureAxis::Count); + perm_vec_data[1] = perm->axis(loco::FeatureAxis::Height); + perm_vec_data[2] = perm->axis(loco::FeatureAxis::Width); + perm_vec_data[3] = perm->axis(loco::FeatureAxis::Depth); + + exportAsTranspose(node, builder, perm_vec_data, gd); } } @@ -182,7 +168,13 @@ void exportFeatureDecode(loco::FeatureDecode *node, FlatBufferBuilder &builder, } else { - exportFeatureTranspose(node, builder, perm, /*inverted*/ true, gd); + std::vector perm_vec_data(4); + perm_vec_data[perm->axis(loco::FeatureAxis::Count)] = 0; + perm_vec_data[perm->axis(loco::FeatureAxis::Height)] = 1; + perm_vec_data[perm->axis(loco::FeatureAxis::Width)] = 2; + perm_vec_data[perm->axis(loco::FeatureAxis::Depth)] = 3; + + exportAsTranspose(node, builder, perm_vec_data, gd); } } -- 2.7.4