From facd8f50733a398cc0ee08dfe76ad6b4f9e61817 Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Thu, 17 May 2018 14:58:04 -0700 Subject: [PATCH] Support Bool in Cast (TFLite) PiperOrigin-RevId: 197056978 --- tensorflow/contrib/lite/kernels/cast.cc | 5 +++++ tensorflow/contrib/lite/kernels/cast_test.cc | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 673eedc..60770ca 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -69,6 +69,9 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, case kTfLiteFloat32: copyCast(in, out->data.f, num_elements); break; + case kTfLiteBool: + copyCast(in, out->data.b, num_elements); + break; default: // Unsupported type. return kTfLiteError; @@ -90,6 +93,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return copyToTensor(input->data.uint8, output, num_elements); case kTfLiteFloat32: return copyToTensor(input->data.f, output, num_elements); + case kTfLiteBool: + return copyToTensor(input->data.b, output, num_elements); default: // Unsupported type. return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc index 4e56482..53e2000 100644 --- a/tensorflow/contrib/lite/kernels/cast_test.cc +++ b/tensorflow/contrib/lite/kernels/cast_test.cc @@ -57,6 +57,22 @@ TEST(CastOpModel, CastFloatToInt) { ElementsAreArray({100, 20, 3, 0, 0, 1})); } +TEST(CastOpModel, CastFloatToBool) { + CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BOOL, {3, 2}}); + m.PopulateTensor(m.input(), {100.f, -1.0f, 0.f, 0.4f, 0.999f, 1.1f}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({true, true, false, true, true, true})); +} + +TEST(CastOpModel, CastBoolToFloat) { + CastOpModel m({TensorType_BOOL, {3, 2}}, {TensorType_FLOAT32, {3, 2}}); + m.PopulateTensor(m.input(), {true, true, false, true, false, true}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f})); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { -- 2.7.4