From 84b5938aaee991d6909e16e56c66bf88e8843fbb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 May 2018 19:31:37 -0700 Subject: [PATCH] Add bool conversion in toco for tflite since bool is supported by tflite. PiperOrigin-RevId: 196339883 --- tensorflow/contrib/lite/toco/tflite/types.cc | 18 ++++++++++++++++++ tensorflow/contrib/lite/toco/tflite/types_test.cc | 15 +++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index c9c2e9b..4867c3a 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -36,6 +36,16 @@ DataBuffer::FlatBufferOffset CopyStringToBuffer( return builder->CreateVector(dst_data.data(), bytes); } +// vector may be implemented using a bit-set, so we can't just +// reinterpret_cast, accesing it data as vector and let flatbuffer +// CreateVector handle it. +// Background: https://isocpp.org/blog/2012/11/on-vectorbool +DataBuffer::FlatBufferOffset CopyBoolToBuffer( + const Array& array, flatbuffers::FlatBufferBuilder* builder) { + const auto& src_data = array.GetBuffer().data; + return builder->CreateVector(src_data); +} + template DataBuffer::FlatBufferOffset CopyBuffer( const Array& array, flatbuffers::FlatBufferBuilder* builder) { @@ -86,6 +96,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { return ::tflite::TensorType_UINT8; case ArrayDataType::kString: return ::tflite::TensorType_STRING; + case ArrayDataType::kBool: + return ::tflite::TensorType_BOOL; default: // FLOAT32 is filled for unknown data types. // TODO(ycling): Implement type inference in TF Lite interpreter. @@ -105,6 +117,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { return ArrayDataType::kString; case ::tflite::TensorType_UINT8: return ArrayDataType::kUint8; + case ::tflite::TensorType_BOOL: + return ArrayDataType::kBool; default: LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'."; } @@ -125,6 +139,8 @@ flatbuffers::Offset> DataBuffer::Serialize( return CopyStringToBuffer(array, builder); case ArrayDataType::kUint8: return CopyBuffer(array, builder); + case ArrayDataType::kBool: + return CopyBoolToBuffer(array, builder); default: LOG(FATAL) << "Unhandled array data type."; } @@ -146,6 +162,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, return CopyStringFromBuffer(buffer, array); case ::tflite::TensorType_UINT8: return CopyBuffer(buffer, array); + case ::tflite::TensorType_BOOL: + return CopyBuffer(buffer, array); default: LOG(FATAL) << "Unhandled tensor type."; } diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc index 29fb0b2..564f303 100644 --- a/tensorflow/contrib/lite/toco/tflite/types_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -28,8 +28,7 @@ using flatbuffers::Vector; // These are types that exist in TF Mini but don't have a correspondence // in TF Lite. -static const ArrayDataType kUnsupportedTocoTypes[] = {ArrayDataType::kNone, - ArrayDataType::kBool}; +static const ArrayDataType kUnsupportedTocoTypes[] = {ArrayDataType::kNone}; // These are TF Lite types for which there is no correspondence in TF Mini. static const ::tflite::TensorType kUnsupportedTfLiteTypes[] = { @@ -44,7 +43,7 @@ template Array ToFlatBufferAndBack(std::initializer_list<::toco::DataType> items) { // NOTE: This test does not construct the full buffers list. Since // Deserialize normally takes a buffer, we need to synthesize one and provide - // an index that is non-zero so the buffer is not assumed to be emtpy. + // an index that is non-zero so the buffer is not assumed to be empty. Array src; src.data_type = T; src.GetMutableBuffer().data = items; @@ -71,7 +70,8 @@ TEST(DataType, SupportedTypes) { {ArrayDataType::kUint8, ::tflite::TensorType_UINT8}, {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, - {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}}; + {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}, + {ArrayDataType::kBool, ::tflite::TensorType_BOOL}}; for (auto x : testdata) { EXPECT_EQ(x.second, DataType::Serialize(x.first)); EXPECT_EQ(x.first, DataType::Deserialize(x.second)); @@ -158,6 +158,13 @@ TEST(DataBuffer, String) { ::testing::ElementsAre("AA", "BBB", "Best. String. Ever.")); } +TEST(DataBuffer, Bool) { + Array recovered = + ToFlatBufferAndBack({true, false, true}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(true, false, true)); +} + TEST(Padding, All) { EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame)); EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME)); -- 2.7.4