From 7053d53ebee323f52d5fa52f1e74d9dace7632b2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Principal=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 23 Nov 2018 13:06:01 +0900 Subject: [PATCH] [tflchef] Support int32 tensortype (#2388) This will fix to support int32 tensortype with introducing as_tflite_tensortype() Signed-off-by: SaeHie Park --- contrib/tflchef/core/src/Convert.cpp | 15 +++++++++++++++ contrib/tflchef/core/src/Convert.h | 1 + contrib/tflchef/core/src/ModelChef.cpp | 4 +++- contrib/tflchef/proto/tflchef.proto | 1 + 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/contrib/tflchef/core/src/Convert.cpp b/contrib/tflchef/core/src/Convert.cpp index 7adddef..554eb36 100644 --- a/contrib/tflchef/core/src/Convert.cpp +++ b/contrib/tflchef/core/src/Convert.cpp @@ -49,3 +49,18 @@ tflite::ActivationFunctionType as_tflite_activation(const tflchef::Activation &v throw std::runtime_error{"Unknown activation"}; } + +tflite::TensorType as_tflite_tensortype(const tflchef::TensorType &value) +{ + switch (value) + { + case tflchef::FLOAT32: + return tflite::TensorType_FLOAT32; + case tflchef::INT32: + return tflite::TensorType_INT32; + default: + break; + } + + throw std::runtime_error{"Unknown tensor type"}; +} diff --git a/contrib/tflchef/core/src/Convert.h b/contrib/tflchef/core/src/Convert.h index b4e763e..1685fcb 100644 --- a/contrib/tflchef/core/src/Convert.h +++ b/contrib/tflchef/core/src/Convert.h @@ -26,5 +26,6 @@ tflite::Padding as_tflite_padding(const tflchef::Padding &value); tflite::ActivationFunctionType as_tflite_activation(const tflchef::Activation &value); +tflite::TensorType as_tflite_tensortype(const tflchef::TensorType &value); #endif // __CONVERT_H__ diff --git a/contrib/tflchef/core/src/ModelChef.cpp b/contrib/tflchef/core/src/ModelChef.cpp index 278f225..c6feba1 100644 --- a/contrib/tflchef/core/src/ModelChef.cpp +++ b/contrib/tflchef/core/src/ModelChef.cpp @@ -17,6 +17,8 @@ #include "tflchef/ModelChef.h" #include "Arguments.h" +#include "Convert.h" + #include "DataChef.h" #include "DataChefs.h" @@ -281,7 +283,7 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe) tflite::TensorBuilder tensor_builder{*flatbuffer_builder}; tensor_builder.add_shape(shape); - tensor_builder.add_type(tflite::TensorType_FLOAT32); + tensor_builder.add_type(as_tflite_tensortype(operand.type())); tensor_builder.add_buffer(buffer_index); tensor_builder.add_name(name); diff --git a/contrib/tflchef/proto/tflchef.proto b/contrib/tflchef/proto/tflchef.proto index 756af50..2a96f8e 100644 --- a/contrib/tflchef/proto/tflchef.proto +++ b/contrib/tflchef/proto/tflchef.proto @@ -5,6 +5,7 @@ package tflchef; // This enum value corresponds to TensorType in TensorFlow Lite schema enum TensorType { FLOAT32 = 0; + INT32 = 2; } message TensorShape { -- 2.7.4