Add boolean type to tflite in favor of comparison implementations.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 13 Apr 2018 02:52:18 +0000 (19:52 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 02:54:58 +0000 (19:54 -0700)
PiperOrigin-RevId: 192711203

12 files changed:
tensorflow/contrib/lite/context.h
tensorflow/contrib/lite/interpreter.cc
tensorflow/contrib/lite/interpreter.h
tensorflow/contrib/lite/kernels/internal/tensor.h
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/optional_debug_tools.cc
tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
tensorflow/contrib/lite/schema/schema.fbs
tensorflow/contrib/lite/schema/schema_generated.h
tensorflow/contrib/lite/testing/split.h
tensorflow/contrib/lite/testing/split_test.cc
tensorflow/contrib/lite/testing/tflite_driver.cc

index 45184b0..0b38f43 100644 (file)
@@ -137,6 +137,7 @@ typedef enum {
   kTfLiteUInt8 = 3,
   kTfLiteInt64 = 4,
   kTfLiteString = 5,
+  kTfLiteBool = 6,
 } TfLiteType;
 
 // Parameters for asymmetric quantization. Quantized values can be converted
@@ -155,6 +156,7 @@ typedef union {
   char* raw;
   const char* raw_const;
   uint8_t* uint8;
+  bool* b;
 } TfLitePtrUnion;
 
 // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
index 4575fe8..f258654 100644 (file)
@@ -337,9 +337,13 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
     case kTfLiteInt64:
       *bytes = sizeof(int64_t) * count;
       break;
+    case kTfLiteBool:
+      *bytes = sizeof(bool) * count;
+      break;
     default:
-      ReportError(&context_,
-                  "Only float32, int32, int64, uint8 supported currently.");
+      ReportError(
+          &context_,
+          "Only float32, int32, int64, uint8, bool supported currently.");
       return kTfLiteError;
   }
   return kTfLiteOk;
index a6d582a..df67cce 100644 (file)
@@ -48,6 +48,10 @@ template <>
 constexpr TfLiteType typeToTfLiteType<unsigned char>() {
   return kTfLiteUInt8;
 }
+template <>
+constexpr TfLiteType typeToTfLiteType<bool>() {
+  return kTfLiteBool;
+}
 
 // Forward declare since NNAPIDelegate uses Interpreter.
 class NNAPIDelegate;
index 4bce2ff..62cea14 100644 (file)
@@ -44,6 +44,11 @@ inline int64_t* GetTensorData(TfLiteTensor* tensor) {
   return tensor != nullptr ? tensor->data.i64 : nullptr;
 }
 
+template <>
+inline bool* GetTensorData(TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
 inline int RemapDim(int max_dimensions, int d) {
   return max_dimensions - d - 1;
 }
index 87af953..0b65884 100644 (file)
@@ -57,6 +57,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
     case TensorType_STRING:
       *type = kTfLiteString;
       break;
+    case TensorType_BOOL:
+      *type = kTfLiteBool;
+      break;
     default:
       error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
                              EnumNameTensorType(tensor_type), tensor_type);
index 1f762e6..e136663 100644 (file)
@@ -48,6 +48,8 @@ const char* TensorTypeName(TfLiteType type) {
       return "kTfLiteInt64";
     case kTfLiteString:
       return "kTfLiteString";
+    case kTfLiteBool:
+      return "kTfLiteBool";
   }
   return "(invalid)";
 }
index 4b34969..04fc098 100644 (file)
@@ -72,6 +72,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
       return NPY_INT64;
     case kTfLiteString:
       return NPY_OBJECT;
+    case kTfLiteBool:
+      return NPY_BOOL;
     case kTfLiteNoType:
       return -1;
   }
@@ -90,6 +92,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
       return kTfLiteUInt8;
     case NPY_INT64:
       return kTfLiteInt64;
+    case NPY_BOOL:
+      return kTfLiteBool;
     case NPY_OBJECT:
     case NPY_STRING:
     case NPY_UNICODE:
index 3574937..fa82550 100644 (file)
@@ -33,6 +33,7 @@ enum TensorType : byte {
   UINT8 = 3,
   INT64 = 4,
   STRING = 5,
+  BOOL = 6,
 }
 
 // Parameters for converting a quantized tensor back to float. Given a
index c638daf..909c4cc 100755 (executable)
@@ -173,18 +173,20 @@ enum TensorType {
   TensorType_UINT8 = 3,
   TensorType_INT64 = 4,
   TensorType_STRING = 5,
+  TensorType_BOOL = 6,
   TensorType_MIN = TensorType_FLOAT32,
-  TensorType_MAX = TensorType_STRING
+  TensorType_MAX = TensorType_BOOL
 };
 
-inline TensorType (&EnumValuesTensorType())[6] {
+inline TensorType (&EnumValuesTensorType())[7] {
   static TensorType values[] = {
     TensorType_FLOAT32,
     TensorType_FLOAT16,
     TensorType_INT32,
     TensorType_UINT8,
     TensorType_INT64,
-    TensorType_STRING
+    TensorType_STRING,
+    TensorType_BOOL
   };
   return values;
 }
@@ -197,6 +199,7 @@ inline const char **EnumNamesTensorType() {
     "UINT8",
     "INT64",
     "STRING",
+    "BOOL",
     nullptr
   };
   return names;
index 428cfda..896f294 100644 (file)
@@ -80,6 +80,16 @@ inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
   return fields;
 }
 
+template <>
+inline std::vector<bool> Split(const string& s, const string& delimiter) {
+  std::vector<bool> fields;
+  for (const auto& p : SplitToPos(s, delimiter)) {
+    fields.push_back(
+        static_cast<bool>(strtol(s.data() + p.first, nullptr, 10)));
+  }
+  return fields;
+}
+
 }  // namespace testing
 }  // namespace tflite
 
index 3d1e25d..76b918c 100644 (file)
@@ -52,6 +52,11 @@ TEST(SplitTest, SplitUint8) {
   EXPECT_THAT(Split<uint8_t>("1,-1,258", ","), ElementsAre(1, 255, 2));
 }
 
+TEST(SplitTest, SplitBool) {
+  EXPECT_THAT(Split<bool>("1, 0, 0, 1", ","),
+              ElementsAre(true, false, false, true));
+}
+
 }  // namespace
 }  // namespace testing
 }  // namespace tflite
index 3764bab..58fe5bd 100644 (file)
@@ -42,6 +42,10 @@ template <>
 uint8_t Value(const TfLitePtrUnion& data, int index) {
   return data.uint8[index];
 }
+template <>
+bool Value(const TfLitePtrUnion& data, int index) {
+  return data.b[index];
+}
 
 template <typename T>
 void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
@@ -79,6 +83,8 @@ class TfLiteDriver::Expectation {
         return TypedCheck<int64_t>(verbose, tensor);
       case kTfLiteUInt8:
         return TypedCheck<uint8_t>(verbose, tensor);
+      case kTfLiteBool:
+        return TypedCheck<bool>(verbose, tensor);
       default:
         fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
         return false;
@@ -203,6 +209,12 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) {
       SetTensorData(values, &tensor->data);
       break;
     }
+    case kTfLiteBool: {
+      const auto& values = testing::Split<bool>(csv_values, ",");
+      if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
+      SetTensorData(values, &tensor->data);
+      break;
+    }
     default:
       fprintf(stderr, "Unsupported type %d in SetInput\n", tensor->type);
       Invalidate("Unsupported tensor data type");
@@ -231,6 +243,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
     case kTfLiteUInt8:
       expected_output_[id]->SetData<uint8_t>(csv_values);
       break;
+    case kTfLiteBool:
+      expected_output_[id]->SetData<bool>(csv_values);
+      break;
     default:
       fprintf(stderr, "Unsupported type %d in SetExpectation\n", tensor->type);
       Invalidate("Unsupported tensor data type");