kTfLiteUInt8 = 3,
kTfLiteInt64 = 4,
kTfLiteString = 5,
+ kTfLiteBool = 6,
} TfLiteType;
// Parameters for asymmetric quantization. Quantized values can be converted
char* raw;
const char* raw_const;
uint8_t* uint8;
+ bool* b;
} TfLitePtrUnion;
// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
case kTfLiteInt64:
*bytes = sizeof(int64_t) * count;
break;
+ case kTfLiteBool:
+ *bytes = sizeof(bool) * count;
+ break;
default:
- ReportError(&context_,
- "Only float32, int32, int64, uint8 supported currently.");
+ ReportError(
+ &context_,
+ "Only float32, int32, int64, uint8, bool supported currently.");
return kTfLiteError;
}
return kTfLiteOk;
constexpr TfLiteType typeToTfLiteType<unsigned char>() {
return kTfLiteUInt8;
}
+template <>
+constexpr TfLiteType typeToTfLiteType<bool>() {
+ return kTfLiteBool;
+}
// Forward declare since NNAPIDelegate uses Interpreter.
class NNAPIDelegate;
return tensor != nullptr ? tensor->data.i64 : nullptr;
}
+template <>
+inline bool* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
inline int RemapDim(int max_dimensions, int d) {
return max_dimensions - d - 1;
}
case TensorType_STRING:
*type = kTfLiteString;
break;
+ case TensorType_BOOL:
+ *type = kTfLiteBool;
+ break;
default:
error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
EnumNameTensorType(tensor_type), tensor_type);
return "kTfLiteInt64";
case kTfLiteString:
return "kTfLiteString";
+ case kTfLiteBool:
+ return "kTfLiteBool";
}
return "(invalid)";
}
return NPY_INT64;
case kTfLiteString:
return NPY_OBJECT;
+ case kTfLiteBool:
+ return NPY_BOOL;
case kTfLiteNoType:
return -1;
}
return kTfLiteUInt8;
case NPY_INT64:
return kTfLiteInt64;
+ case NPY_BOOL:
+ return kTfLiteBool;
case NPY_OBJECT:
case NPY_STRING:
case NPY_UNICODE:
UINT8 = 3,
INT64 = 4,
STRING = 5,
+ BOOL = 6,
}
// Parameters for converting a quantized tensor back to float. Given a
TensorType_UINT8 = 3,
TensorType_INT64 = 4,
TensorType_STRING = 5,
+ TensorType_BOOL = 6,
TensorType_MIN = TensorType_FLOAT32,
- TensorType_MAX = TensorType_STRING
+ TensorType_MAX = TensorType_BOOL
};
-inline TensorType (&EnumValuesTensorType())[6] {
+inline TensorType (&EnumValuesTensorType())[7] {
static TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
TensorType_INT32,
TensorType_UINT8,
TensorType_INT64,
- TensorType_STRING
+ TensorType_STRING,
+ TensorType_BOOL
};
return values;
}
"UINT8",
"INT64",
"STRING",
+ "BOOL",
nullptr
};
return names;
return fields;
}
+template <>
+inline std::vector<bool> Split(const string& s, const string& delimiter) {
+ std::vector<bool> fields;
+ for (const auto& p : SplitToPos(s, delimiter)) {
+ fields.push_back(
+ static_cast<bool>(strtol(s.data() + p.first, nullptr, 10)));
+ }
+ return fields;
+}
+
} // namespace testing
} // namespace tflite
EXPECT_THAT(Split<uint8_t>("1,-1,258", ","), ElementsAre(1, 255, 2));
}
+TEST(SplitTest, SplitBool) {
+ EXPECT_THAT(Split<bool>("1, 0, 0, 1", ","),
+ ElementsAre(true, false, false, true));
+}
+
} // namespace
} // namespace testing
} // namespace tflite
uint8_t Value(const TfLitePtrUnion& data, int index) {
return data.uint8[index];
}
+template <>
+bool Value(const TfLitePtrUnion& data, int index) {
+ return data.b[index];
+}
template <typename T>
void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
return TypedCheck<int64_t>(verbose, tensor);
case kTfLiteUInt8:
return TypedCheck<uint8_t>(verbose, tensor);
+ case kTfLiteBool:
+ return TypedCheck<bool>(verbose, tensor);
default:
fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
return false;
SetTensorData(values, &tensor->data);
break;
}
+ case kTfLiteBool: {
+ const auto& values = testing::Split<bool>(csv_values, ",");
+ if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
+ SetTensorData(values, &tensor->data);
+ break;
+ }
default:
fprintf(stderr, "Unsupported type %d in SetInput\n", tensor->type);
Invalidate("Unsupported tensor data type");
case kTfLiteUInt8:
expected_output_[id]->SetData<uint8_t>(csv_values);
break;
+ case kTfLiteBool:
+ expected_output_[id]->SetData<bool>(csv_values);
+ break;
default:
fprintf(stderr, "Unsupported type %d in SetExpectation\n", tensor->type);
Invalidate("Unsupported tensor data type");