From 5e93a0795b57a22c4c08424fb10fe3ca7e182707 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 3 Apr 2018 10:47:37 -0700 Subject: [PATCH] [TF:XLA] Add DT_HALF to the supported data types for the CPU and GPU devices. Add a test case to compilation_passes_test for DT_HALF data type. PiperOrigin-RevId: 191464288 --- .../compiler/jit/mark_for_compilation_pass_test.cc | 22 +++++++++++++++++++++- tensorflow/compiler/jit/xla_cpu_device.cc | 4 ++-- tensorflow/compiler/jit/xla_gpu_device.cc | 4 ++-- tensorflow/compiler/tf2xla/xla_op_registry.h | 8 ++++---- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 381c020..af19192 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -138,7 +138,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, UnsupportedTypes) { +TEST(XlaCompilationTest, Complex128Unsupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -158,6 +158,26 @@ TEST(XlaCompilationTest, UnsupportedTypes) { EXPECT_TRUE(clusters.empty()); } +TEST(XlaCompilationTest, HalfSupported) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_HALF) + .WithAttr("value", Tensor(DT_HALF, TensorShape()))); + Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + EXPECT_FALSE(clusters.empty()); +} + TEST(XlaCompilationTest, ConcatWithConstArg) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index d2dfdee..bc07dbd 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -62,8 +62,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 5a1db81..84c0d5f 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -62,8 +62,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index ff74531..da2a6c3 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -51,12 +51,12 @@ constexpr std::array kNumericTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, +constexpr std::array kCpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, +constexpr std::array kGpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; // Class that manages registrations of operators and devices for the XLA JIT. -- 2.7.4