Support Bool in Cast (TFLite)
authorYu-Cheng Ling <ycling@google.com>
Thu, 17 May 2018 21:58:04 +0000 (14:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 22:00:50 +0000 (15:00 -0700)
PiperOrigin-RevId: 197056978

tensorflow/contrib/lite/kernels/cast.cc
tensorflow/contrib/lite/kernels/cast_test.cc

index 673eedc..60770ca 100644 (file)
@@ -69,6 +69,9 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
     case kTfLiteFloat32:
       copyCast(in, out->data.f, num_elements);
       break;
+    case kTfLiteBool:
+      copyCast(in, out->data.b, num_elements);
+      break;
     default:
       // Unsupported type.
       return kTfLiteError;
@@ -90,6 +93,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       return copyToTensor(input->data.uint8, output, num_elements);
     case kTfLiteFloat32:
       return copyToTensor(input->data.f, output, num_elements);
+    case kTfLiteBool:
+      return copyToTensor(input->data.b, output, num_elements);
     default:
       // Unsupported type.
       return kTfLiteError;
index 4e56482..53e2000 100644 (file)
@@ -57,6 +57,22 @@ TEST(CastOpModel, CastFloatToInt) {
               ElementsAreArray({100, 20, 3, 0, 0, 1}));
 }
 
+TEST(CastOpModel, CastFloatToBool) {
+  CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BOOL, {3, 2}});
+  m.PopulateTensor<float>(m.input(), {100.f, -1.0f, 0.f, 0.4f, 0.999f, 1.1f});
+  m.Invoke();
+  EXPECT_THAT(m.ExtractVector<bool>(m.output()),
+              ElementsAreArray({true, true, false, true, true, true}));
+}
+
+TEST(CastOpModel, CastBoolToFloat) {
+  CastOpModel m({TensorType_BOOL, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
+  m.PopulateTensor<bool>(m.input(), {true, true, false, true, false, true});
+  m.Invoke();
+  EXPECT_THAT(m.ExtractVector<float>(m.output()),
+              ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
+}
+
 }  // namespace
 }  // namespace tflite
 int main(int argc, char** argv) {