From 7fcb86028c8911a7c46e1fc6ce2fb2266527d3ca Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Principal=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Tue, 27 Nov 2018 13:05:44 +0900 Subject: [PATCH] [tflchef/rev] Write reshape explicit values (#2418) This will change tflchef-reverse to provide explicit values of shape information for the operand used by reshape operator Signed-off-by: SaeHie Park --- contrib/tflchef/tests/CMakeLists.txt | 1 + contrib/tflchef/tflite/src/Op/Reshape.cpp | 12 +++++++++++- contrib/tflchef/tflite/src/RecipeChef.cpp | 13 +++++++++++++ contrib/tflchef/tflite/src/TFliteImport.h | 23 +++++++++++++++++++++++ 4 files changed, 48 insertions(+), 1 deletion(-) diff --git a/contrib/tflchef/tests/CMakeLists.txt b/contrib/tflchef/tests/CMakeLists.txt index c20be34..f6a9ffc 100644 --- a/contrib/tflchef/tests/CMakeLists.txt +++ b/contrib/tflchef/tests/CMakeLists.txt @@ -41,6 +41,7 @@ list(APPEND GEN_TFLITEFILES "maxpool2d/test.recipe") list(APPEND GEN_TFLITEFILES "relu/test.recipe") list(APPEND GEN_TFLITEFILES "relu6/test.recipe") list(APPEND GEN_TFLITEFILES "reshape/test.recipe") +list(APPEND GEN_TFLITEFILES "explicit_datachef/test.recipe") # TODO if all operates are added we can just use one line of "*/test.recipe" foreach(TFLITEFILE IN ITEMS ${GEN_TFLITEFILES}) diff --git a/contrib/tflchef/tflite/src/Op/Reshape.cpp b/contrib/tflchef/tflite/src/Op/Reshape.cpp index 5bb403e..663ab3e 100644 --- a/contrib/tflchef/tflite/src/Op/Reshape.cpp +++ b/contrib/tflchef/tflite/src/Op/Reshape.cpp @@ -24,7 +24,17 @@ namespace tflchef void TFliteOpReshape::filler(const tflite::Operator *op, TFliteImport *import, tflchef::ModelRecipe *model_recipe) const { - // Nothing to do with filler + const std::vector &inputs = as_index_vector(op->inputs()); + + bool hasShape = (inputs.size() == 2); + assert(inputs.size() == 1 || hasShape); + + if (hasShape) + { + auto op_params = op->builtin_options_as_ReshapeOptions(); + std::vector new_shape = as_index_vector(op_params->new_shape()); + import->set_tensor_filler(inputs.at(1), new_shape); + } } tflchef::Operation *TFliteOpReshape::build(const tflite::Operator *op, TFliteImport *import, diff --git a/contrib/tflchef/tflite/src/RecipeChef.cpp b/contrib/tflchef/tflite/src/RecipeChef.cpp index 4c6553f..836c12e 100644 --- a/contrib/tflchef/tflite/src/RecipeChef.cpp +++ b/contrib/tflchef/tflite/src/RecipeChef.cpp @@ -23,6 +23,7 @@ #include "TFliteOpRegistry.h" #include +#include namespace tflchef { @@ -113,6 +114,7 @@ std::unique_ptr generate_recipe(const tflite::Model *model) } // filler for weights, bias and so on + std::vector expvalues; if (tflite_import.get_tensor_filler(i)) { tflchef::TensorFiller *filler = operand->mutable_filler(); @@ -121,6 +123,17 @@ std::unique_ptr generate_recipe(const tflite::Model *model) filler->add_arg("0.0"); // average filler->add_arg("0.1"); // standard deviation } + else if (tflite_import.get_tensor_filler(i, expvalues)) + { + tflchef::TensorFiller *filler = operand->mutable_filler(); + filler->set_tag("explicit"); + for (auto value : expvalues) + { + std::ostringstream ss; + ss << value; + filler->add_arg(ss.str()); + } + } } // add all operators diff --git a/contrib/tflchef/tflite/src/TFliteImport.h b/contrib/tflchef/tflite/src/TFliteImport.h index 6d55517..ade8fc8 100644 --- a/contrib/tflchef/tflite/src/TFliteImport.h +++ b/contrib/tflchef/tflite/src/TFliteImport.h @@ -70,6 +70,14 @@ public: void set_tensor_filler(uint32_t tensor_index) { _tensor_filler[tensor_index] = true; } /** + * @brief This will store int32 filler values such as reshape information for the tensor + */ + void set_tensor_filler(uint32_t tensor_index, std::vector &expvalues) + { + _tensor_filler_vint32[tensor_index] = expvalues; + } + + /** * @brief This will return true if the tensor by index, needs a filler option. */ bool get_tensor_filler(uint32_t tensor_index) @@ -82,6 +90,20 @@ public: return false; } + /** + * @brief This will return true if the tensor by index, needs a int array filler option. + */ + bool get_tensor_filler(uint32_t tensor_index, std::vector &expvalues) + { + auto it = _tensor_filler_vint32.find(tensor_index); + if (it != _tensor_filler_vint32.end()) + { + expvalues = it->second; + return true; + } + return false; + } + private: const TFliteSubGraphs_t *_subgraphs; const TFliteBuffers_t *_buffers; @@ -93,6 +115,7 @@ private: std::vector _outputs; std::map _tensor_filler; + std::map> _tensor_filler_vint32; }; } // namespace tflchef -- 2.7.4