case kTfLiteFloat32:
copyCast(in, out->data.f, num_elements);
break;
+ case kTfLiteBool:
+ copyCast(in, out->data.b, num_elements);
+ break;
default:
// Unsupported type.
return kTfLiteError;
return copyToTensor(input->data.uint8, output, num_elements);
case kTfLiteFloat32:
return copyToTensor(input->data.f, output, num_elements);
+ case kTfLiteBool:
+ return copyToTensor(input->data.b, output, num_elements);
default:
// Unsupported type.
return kTfLiteError;
ElementsAreArray({100, 20, 3, 0, 0, 1}));
}
+TEST(CastOpModel, CastFloatToBool) {
+ CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BOOL, {3, 2}});
+ m.PopulateTensor<float>(m.input(), {100.f, -1.0f, 0.f, 0.4f, 0.999f, 1.1f});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<bool>(m.output()),
+ ElementsAreArray({true, true, false, true, true, true}));
+}
+
+TEST(CastOpModel, CastBoolToFloat) {
+ CastOpModel m({TensorType_BOOL, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
+ m.PopulateTensor<bool>(m.input(), {true, true, false, true, false, true});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
+}
+
} // namespace
} // namespace tflite
int main(int argc, char** argv) {