From 3c9870524b86fe7e3cff5a49daa692cd52e7f0c4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 12 Apr 2018 19:52:18 -0700 Subject: [PATCH] Add boolean type to tflite in favor of comparison implementations. PiperOrigin-RevId: 192711203 --- tensorflow/contrib/lite/context.h | 2 ++ tensorflow/contrib/lite/interpreter.cc | 8 ++++++-- tensorflow/contrib/lite/interpreter.h | 4 ++++ tensorflow/contrib/lite/kernels/internal/tensor.h | 5 +++++ tensorflow/contrib/lite/model.cc | 3 +++ tensorflow/contrib/lite/optional_debug_tools.cc | 2 ++ .../python/interpreter_wrapper/interpreter_wrapper.cc | 4 ++++ tensorflow/contrib/lite/schema/schema.fbs | 1 + tensorflow/contrib/lite/schema/schema_generated.h | 9 ++++++--- tensorflow/contrib/lite/testing/split.h | 10 ++++++++++ tensorflow/contrib/lite/testing/split_test.cc | 5 +++++ tensorflow/contrib/lite/testing/tflite_driver.cc | 15 +++++++++++++++ 12 files changed, 63 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 45184b0..0b38f43 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -137,6 +137,7 @@ typedef enum { kTfLiteUInt8 = 3, kTfLiteInt64 = 4, kTfLiteString = 5, + kTfLiteBool = 6, } TfLiteType; // Parameters for asymmetric quantization. Quantized values can be converted @@ -155,6 +156,7 @@ typedef union { char* raw; const char* raw_const; uint8_t* uint8; + bool* b; } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 4575fe8..f258654 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -337,9 +337,13 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, case kTfLiteInt64: *bytes = sizeof(int64_t) * count; break; + case kTfLiteBool: + *bytes = sizeof(bool) * count; + break; default: - ReportError(&context_, - "Only float32, int32, int64, uint8 supported currently."); + ReportError( + &context_, + "Only float32, int32, int64, uint8, bool supported currently."); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index a6d582a..df67cce 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -48,6 +48,10 @@ template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteUInt8; } +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteBool; +} // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index 4bce2ff..62cea14 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -44,6 +44,11 @@ inline int64_t* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.i64 : nullptr; } +template <> +inline bool* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.b : nullptr; +} + inline int RemapDim(int max_dimensions, int d) { return max_dimensions - d - 1; } diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 87af953..0b65884 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -57,6 +57,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_STRING: *type = kTfLiteString; break; + case TensorType_BOOL: + *type = kTfLiteBool; + break; default: error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", EnumNameTensorType(tensor_type), tensor_type); diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc index 1f762e6..e136663 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.cc +++ b/tensorflow/contrib/lite/optional_debug_tools.cc @@ -48,6 +48,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteInt64"; case kTfLiteString: return "kTfLiteString"; + case kTfLiteBool: + return "kTfLiteBool"; } return "(invalid)"; } diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 4b34969..04fc098 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -72,6 +72,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_INT64; case kTfLiteString: return NPY_OBJECT; + case kTfLiteBool: + return NPY_BOOL; case kTfLiteNoType: return -1; } @@ -90,6 +92,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { return kTfLiteUInt8; case NPY_INT64: return kTfLiteInt64; + case NPY_BOOL: + return kTfLiteBool; case NPY_OBJECT: case NPY_STRING: case NPY_UNICODE: diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 3574937..fa82550 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -33,6 +33,7 @@ enum TensorType : byte { UINT8 = 3, INT64 = 4, STRING = 5, + BOOL = 6, } // Parameters for converting a quantized tensor back to float. Given a diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index c638daf..909c4cc 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -173,18 +173,20 @@ enum TensorType { TensorType_UINT8 = 3, TensorType_INT64 = 4, TensorType_STRING = 5, + TensorType_BOOL = 6, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_STRING + TensorType_MAX = TensorType_BOOL }; -inline TensorType (&EnumValuesTensorType())[6] { +inline TensorType (&EnumValuesTensorType())[7] { static TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, TensorType_INT32, TensorType_UINT8, TensorType_INT64, - TensorType_STRING + TensorType_STRING, + TensorType_BOOL }; return values; } @@ -197,6 +199,7 @@ inline const char **EnumNamesTensorType() { "UINT8", "INT64", "STRING", + "BOOL", nullptr }; return names; diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/contrib/lite/testing/split.h index 428cfda..896f294 100644 --- a/tensorflow/contrib/lite/testing/split.h +++ b/tensorflow/contrib/lite/testing/split.h @@ -80,6 +80,16 @@ inline std::vector Split(const string& s, const string& delimiter) { return fields; } +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back( + static_cast(strtol(s.data() + p.first, nullptr, 10))); + } + return fields; +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/split_test.cc b/tensorflow/contrib/lite/testing/split_test.cc index 3d1e25d..76b918c 100644 --- a/tensorflow/contrib/lite/testing/split_test.cc +++ b/tensorflow/contrib/lite/testing/split_test.cc @@ -52,6 +52,11 @@ TEST(SplitTest, SplitUint8) { EXPECT_THAT(Split("1,-1,258", ","), ElementsAre(1, 255, 2)); } +TEST(SplitTest, SplitBool) { + EXPECT_THAT(Split("1, 0, 0, 1", ","), + ElementsAre(true, false, false, true)); +} + } // namespace } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 3764bab..58fe5bd 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -42,6 +42,10 @@ template <> uint8_t Value(const TfLitePtrUnion& data, int index) { return data.uint8[index]; } +template <> +bool Value(const TfLitePtrUnion& data, int index) { + return data.b[index]; +} template void SetTensorData(const std::vector& values, TfLitePtrUnion* data) { @@ -79,6 +83,8 @@ class TfLiteDriver::Expectation { return TypedCheck(verbose, tensor); case kTfLiteUInt8: return TypedCheck(verbose, tensor); + case kTfLiteBool: + return TypedCheck(verbose, tensor); default: fprintf(stderr, "Unsupported type %d in Check\n", tensor.type); return false; @@ -203,6 +209,12 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) { SetTensorData(values, &tensor->data); break; } + case kTfLiteBool: { + const auto& values = testing::Split(csv_values, ","); + if (!CheckSizes(tensor->bytes, values.size())) return; + SetTensorData(values, &tensor->data); + break; + } default: fprintf(stderr, "Unsupported type %d in SetInput\n", tensor->type); Invalidate("Unsupported tensor data type"); @@ -231,6 +243,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) { case kTfLiteUInt8: expected_output_[id]->SetData(csv_values); break; + case kTfLiteBool: + expected_output_[id]->SetData(csv_values); + break; default: fprintf(stderr, "Unsupported type %d in SetExpectation\n", tensor->type); Invalidate("Unsupported tensor data type"); -- 2.7.4