Enable half precision convolution for the CPU and GPU backends.
authorBixia Zheng <bixia@google.com>
Thu, 15 Feb 2018 18:39:04 +0000 (10:39 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Feb 2018 18:42:55 +0000 (10:42 -0800)
Enhance the CPU IR emitter to support F16 dot operation and convolution
operation.
Add a CPU runtime implementation for F16 convolution.
Enhance the GPU backend to handle F16 convolution thunk.
Convert some F32 xla convolution tests to support both F32 and F16 and disable
the tests for the CPU backend due to b/72509305.

PiperOrigin-RevId: 185862438

25 files changed:
tensorflow/compiler/xla/array.h
tensorflow/compiler/xla/array2d.h
tensorflow/compiler/xla/array2d_test.cc
tensorflow/compiler/xla/array3d.h
tensorflow/compiler/xla/array3d_test.cc
tensorflow/compiler/xla/array4d.h
tensorflow/compiler/xla/array4d_test.cc
tensorflow/compiler/xla/array_test.cc
tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
tensorflow/compiler/xla/service/cpu/cpu_runtime.h
tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
tensorflow/compiler/xla/service/cpu/ir_emitter.cc
tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc
tensorflow/compiler/xla/service/cpu/runtime_conv2d.h
tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h
tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc
tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h
tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/tests/convolution_test.cc
tensorflow/compiler/xla/tests/test_macros.h

index 71aa057cd3a1c273c0e851497a78f94ba37c778e..46ee4e64c9ae7ca111d9d04bedcb74ff02a42386 100644 (file)
@@ -121,6 +121,23 @@ class Array {
     CHECK(idx == num_elements());
   }
 
+  // Creates a 2D array of Eigen::half from the given nested initializer list of
+  // float values.
+  template <typename T2, typename = typename std::enable_if<
+                             std::is_same<T, Eigen::half>::value &&
+                             std::is_same<T2, float>::value>::type>
+  Array(std::initializer_list<std::initializer_list<T2>> values)
+      : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
+    int64 idx = 0;
+    for (const auto& it1 : values) {
+      for (const auto& it2 : it1) {
+        values_[idx] = static_cast<T>(it2);
+        ++idx;
+      }
+    }
+    CHECK(idx == num_elements());
+  }
+
   // Creates a 3D array from the given nested initializer list. The outer
   // initializer list is the first dimension, and so on.
   Array(InitializerList3D values)
@@ -138,6 +155,27 @@ class Array {
     CHECK(idx == num_elements());
   }
 
+  // Creates a 3D array of Eigen::half from the given nested initializer list of
+  // float values.
+  template <typename T2, typename = typename std::enable_if<
+                             std::is_same<T, Eigen::half>::value &&
+                             std::is_same<T2, float>::value>::type>
+  Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
+            values)
+      : Array(ToInt64Vector({values.size(), values.begin()->size(),
+                             values.begin()->begin()->size()})) {
+    int64 idx = 0;
+    for (const auto& it1 : values) {
+      for (const auto& it2 : it1) {
+        for (const auto& it3 : it2) {
+          values_[idx] = static_cast<T>(it3);
+          ++idx;
+        }
+      }
+    }
+    CHECK(idx == num_elements());
+  }
+
   // Creates a 4D array from the given nested initializer list. The outer
   // initializer list is the first dimension, and so on.
   Array(InitializerList4D values)
@@ -158,6 +196,31 @@ class Array {
     CHECK(idx == num_elements());
   }
 
+  // Creates a 4D array of Eigen::half from the given nested initializer list of
+  // float values.
+  template <typename T2, typename = typename std::enable_if<
+                             std::is_same<T, Eigen::half>::value &&
+                             std::is_same<T2, float>::value>::type>
+  Array(std::initializer_list<
+        std::initializer_list<std::initializer_list<std::initializer_list<T2>>>>
+            values)
+      : Array(ToInt64Vector({values.size(), values.begin()->size(),
+                             values.begin()->begin()->size(),
+                             values.begin()->begin()->begin()->size()})) {
+    int64 idx = 0;
+    for (const auto& it1 : values) {
+      for (const auto& it2 : it1) {
+        for (const auto& it3 : it2) {
+          for (const auto& it4 : it3) {
+            values_[idx] = static_cast<T>(it4);
+            ++idx;
+          }
+        }
+      }
+    }
+    CHECK(idx == num_elements());
+  }
+
   Array(const Array<T>& other)
       : sizes_(other.sizes_), values_(new T[num_elements()]) {
     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
@@ -185,7 +248,7 @@ class Array {
   // Fills the array with the sequence i*multiplier for i=0,1,...
   void FillWithMultiples(const T& multiplier) {
     for (int64 i = 0; i < num_elements(); ++i) {
-      values_[i] = i * multiplier;
+      values_[i] = static_cast<T>(i) * multiplier;
     }
   }
 
index bb85fbee9b97fd6b9b0bf7223a9b820989dcbfa7..41f563486d21e42e88dcf6c751ce4a64da5e3213 100644 (file)
@@ -52,6 +52,14 @@ class Array2D : public Array<T> {
   Array2D(std::initializer_list<std::initializer_list<T>> values)
       : Array<T>(values) {}
 
+  // Creates an array of Eigen::half from the given nested initializer list of
+  // float values.
+  template <typename T2, typename = typename std::enable_if<
+                             std::is_same<T, Eigen::half>::value &&
+                             std::is_same<T2, float>::value>::type>
+  Array2D(std::initializer_list<std::initializer_list<T2>> values)
+      : Array<T>(values) {}
+
   Array2D(const Array2D<T>& other) : Array<T>(other) {}
 
   int64 n1() const { return this->dim(0); }
index c08e42c20ee684dfad8268aa8223440fbfad8a33..93034a719bfbd6724c007059715754677f3f1e62 100644 (file)
@@ -63,6 +63,20 @@ TEST(Array2dTest, InitializerListCtor) {
   EXPECT_EQ(arr(1, 2), 6);
 }
 
+TEST(Array2dTest, InitializerListCtorHalf) {
+  Array2D<Eigen::half> arr = {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}};
+
+  EXPECT_EQ(arr.n1(), 2);
+  EXPECT_EQ(arr.n2(), 3);
+
+  EXPECT_EQ(arr(0, 0), static_cast<Eigen::half>(1));
+  EXPECT_EQ(arr(0, 1), static_cast<Eigen::half>(2));
+  EXPECT_EQ(arr(0, 2), static_cast<Eigen::half>(3));
+  EXPECT_EQ(arr(1, 0), static_cast<Eigen::half>(4));
+  EXPECT_EQ(arr(1, 1), static_cast<Eigen::half>(5));
+  EXPECT_EQ(arr(1, 2), static_cast<Eigen::half>(6));
+}
+
 TEST(Array2dTest, Accessors) {
   Array2D<int> arr = {{1, 2, 3}, {4, 5, 6}};
 
index a1c5840a5f3874e27043c821ed4684da2fa6c542..e5eb235d45d160d486d1499db665ed14a8509043 100644 (file)
@@ -57,6 +57,16 @@ class Array3D : public Array<T> {
               values)
       : Array<T>(values) {}
 
+  // Creates an array of Eigen::half from the given nested initializer list of
+  // float values.
+  template <typename T2, typename = typename std::enable_if<
+                             std::is_same<T, Eigen::half>::value &&
+                             std::is_same<T2, float>::value>::type>
+  Array3D(
+      std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
+          values)
+      : Array<T>(values) {}
+
   int64 n1() const { return this->dim(0); }
   int64 n2() const { return this->dim(1); }
   int64 n3() const { return this->dim(2); }
index 6b5f4b343b2113652758bbd5ce0fc803239c1266..691ff6c03594a98a12e0fdd2151c4c2a2c9c128a 100644 (file)
@@ -69,6 +69,29 @@ TEST(Array3dTest, InitializerListCtor) {
   EXPECT_EQ(arr(2, 3, 1), 24);
 }
 
+TEST(Array3dTest, InitializerListCtorHalf) {
+  Array3D<Eigen::half> arr = {
+      {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}, {7.0f, 8.0f}},
+      {{9.0f, 10.0f}, {11.0f, 12.0f}, {13.0f, 14.0f}, {15.0f, 16.0f}},
+      {{17.0f, 18.0f}, {19.0f, 20.0f}, {21.0f, 22.0f}, {23.0f, 24.0f}}};
+
+  EXPECT_EQ(arr.n1(), 3);
+  EXPECT_EQ(arr.n2(), 4);
+  EXPECT_EQ(arr.n3(), 2);
+  EXPECT_EQ(arr.num_elements(), 24);
+
+  EXPECT_EQ(arr(0, 0, 0), static_cast<Eigen::half>(1));
+  EXPECT_EQ(arr(0, 0, 1), static_cast<Eigen::half>(2));
+  EXPECT_EQ(arr(0, 1, 0), static_cast<Eigen::half>(3));
+  EXPECT_EQ(arr(0, 3, 1), static_cast<Eigen::half>(8));
+  EXPECT_EQ(arr(1, 0, 0), static_cast<Eigen::half>(9));
+  EXPECT_EQ(arr(1, 1, 1), static_cast<Eigen::half>(12));
+  EXPECT_EQ(arr(2, 0, 0), static_cast<Eigen::half>(17));
+  EXPECT_EQ(arr(2, 1, 1), static_cast<Eigen::half>(20));
+  EXPECT_EQ(arr(2, 2, 0), static_cast<Eigen::half>(21));
+  EXPECT_EQ(arr(2, 3, 1), static_cast<Eigen::half>(24));
+}
+
 TEST(Array3dTest, Fill) {
   Array3D<int> fullof7(2, 3, 4, 7);
   for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) {
index f8b2b2afe5fed9c465c2a1f39308b7f44311b16a..cff70e54bad0116bdd08674b626b3bf99dc89e1f 100644 (file)
@@ -82,6 +82,16 @@ class Array4D : public Array<T> {
               values)
       : Array<T>(values) {}
 
+  // Creates an array of Eigen::half from the given nested initializer list of
+  // float values.
+  template <typename T2, typename = typename std::enable_if<
+                             std::is_same<T, Eigen::half>::value &&
+                             std::is_same<T2, float>::value>::type>
+  Array4D(std::initializer_list<std::initializer_list<
+              std::initializer_list<std::initializer_list<T2>>>>
+              values)
+      : Array<T>(values) {}
+
   // Numerically-named aliases for the various dimensions. This matches the
   // dimension names used in array3d.
   int64 n4() const { return this->dim(3); }
index 3bc8148c911df0aeade364e4ac2e2ee828bacb53..927733ea1eab43feff643c35535cc6d9ea59ba5a 100644 (file)
@@ -97,6 +97,36 @@ TEST(Array3dTest, InitializerListCtor) {
   EXPECT_EQ(arr(2, 3, 1, 0), 24);
 }
 
+TEST(Array3dTest, InitializerListCtorHalf) {
+  Array4D<Eigen::half> arr = {
+      {{{1.0f}, {2.0f}}, {{3.0f}, {4.0f}}, {{5.0f}, {6.0f}}, {{7.0f}, {8.0f}}},
+      {{{9.0f}, {10.0f}},
+       {{11.0f}, {12.0f}},
+       {{13.0f}, {14.0f}},
+       {{15.0f}, {16.0f}}},
+      {{{17.0f}, {18.0f}},
+       {{19.0f}, {20.0f}},
+       {{21.0f}, {22.0f}},
+       {{23.0f}, {24.0f}}}};
+
+  EXPECT_EQ(arr.n1(), 3);
+  EXPECT_EQ(arr.n2(), 4);
+  EXPECT_EQ(arr.n3(), 2);
+  EXPECT_EQ(arr.n4(), 1);
+  EXPECT_EQ(arr.num_elements(), 24);
+
+  EXPECT_EQ(arr(0, 0, 0, 0), static_cast<Eigen::half>(1));
+  EXPECT_EQ(arr(0, 0, 1, 0), static_cast<Eigen::half>(2));
+  EXPECT_EQ(arr(0, 1, 0, 0), static_cast<Eigen::half>(3));
+  EXPECT_EQ(arr(0, 3, 1, 0), static_cast<Eigen::half>(8));
+  EXPECT_EQ(arr(1, 0, 0, 0), static_cast<Eigen::half>(9));
+  EXPECT_EQ(arr(1, 1, 1, 0), static_cast<Eigen::half>(12));
+  EXPECT_EQ(arr(2, 0, 0, 0), static_cast<Eigen::half>(17));
+  EXPECT_EQ(arr(2, 1, 1, 0), static_cast<Eigen::half>(20));
+  EXPECT_EQ(arr(2, 2, 0, 0), static_cast<Eigen::half>(21));
+  EXPECT_EQ(arr(2, 3, 1, 0), static_cast<Eigen::half>(24));
+}
+
 TEST(Array4dTest, Fill) {
   Array4D<int> fullof7(2, 3, 4, 5, 7);
   fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
index 8b9419477479d952126fd831eb44899e7649ca71..e8356c9832d34135f5ffb1a5c7a9d6db6db3a051 100644 (file)
@@ -60,6 +60,25 @@ TEST(ArrayTest, InitializerListCtor) {
   EXPECT_EQ(arr(1, 2), 6);
 }
 
+TEST(ArrayTest, InitializerListCtorHalf) {
+  Array<Eigen::half> d2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
+  EXPECT_EQ(d2.dim(0), 2);
+  EXPECT_EQ(d2.dim(1), 3);
+
+  Array<Eigen::half> d3({{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}});
+  EXPECT_EQ(d3.dim(0), 3);
+  EXPECT_EQ(d3.dim(1), 2);
+  EXPECT_EQ(d3.dim(2), 1);
+
+  Array<Eigen::half> d4(
+      {{{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}},
+       {{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}}});
+  EXPECT_EQ(d4.dim(0), 2);
+  EXPECT_EQ(d4.dim(1), 3);
+  EXPECT_EQ(d4.dim(2), 2);
+  EXPECT_EQ(d4.dim(3), 1);
+}
+
 TEST(ArrayTest, IndexingReadWrite) {
   Array<int> arr({2, 3});
 
index 1ef45dbec39a0880ebb123ba3fcd1fd6c89eb39a..40ace963270e8cead47cc731cc326351178dff7d 100644 (file)
@@ -35,6 +35,8 @@ extern const char* const kEigenMatMulF32SymbolName =
     "__xla_cpu_runtime_EigenMatMulF32";
 extern const char* const kEigenMatMulF64SymbolName =
     "__xla_cpu_runtime_EigenMatMulF64";
+extern const char* const kEigenConvF16SymbolName =
+    "__xla_cpu_runtime_EigenConvF16";
 extern const char* const kEigenConvF32SymbolName =
     "__xla_cpu_runtime_EigenConvF32";
 extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
@@ -42,6 +44,8 @@ extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
     "__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
 extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
     "__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
+extern const char* const kEigenSingleThreadedConvF16SymbolName =
+    "__xla_cpu_runtime_EigenSingleThreadedConvF16";
 extern const char* const kEigenSingleThreadedConvF32SymbolName =
     "__xla_cpu_runtime_EigenSingleThreadedConvF32";
 extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
index 3e1f08071119c938619d02777513e5b834077118..2141dfe1cedd6f9674acc348152574b4fd30895b 100644 (file)
@@ -43,10 +43,12 @@ namespace runtime {
 //    because it is a symbol in the cpu_runtime library.
 extern const char* const kEigenMatMulF32SymbolName;
 extern const char* const kEigenMatMulF64SymbolName;
+extern const char* const kEigenConvF16SymbolName;
 extern const char* const kEigenConvF32SymbolName;
 extern const char* const kEigenFftSymbolName;
 extern const char* const kEigenSingleThreadedMatMulF32SymbolName;
 extern const char* const kEigenSingleThreadedMatMulF64SymbolName;
+extern const char* const kEigenSingleThreadedConvF16SymbolName;
 extern const char* const kEigenSingleThreadedConvF32SymbolName;
 extern const char* const kAcquireInfeedBufferForDequeueSymbolName;
 extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
index c9fc586b9a4c06eb9e1f111d8f9bd2f717990aab..cfe7c9c3af0be109ac8a86753e880e2bcbceba41 100644 (file)
@@ -549,7 +549,7 @@ DotOpEmitter::DotOpEmitter(
     const HloModuleConfig& hlo_module_config,
     const TargetMachineFeatures& target_machine_features) {
   PrimitiveType type = target_array.GetShape().element_type();
-  TF_RET_CHECK(F32 == type || F64 == type || C64 == type);
+  TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type);
   DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
                            lhs_array, rhs_array, addend_array,
                            executable_run_options_value, ir_builder,
index 0b2d3d47463b745049807e9afa55360434ad522b..496aea051cda1059ed0a5db7652814d7391b2c4e 100644 (file)
@@ -801,7 +801,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
   auto rhs = dot->operand(1);
   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
       /*instruction=*/*dot, /*operands=*/{lhs, rhs},
-      /*supported_types=*/{F32, F64, C64}));
+      /*supported_types=*/{F16, F32, F64, C64}));
   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
   if (dnums.lhs_batch_dimensions_size() > 0 ||
       dnums.rhs_batch_dimensions_size() > 0) {
@@ -849,7 +849,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
   const auto& window = convolution->window();
   TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
       /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
-      /*supported_types=*/{F32, C64}));
+      /*supported_types=*/{F16, F32, C64}));
 
   const ConvolutionDimensionNumbers& dnums =
       convolution->convolution_dimension_numbers();
@@ -928,25 +928,30 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
       int64 rhs_col_dilation =
           one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
 
-      // Args have been computed, make the call.
-      llvm::Type* float_ptr_type = ir_builder_.getFloatTy()->getPointerTo();
+      PrimitiveType primitive_type = lhs->shape().element_type();
+      llvm::Type* ir_ptr_type = primitive_type == F16
+                                    ? ir_builder_.getHalfTy()->getPointerTo()
+                                    : ir_builder_.getFloatTy()->getPointerTo();
       llvm::Type* int64_type = ir_builder_.getInt64Ty();
       llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
       llvm::FunctionType* conv_type = llvm::FunctionType::get(
           ir_builder_.getVoidTy(),
-          {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
-           int64_type,    int64_type,     int64_type,     int64_type,
-           int64_type,    int64_type,     int64_type,     int64_type,
-           int64_type,    int64_type,     int64_type,     int64_type,
-           int64_type,    int64_type,     int64_type,     int64_type,
-           int64_type,    int64_type,     int64_type,     int64_type},
+          {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type,
+           int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
+           int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
+           int64_type,    int64_type,  int64_type,  int64_type,  int64_type,
+           int64_type,    int64_type,  int64_type,  int64_type},
           /*isVarArg=*/false);
       bool multi_threaded_eigen =
           hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
       const char* fn_name =
-          (multi_threaded_eigen
-               ? runtime::kEigenConvF32SymbolName
-               : runtime::kEigenSingleThreadedConvF32SymbolName);
+          primitive_type == F16
+              ? (multi_threaded_eigen
+                     ? runtime::kEigenConvF16SymbolName
+                     : runtime::kEigenSingleThreadedConvF16SymbolName)
+              : (multi_threaded_eigen
+                     ? runtime::kEigenConvF32SymbolName
+                     : runtime::kEigenSingleThreadedConvF32SymbolName);
       llvm::Function* conv_func = llvm::cast<llvm::Function>(
           module_->getOrInsertFunction(fn_name, conv_type));
       conv_func->setCallingConv(llvm::CallingConv::C);
@@ -956,9 +961,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
           conv_func, {
                          GetExecutableRunOptionsArgument(),
                          ir_builder_.CreateBitCast(
-                             GetEmittedValueFor(convolution), float_ptr_type),
-                         ir_builder_.CreateBitCast(lhs_address, float_ptr_type),
-                         ir_builder_.CreateBitCast(rhs_address, float_ptr_type),
+                             GetEmittedValueFor(convolution), ir_ptr_type),
+                         ir_builder_.CreateBitCast(lhs_address, ir_ptr_type),
+                         ir_builder_.CreateBitCast(rhs_address, ir_ptr_type),
                          ir_builder_.getInt64(input_batch),
                          ir_builder_.getInt64(input_rows),
                          ir_builder_.getInt64(input_cols),
index c2f64eb27a554d17ebe2a94dba334fe378bd7254..3905e7ff2a14d25813e345399e692f9e0f4bd0af 100644 (file)
@@ -34,7 +34,26 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF32(
     int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
   const xla::ExecutableRunOptions* run_options =
       static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
-  tensorflow::xla::EigenConvF32Impl(
+  tensorflow::xla::EigenConvImpl(
+      *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch,
+      input_rows, input_cols, input_channels, kernel_rows, kernel_cols,
+      kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
+      col_stride, padding_top, padding_bottom, padding_left, padding_right,
+      lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF16(
+    const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
+    Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols,
+    int64 input_channels, int64 kernel_rows, int64 kernel_cols,
+    int64 kernel_channels, int64 kernel_filters, int64 output_rows,
+    int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top,
+    int64 padding_bottom, int64 padding_left, int64 padding_right,
+    int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation,
+    int64 rhs_col_dilation) {
+  const xla::ExecutableRunOptions* run_options =
+      static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+  tensorflow::xla::EigenConvImpl(
       *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch,
       input_rows, input_cols, input_channels, kernel_rows, kernel_cols,
       kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
index 05ae094691fd9a7ca83b902145c0750fafdc529a..39e20ed45639040110b99ddb52eb6f6dab26dfaa 100644 (file)
@@ -34,6 +34,20 @@ extern void __xla_cpu_runtime_EigenConvF32(
     tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation,
     tensorflow::int64 rhs_col_dilation);
 
+extern void __xla_cpu_runtime_EigenConvF16(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
+    Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs,
+    tensorflow::int64 input_batch, tensorflow::int64 input_rows,
+    tensorflow::int64 input_cols, tensorflow::int64 input_channels,
+    tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols,
+    tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters,
+    tensorflow::int64 output_rows, tensorflow::int64 output_cols,
+    tensorflow::int64 row_stride, tensorflow::int64 col_stride,
+    tensorflow::int64 padding_top, tensorflow::int64 padding_bottom,
+    tensorflow::int64 padding_left, tensorflow::int64 padding_right,
+    tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation,
+    tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation);
+
 }  // extern "C"
 
 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_
index 02f45fee0f1b8cd1125ec6a97f01e0028137bb69..85af63bb032ce33bdd188d6e5bcd78a726d5d9fa 100644 (file)
@@ -24,26 +24,27 @@ limitations under the License.
 namespace tensorflow {
 namespace xla {
 
-template <typename EigenDevice>
-void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs,
-                      float* rhs, int64 input_batch, int64 input_rows,
-                      int64 input_cols, int64 input_channels, int64 kernel_rows,
-                      int64 kernel_cols, int64 kernel_channels,
-                      int64 kernel_filters, int64 output_rows,
-                      int64 output_cols, int64 row_stride, int64 col_stride,
-                      int64 padding_top, int64 padding_bottom,
-                      int64 padding_left, int64 padding_right,
-                      int64 lhs_row_dilation, int64 lhs_col_dilation,
-                      int64 rhs_row_dilation, int64 rhs_col_dilation) {
-  const Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor>,
+template <typename EigenDevice, typename ScalarType>
+void EigenConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs,
+                   ScalarType* rhs, int64 input_batch, int64 input_rows,
+                   int64 input_cols, int64 input_channels, int64 kernel_rows,
+                   int64 kernel_cols, int64 kernel_channels,
+                   int64 kernel_filters, int64 output_rows, int64 output_cols,
+                   int64 row_stride, int64 col_stride, int64 padding_top,
+                   int64 padding_bottom, int64 padding_left,
+                   int64 padding_right, int64 lhs_row_dilation,
+                   int64 lhs_col_dilation, int64 rhs_row_dilation,
+                   int64 rhs_col_dilation) {
+  const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
                          Eigen::Aligned>
       input(lhs, input_batch, input_rows, input_cols, input_channels);
 
-  const Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor>,
+  const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
                          Eigen::Aligned>
       kernel(rhs, kernel_rows, kernel_cols, kernel_channels, kernel_filters);
 
-  Eigen::TensorMap<Eigen::Tensor<float, 4, Eigen::RowMajor>, Eigen::Aligned>
+  Eigen::TensorMap<Eigen::Tensor<ScalarType, 4, Eigen::RowMajor>,
+                   Eigen::Aligned>
       output(out, input_batch, output_rows, output_cols, kernel_filters);
 
   Eigen::array<Eigen::IndexPair<int64>, 1> contract_dims;
@@ -75,7 +76,7 @@ void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs,
                                  row_stride, rhs_col_dilation, rhs_row_dilation,
                                  lhs_col_dilation, lhs_row_dilation,
                                  padding_left, padding_right, padding_top,
-                                 padding_bottom, 0.0f)
+                                 padding_bottom, static_cast<ScalarType>(0.0f))
           .reshape(pre_contract_dims)
           .contract(kernel.reshape(kernel_dims), contract_dims)
           .reshape(post_contract_dims);
index d0b0e11ac0f9fd06e384c2bb5e6296edd0825f5c..5afccc6a86e2df468e3e3e874cf0f4d4e1342a88 100644 (file)
@@ -21,6 +21,24 @@ limitations under the License.
 
 using tensorflow::int64;
 
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
+__xla_cpu_runtime_EigenSingleThreadedConvF16(
+    const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
+    Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols,
+    int64 input_channels, int64 kernel_rows, int64 kernel_cols,
+    int64 kernel_channels, int64 kernel_filters, int64 output_rows,
+    int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top,
+    int64 padding_bottom, int64 padding_left, int64 padding_right,
+    int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation,
+    int64 rhs_col_dilation) {
+  tensorflow::xla::EigenConvImpl(
+      Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows,
+      input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels,
+      kernel_filters, output_rows, output_cols, row_stride, col_stride,
+      padding_top, padding_bottom, padding_left, padding_right,
+      lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
+}
+
 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
 __xla_cpu_runtime_EigenSingleThreadedConvF32(
     const void* run_options_ptr, float* out, float* lhs, float* rhs,
@@ -30,7 +48,7 @@ __xla_cpu_runtime_EigenSingleThreadedConvF32(
     int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom,
     int64 padding_left, int64 padding_right, int64 lhs_row_dilation,
     int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
-  tensorflow::xla::EigenConvF32Impl(
+  tensorflow::xla::EigenConvImpl(
       Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows,
       input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels,
       kernel_filters, output_rows, output_cols, row_stride, col_stride,
index 8ae1a42149bde26ca2f510ad47e76ae47f34a977..f216bd0152aa93b8753d881938c63a9cabea899b 100644 (file)
@@ -20,6 +20,20 @@ limitations under the License.
 
 extern "C" {
 
+extern void __xla_cpu_runtime_EigenSingleThreadedConvF16(
+    const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
+    Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs,
+    tensorflow::int64 input_batch, tensorflow::int64 input_rows,
+    tensorflow::int64 input_cols, tensorflow::int64 input_channels,
+    tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols,
+    tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters,
+    tensorflow::int64 output_rows, tensorflow::int64 output_cols,
+    tensorflow::int64 row_stride, tensorflow::int64 col_stride,
+    tensorflow::int64 padding_top, tensorflow::int64 padding_bottom,
+    tensorflow::int64 padding_left, tensorflow::int64 padding_right,
+    tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation,
+    tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation);
+
 extern void __xla_cpu_runtime_EigenSingleThreadedConvF32(
     const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
     float* lhs, float* rhs, tensorflow::int64 input_batch,
index 64d3a51f41676bbb4b59c9d272d22f52a87a0559..f19cb86cc42f03543c77ff17b6ae4a8a69bbb140 100644 (file)
@@ -208,10 +208,12 @@ bool RegisterKnownJITSymbols() {
 
   REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
   REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
+  REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
+  REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
   REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
index 15bba49b73bce8eb4a18175f8874f05049119458..461747b699b542ae0c8735aea34cc9e57c1fb387 100644 (file)
@@ -63,12 +63,12 @@ ConvolutionThunk::ConvolutionThunk(
 
 Status ConvolutionThunk::ExecuteOnStream(
     const BufferAllocations& buffer_allocations, se::Stream* stream) {
-  se::DeviceMemory<float> input_data(
-      buffer_allocations.GetDeviceAddress(input_buffer_));
-  se::DeviceMemory<float> filter_data(
-      buffer_allocations.GetDeviceAddress(filter_buffer_));
-  se::DeviceMemory<float> output_data(
-      buffer_allocations.GetDeviceAddress(output_buffer_));
+  se::DeviceMemoryBase input_data =
+      buffer_allocations.GetDeviceAddress(input_buffer_);
+  se::DeviceMemoryBase filter_data =
+      buffer_allocations.GetDeviceAddress(filter_buffer_);
+  se::DeviceMemoryBase output_data =
+      buffer_allocations.GetDeviceAddress(output_buffer_);
   se::DeviceMemoryBase scratch =
       buffer_allocations.GetDeviceAddress(scratch_buffer_);
 
@@ -80,8 +80,8 @@ Status ConvolutionThunk::ExecuteOnStream(
       filter_data, output_data, scratch, window_, dim_nums_, algorithm_config,
       stream));
 
-  // Figure out which of output/input/filter is the result produced by this op,
-  // and write the result tuple.
+  // Figure out which of output/input/filter is the result produced by
+  // this op, and write the result tuple.
   void* result_ptr = [&] {
     switch (convolution_kind_) {
       case CudnnConvKind::kForward:
index c29aa31d4ee31c88ec6d315480d4258b190bbcff..1792893ae401bf16d2dd9e861607e8f3821a505e 100644 (file)
@@ -135,15 +135,6 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
       break;
   }
 
-  // Remove any algorithms with tensor math enabled.  These have lower precision
-  // than regular algorithms, and we don't yet have a way to turn this on/off in
-  // XLA.
-  algorithms.erase(std::remove_if(algorithms.begin(), algorithms.end(),
-                                  [&](const AlgorithmDesc& a) {
-                                    return a.tensor_ops_enabled();
-                                  }),
-                   algorithms.end());
-
   return algorithms;
 }
 
@@ -222,6 +213,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
       ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums);
   se::dnn::ProfileResult best_result;
   int64 best_result_bytes_used = 0;
+
   for (const AlgorithmDesc& alg :
        GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
     ScratchAllocator scratch_allocator(device_ordinal, allocator);
@@ -229,14 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
     VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
             << instr->ToString();
 
-    bool launch_ok =
-        RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
-                            se::DeviceMemory<float>(input_buf.ValueOrDie()),
-                            se::DeviceMemory<float>(filter_buf.ValueOrDie()),
-                            se::DeviceMemory<float>(output_buf.ValueOrDie()),
-                            &scratch_allocator, window, dnums,
-                            AlgorithmConfig(alg), &stream, &profile_result)
-            .ok();
+    bool launch_ok = RunCudnnConvolution(
+                         kind, input_shape, filter_shape, output_shape,
+                         input_buf.ValueOrDie(), filter_buf.ValueOrDie(),
+                         output_buf.ValueOrDie(), &scratch_allocator, window,
+                         dnums, AlgorithmConfig(alg), &stream, &profile_result)
+                         .ok();
 
     if (launch_ok && profile_result.is_valid()) {
       int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
index 81695a6c326b922904330f33bc88260729ff67ee..e4ae839e1dd4cb3a744a3f6a3329cabdaeb3f38d 100644 (file)
@@ -70,39 +70,11 @@ class ScratchBufAllocator : public se::ScratchAllocator {
   bool allocated_ = false;
 };
 
-}  // anonymous namespace
-
-string CudnnConvKindToString(CudnnConvKind kind) {
-  switch (kind) {
-    case CudnnConvKind::kForward:
-      return "forward";
-    case CudnnConvKind::kBackwardFilter:
-      return "backward_filter";
-    case CudnnConvKind::kBackwardInput:
-      return "backward_input";
-  }
-}
-
-Status RunCudnnConvolution(CudnnConvKind kind, const Shape& input_shape,
-                           const Shape& filter_shape, const Shape& output_shape,
-                           DeviceMemory<float> input_buf,
-                           DeviceMemory<float> filter_buf,
-                           DeviceMemory<float> output_buf,
-                           DeviceMemoryBase scratch_buf, const Window& window,
-                           const ConvolutionDimensionNumbers& dnums,
-                           AlgorithmConfig algorithm, Stream* stream,
-                           ProfileResult* profile_result /*= nullptr*/) {
-  ScratchBufAllocator scratch_allocator(scratch_buf);
-  return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
-                             input_buf, filter_buf, output_buf,
-                             &scratch_allocator, window, dnums, algorithm,
-                             stream, profile_result);
-}
-
+template <typename T>
 Status RunCudnnConvolution(
     CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-    const Shape& output_shape, DeviceMemory<float> input_buf,
-    DeviceMemory<float> filter_buf, DeviceMemory<float> output_buf,
+    const Shape& output_shape, DeviceMemory<T> input_buf,
+    DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
     se::ScratchAllocator* scratch_allocator, const Window& window,
     const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
     Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
@@ -124,8 +96,16 @@ Status RunCudnnConvolution(
   // tensorflow/python/ops/nn_ops.py).
   const int effective_num_dimensions = std::max(2, num_dimensions);
 
-  CHECK_EQ(F32, output_shape.element_type())
-      << ShapeUtil::HumanString(output_shape);
+  if (std::is_same<T, float>::value) {
+    CHECK_EQ(F32, output_shape.element_type())
+        << ShapeUtil::HumanString(output_shape);
+  } else if (std::is_same<T, Eigen::half>::value) {
+    CHECK_EQ(F16, output_shape.element_type())
+        << ShapeUtil::HumanString(output_shape);
+  } else {
+    LOG(FATAL) << ShapeUtil::HumanString(output_shape);
+  }
+
   CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size());
   CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size());
   CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size());
@@ -220,5 +200,63 @@ Status RunCudnnConvolution(
   return Status::OK();
 }
 
+}  // anonymous namespace
+
+string CudnnConvKindToString(CudnnConvKind kind) {
+  switch (kind) {
+    case CudnnConvKind::kForward:
+      return "forward";
+    case CudnnConvKind::kBackwardFilter:
+      return "backward_filter";
+    case CudnnConvKind::kBackwardInput:
+      return "backward_input";
+  }
+}
+
+Status RunCudnnConvolution(
+    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
+    const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
+    perftools::gputools::DeviceMemoryBase filter_buf,
+    perftools::gputools::DeviceMemoryBase output_buf,
+    perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window,
+    const ConvolutionDimensionNumbers& dnums,
+    perftools::gputools::dnn::AlgorithmConfig algorithm,
+    perftools::gputools::Stream* stream,
+    perftools::gputools::dnn::ProfileResult* profile_result) {
+  ScratchBufAllocator scratch_allocator(scratch_buf);
+  return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+                             input_buf, filter_buf, output_buf,
+                             &scratch_allocator, window, dnums, algorithm,
+                             stream, profile_result);
+}
+
+Status RunCudnnConvolution(
+    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
+    const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
+    perftools::gputools::DeviceMemoryBase filter_buf,
+    perftools::gputools::DeviceMemoryBase output_buf,
+    perftools::gputools::ScratchAllocator* scratch_allocator,
+    const Window& window, const ConvolutionDimensionNumbers& dnums,
+    perftools::gputools::dnn::AlgorithmConfig algorithm,
+    perftools::gputools::Stream* stream,
+    perftools::gputools::dnn::ProfileResult* profile_result) {
+  PrimitiveType output_primitive_type = output_shape.element_type();
+  CHECK(output_primitive_type == F32 || output_primitive_type == F16)
+      << ShapeUtil::HumanString(output_shape);
+  if (output_primitive_type == F32) {
+    return RunCudnnConvolution(
+        kind, input_shape, filter_shape, output_shape,
+        se::DeviceMemory<float>(input_buf), se::DeviceMemory<float>(filter_buf),
+        se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
+        algorithm, stream, profile_result);
+  }
+  return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+                             se::DeviceMemory<Eigen::half>(input_buf),
+                             se::DeviceMemory<Eigen::half>(filter_buf),
+                             se::DeviceMemory<Eigen::half>(output_buf),
+                             scratch_allocator, window, dnums, algorithm,
+                             stream, profile_result);
+}
+
 }  // namespace gpu
 }  // namespace xla
index b101f76510c129fd22b246e5f0348848192ecbba..3dbfa2730da359d3c7937140508017c4a7b02d6c 100644 (file)
@@ -55,7 +55,10 @@ string CudnnConvKindToString(CudnnConvKind kind);
 // Note that depending on the value of CudnnConvKind, the result of this call
 // may be written into input_buf, filter_buf, or output_buf!
 //
-// At the moment we only support cudnn convolutions over floats.
+// At the moment we only support cudnn convolutions over float and half, and
+// convolution with half data type is implemented with cudnn PSEUDO_HALF
+// configuration, that is, the input values are half and the internal
+// computation type is float.
 //
 // We provide one overload which takes a scratch buffer, and another which takes
 // an allocator which is responsible for allocating the scratch space.  In
@@ -69,10 +72,9 @@ string CudnnConvKindToString(CudnnConvKind kind);
 // that size, if you like.
 Status RunCudnnConvolution(
     CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-    const Shape& output_shape,
-    perftools::gputools::DeviceMemory<float> input_buf,
-    perftools::gputools::DeviceMemory<float> filter_buf,
-    perftools::gputools::DeviceMemory<float> output_buf,
+    const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
+    perftools::gputools::DeviceMemoryBase filter_buf,
+    perftools::gputools::DeviceMemoryBase output_buf,
     perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window,
     const ConvolutionDimensionNumbers& dnums,
     perftools::gputools::dnn::AlgorithmConfig algorithm,
@@ -81,10 +83,9 @@ Status RunCudnnConvolution(
 
 Status RunCudnnConvolution(
     CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-    const Shape& output_shape,
-    perftools::gputools::DeviceMemory<float> input_buf,
-    perftools::gputools::DeviceMemory<float> filter_buf,
-    perftools::gputools::DeviceMemory<float> output_buf,
+    const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
+    perftools::gputools::DeviceMemoryBase filter_buf,
+    perftools::gputools::DeviceMemoryBase output_buf,
     perftools::gputools::ScratchAllocator* scratch_allocator,
     const Window& window, const ConvolutionDimensionNumbers& dnums,
     perftools::gputools::dnn::AlgorithmConfig algorithm,
index 81212cda4266ec820230d0d84fc2a395edaf411e..eb6e9feb7c0d815c324c16ea2c6d704a9307c774 100644 (file)
@@ -1403,6 +1403,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map));
         break;
       }
+      case F16: {
+        TF_ASSIGN_OR_RETURN(parent_->evaluated_[map],
+                            MapImpl<Eigen::half>(map));
+        break;
+      }
       case F32: {
         TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
         break;
@@ -2041,9 +2046,7 @@ HloEvaluator::HloEvaluator() {
   });
   typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
   typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
-  typed_visitors_[F16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
-    return Unimplemented("HloEvaluator: unhandled primitive type: F16.");
-  });
+  typed_visitors_[F16] = MakeUnique<TypedVisitor<Eigen::half, float>>(this);
   typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
   typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
   typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
index 0ceb9aff378ae8aa8098be9360310b1d78d31ab2..1385b437fc47fe5289c401581fab8b5278872382 100644 (file)
@@ -53,157 +53,200 @@ class ConvolutionTest : public ClientLibraryTestBase {
 #endif
 };
 
-XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
-  const int kInputActivationSizeY = 3;
-  const int kInputActivationSizeX = 3;
-  const int kInputActivationSizeZ = 256;
-  const int kKernelSizeX = 2;
-  const int kKernelSizeY = 2;
-  const int kOutputActivationSizeZ = 256;
-  const int kMiniBatchSize = 4;
-  auto alhs =
-      MakeUnique<Array4D<float>>(kMiniBatchSize, kInputActivationSizeZ,
-                                 kInputActivationSizeY, kInputActivationSizeX);
-  alhs->FillWithMultiples(1.0f);
-  ASSERT_EQ(3, alhs->width());
-  ASSERT_EQ(3, alhs->height());
-
-  auto arhs =
-      MakeUnique<Array4D<float>>(kOutputActivationSizeZ, kInputActivationSizeZ,
-                                 kKernelSizeY, kKernelSizeX);
-  Array2D<float> rhs_raster({
-      {1.0f, 0.0f},  // row 0
-      {0.0f, 0.0f},  // row 1
-  });
-  arhs->FillWithYX(rhs_raster);
-  ASSERT_EQ(2, arhs->width());
-  ASSERT_EQ(2, arhs->height());
+// TODO(b/72509305): Enable half data type tests for CPU
+#if (XLA_TEST_BACKEND_GPU)
+using TestTypes = ::testing::Types<float, Eigen::half>;
+#else
+using TestTypes = ::testing::Types<float>;
+#endif
 
-  ComputationBuilder builder(client_, TestName());
-  auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
-  auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
-  auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+template <typename T>
+Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice<int64> dimensions);
 
-  ComputeAndCompare(&builder, conv, {}, error_spec_);
+template <>
+Shape MakeShapeWrapper<float>(tensorflow::gtl::ArraySlice<int64> dimensions) {
+  return ShapeUtil::MakeShape(F32, dimensions);
 }
 
-TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
-  ComputationBuilder builder(client_, TestName());
-  Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
-  Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
-  auto input = builder.Parameter(0, input_shape, "input");
-  auto filter = builder.Parameter(1, filter_shape, "filter");
-  auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+template <>
+Shape MakeShapeWrapper<Eigen::half>(
+    tensorflow::gtl::ArraySlice<int64> dimensions) {
+  return ShapeUtil::MakeShape(F16, dimensions);
+}
 
-  Array4D<float> input_data(1, 1, 1, 2);
-  input_data.FillWithYX(Array2D<float>({
-      {1, 2},
-  }));
-  Array4D<float> filter_data(1, 1, 1, 2);
-  filter_data.FillWithYX(Array2D<float>({
-      {5, 6},
-  }));
+template <typename T>
+class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
+ public:
+  void RunTest() {
+    const int kInputActivationSizeY = 3;
+    const int kInputActivationSizeX = 3;
+    const int kInputActivationSizeZ = 256;
+    const int kKernelSizeX = 2;
+    const int kKernelSizeY = 2;
+    const int kOutputActivationSizeZ = 256;
+    const int kMiniBatchSize = 4;
+    auto alhs =
+        MakeUnique<Array4D<T>>(kMiniBatchSize, kInputActivationSizeZ,
+                               kInputActivationSizeY, kInputActivationSizeX);
+    alhs->FillWithMultiples(static_cast<T>(1.0f));
+    ASSERT_EQ(3, alhs->width());
+    ASSERT_EQ(3, alhs->height());
+
+    auto arhs =
+        MakeUnique<Array4D<T>>(kOutputActivationSizeZ, kInputActivationSizeZ,
+                               kKernelSizeY, kKernelSizeX);
+    Array2D<T> rhs_raster({
+        {1.0f, 0.0f},  // row 0
+        {0.0f, 0.0f},  // row 1
+    });
+    arhs->FillWithYX(rhs_raster);
+    ASSERT_EQ(2, arhs->width());
+    ASSERT_EQ(2, arhs->height());
+
+    ComputationBuilder builder(client_, TestName());
+    auto lhs = builder.ConstantR4FromArray4D<T>(*alhs);
+    auto rhs = builder.ConstantR4FromArray4D<T>(*arhs);
+    auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+
+    ComputeAndCompare(&builder, conv, {}, error_spec_);
+  }
+};
 
-  ComputeAndCompare(&builder, conv,
-                    {std::move(*Literal::CreateFromArray(input_data)),
-                     std::move(*Literal::CreateFromArray(filter_data))},
-                    error_spec_);
+TYPED_TEST_CASE(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, TestTypes);
+XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, Types) {
+  this->RunTest();
 }
 
+template <typename T>
+class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
+ public:
+  void RunTest() {
+    ComputationBuilder builder(client_, TestName());
+    Shape input_shape = MakeShapeWrapper<T>({1, 1, 1, 2});
+    Shape filter_shape = MakeShapeWrapper<T>({1, 1, 1, 2});
+    auto input = builder.Parameter(0, input_shape, "input");
+    auto filter = builder.Parameter(1, filter_shape, "filter");
+    auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+    Array4D<T> input_data(1, 1, 1, 2);
+    input_data.FillWithYX(Array2D<T>({
+        {1.0f, 2.0f},
+    }));
+    Array4D<T> filter_data(1, 1, 1, 2);
+    filter_data.FillWithYX(Array2D<T>({
+        {5.0f, 6.0f},
+    }));
+
+    ComputeAndCompare(&builder, conv,
+                      {std::move(*Literal::CreateFromArray(input_data)),
+                       std::move(*Literal::CreateFromArray(filter_data))},
+                      error_spec_);
+  }
+};
+
+TYPED_TEST_CASE(Convolve_1x1x1x2_1x1x1x2_Valid, TestTypes);
+TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid, Types) { this->RunTest(); }
+
 // Tests valid padding for 2D convolution in raster space.
-TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
-  ComputationBuilder builder(client_, TestName());
-  Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
-  Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
-  auto input = builder.Parameter(0, input_shape, "input");
-  auto filter = builder.Parameter(1, filter_shape, "filter");
-  auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+template <typename T>
+class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
+ public:
+  void RunTest() {
+    ComputationBuilder builder(client_, TestName());
+    Shape input_shape = MakeShapeWrapper<T>({1, 1, 4, 4});
+    Shape filter_shape = MakeShapeWrapper<T>({1, 1, 2, 2});
+    auto input = builder.Parameter(0, input_shape, "input");
+    auto filter = builder.Parameter(1, filter_shape, "filter");
+    auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+    Array4D<T> input_data(1, 1, 4, 4);
+    input_data.FillWithYX(Array2D<T>({
+        {1.0f, 2.0f, 3.0f, 4.0f},
+        {5.0f, 6.0f, 7.0f, 8.0f},
+        {9.0f, 10.0f, 11.0f, 12.0f},
+        {13.0f, 14.0f, 15.0f, 16.0f},
+    }));
+    Array4D<T> filter_data(1, 1, 2, 2);
+    filter_data.FillWithYX(Array2D<T>({
+        {5.0f, 6.0f},
+        {7.0f, 8.0f},
+    }));
+    ComputeAndCompare(&builder, conv,
+                      {std::move(*Literal::CreateFromArray(input_data)),
+                       std::move(*Literal::CreateFromArray(filter_data))},
+                      error_spec_);
+  }
+};
 
-  Array4D<float> input_data(1, 1, 4, 4);
-  // clang-format off
-  input_data.FillWithYX(Array2D<float>({
-    {1,  2,  3,  4 },
-    {5,  6,  7,  8 },
-    {9,  10, 11, 12},
-    {13, 14, 15, 16},
-  }));
-  // clang-format on
-  Array4D<float> filter_data(1, 1, 2, 2);
-  // clang-format off
-  filter_data.FillWithYX(Array2D<float>({
-    {5, 6},
-    {7, 8},
-  }));
-  // clang-format on
-  ComputeAndCompare(&builder, conv,
-                    {std::move(*Literal::CreateFromArray(input_data)),
-                     std::move(*Literal::CreateFromArray(filter_data))},
-                    error_spec_);
-}
+TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Valid, TestTypes);
+TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid, Types) { this->RunTest(); }
 
 // Tests same padding for 2D convolution in raster space.
-TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
-  ComputationBuilder builder(client_, TestName());
-  Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
-  Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
-  auto input = builder.Parameter(0, input_shape, "input");
-  auto filter = builder.Parameter(1, filter_shape, "filter");
-  auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
-
-  Array4D<float> input_data(1, 1, 4, 4);
-  // clang-format off
-  input_data.FillWithYX(Array2D<float>({
-    {1,  2,  3,  4 },
-    {5,  6,  7,  8 },
-    {9,  10, 11, 12},
-    {13, 14, 15, 16},
-  }));
-  // clang-format on
-  Array4D<float> filter_data(1, 1, 2, 2);
-  // clang-format off
-  filter_data.FillWithYX(Array2D<float>({
-    {5, 6},
-    {7, 8},
-  }));
-  // clang-format on
-  ComputeAndCompare(&builder, conv,
-                    {std::move(*Literal::CreateFromArray(input_data)),
-                     std::move(*Literal::CreateFromArray(filter_data))},
-                    error_spec_);
-}
+template <typename T>
+class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
+ public:
+  void RunTest() {
+    ComputationBuilder builder(client_, TestName());
+    Shape input_shape = MakeShapeWrapper<T>({1, 1, 4, 4});
+    Shape filter_shape = MakeShapeWrapper<T>({1, 1, 2, 2});
+    auto input = builder.Parameter(0, input_shape, "input");
+    auto filter = builder.Parameter(1, filter_shape, "filter");
+    auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+    Array4D<T> input_data(1, 1, 4, 4);
+    input_data.FillWithYX(Array2D<T>({
+        {1.0f, 2.0f, 3.0f, 4.0f},
+        {5.0f, 6.0f, 7.0f, 8.0f},
+        {9.0f, 10.0f, 11.0f, 12.0f},
+        {13.0f, 14.0f, 15.0f, 16.0f},
+    }));
+    Array4D<T> filter_data(1, 1, 2, 2);
+    filter_data.FillWithYX(Array2D<T>({
+        {5.0f, 6.0f},
+        {7.0f, 8.0f},
+    }));
+
+    ComputeAndCompare(&builder, conv,
+                      {std::move(*Literal::CreateFromArray(input_data)),
+                       std::move(*Literal::CreateFromArray(filter_data))},
+                      error_spec_);
+  }
+};
+
+TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Same, TestTypes);
+TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same, Types) { this->RunTest(); }
 
 // Tests same padding for 2D convolution in raster space with an odd sized
 // kernel.
-TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
-  ComputationBuilder builder(client_, TestName());
-  Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
-  Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
-  auto input = builder.Parameter(0, input_shape, "input");
-  auto filter = builder.Parameter(1, filter_shape, "filter");
-  auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
-
-  Array4D<float> input_data(1, 1, 4, 4);
-  // clang-format off
-  input_data.FillWithYX(Array2D<float>({
-    {1,  2,  3,  4 },
-    {5,  6,  7,  8 },
-    {9,  10, 11, 12},
-    {13, 14, 15, 16},
-  }));
-  // clang-format on
-  Array4D<float> filter_data(1, 1, 3, 3);
-  // clang-format off
-  filter_data.FillWithYX(Array2D<float>({
-    { 5,  6,  7},
-    { 8,  9, 10},
-    {11, 12, 13},
-  }));
-  // clang-format on
-  ComputeAndCompare(&builder, conv,
-                    {std::move(*Literal::CreateFromArray(input_data)),
-                     std::move(*Literal::CreateFromArray(filter_data))},
-                    error_spec_);
-}
+template <typename T>
+class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
+ public:
+  void RunTest() {
+    ComputationBuilder builder(client_, TestName());
+    Shape input_shape = MakeShapeWrapper<T>({1, 1, 4, 4});
+    Shape filter_shape = MakeShapeWrapper<T>({1, 1, 3, 3});
+    auto input = builder.Parameter(0, input_shape, "input");
+    auto filter = builder.Parameter(1, filter_shape, "filter");
+    auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+    Array4D<T> input_data(1, 1, 4, 4);
+    input_data.FillWithYX(Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f},
+                                      {5.0f, 6.0f, 7.0f, 8.0f},
+                                      {9.0f, 10.0f, 11.0f, 12.0f},
+                                      {13.0f, 14.0f, 15.0f, 16.0f}}));
+    Array4D<T> filter_data(1, 1, 3, 3);
+    filter_data.FillWithYX(Array2D<T>(
+        {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
+    // clang-format on
+    ComputeAndCompare(&builder, conv,
+                      {std::move(*Literal::CreateFromArray(input_data)),
+                       std::move(*Literal::CreateFromArray(filter_data))},
+                      error_spec_);
+  }
+};
+
+TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
+TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
 
 XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
   ComputationBuilder builder(client_, TestName());
@@ -232,36 +275,44 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
                              error_spec_);
 }
 
-XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithRHSDilation) {
-  ComputationBuilder builder(client_, TestName());
-  {
-    Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
-    Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
-    auto input = builder.Parameter(0, input_shape, "input");
-    auto filter = builder.Parameter(1, filter_shape, "filter");
-    // Convolution dimensions are bf0_oi0->bo0.
-    builder.ConvGeneralDilated(
-        input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
-        /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
-        /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+template <typename T>
+class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
+ public:
+  void RunTest() {
+    ComputationBuilder builder(client_, TestName());
+    {
+      Shape input_shape = MakeShapeWrapper<T>({1, 2, 5});
+      Shape filter_shape = MakeShapeWrapper<T>({1, 2, 2});
+      auto input = builder.Parameter(0, input_shape, "input");
+      auto filter = builder.Parameter(1, filter_shape, "filter");
+      // Convolution dimensions are bf0_oi0->bo0.
+      builder.ConvGeneralDilated(
+          input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
+          /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
+          /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+    }
+
+    Array3D<T> input(
+        {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
+    Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
+
+    Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
+
+    auto input_literal =
+        client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+            .ConsumeValueOrDie();
+    auto filter_literal =
+        client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+            .ConsumeValueOrDie();
+
+    ComputeAndCompareR3<T>(&builder, expected,
+                           {input_literal.get(), filter_literal.get()},
+                           error_spec_);
   }
+};  // namespace
 
-  Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
-  Array3D<float> filter({{{10, 20}, {30, 40}}});
-
-  Array3D<float> expected({{{570, 670, 770}}});
-
-  auto input_literal =
-      client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
-          .ConsumeValueOrDie();
-  auto filter_literal =
-      client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
-          .ConsumeValueOrDie();
-
-  ComputeAndCompareR3<float>(&builder, expected,
-                             {input_literal.get(), filter_literal.get()},
-                             error_spec_);
-}
+TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
+TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
 
 XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
   ComputationBuilder builder(client_, TestName());
@@ -325,36 +376,45 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
                              error_spec_);
 }
 
-XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithPadding) {
-  ComputationBuilder builder(client_, TestName());
-  {
-    Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
-    Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
-    auto input = builder.Parameter(0, input_shape, "input");
-    auto filter = builder.Parameter(1, filter_shape, "filter");
-    // Convolution dimensions are bf0_oi0->bo0.
-    builder.ConvGeneralDilated(
-        input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
-        /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
-        /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+template <typename T>
+class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
+ public:
+  void RunTest() {
+    ComputationBuilder builder(client_, TestName());
+    {
+      Shape input_shape = MakeShapeWrapper<T>({1, 2, 5});
+      Shape filter_shape = MakeShapeWrapper<T>({1, 2, 2});
+      auto input = builder.Parameter(0, input_shape, "input");
+      auto filter = builder.Parameter(1, filter_shape, "filter");
+      // Convolution dimensions are bf0_oi0->bo0.
+      builder.ConvGeneralDilated(
+          input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
+          /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
+          /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+    }
+
+    Array3D<T> input(
+        {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
+    Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
+
+    Array3D<T> expected(
+        {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
+
+    auto input_literal =
+        client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+            .ConsumeValueOrDie();
+    auto filter_literal =
+        client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+            .ConsumeValueOrDie();
+
+    ComputeAndCompareR3<T>(&builder, expected,
+                           {input_literal.get(), filter_literal.get()},
+                           error_spec_);
   }
+};
 
-  Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
-  Array3D<float> filter({{{10, 20}, {30, 40}}});
-
-  Array3D<float> expected({{{0, 260, 510, 610, 710, 810, 350, 0}}});
-
-  auto input_literal =
-      client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
-          .ConsumeValueOrDie();
-  auto filter_literal =
-      client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
-          .ConsumeValueOrDie();
-
-  ComputeAndCompareR3<float>(&builder, expected,
-                             {input_literal.get(), filter_literal.get()},
-                             error_spec_);
-}
+TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
+TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
 
 XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
   ComputationBuilder builder(client_, TestName());
@@ -389,12 +449,12 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
   }
 
   std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
-  std::iota(input_elems.begin(), input_elems.end(), 1.0f);
+  iota(input_elems.begin(), input_elems.end(), 1.0f);
   auto input_r1 = Literal::CreateR1<float>(input_elems);
   auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
 
   std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
-  std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
+  iota(filter_elems.begin(), filter_elems.end(), 1.0f);
   auto filter_r1 = Literal::CreateR1<float>(filter_elems);
   auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
 
@@ -412,56 +472,73 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
                            error_spec_);
 }
 
-XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) {
-  ComputationBuilder builder(client_, TestName());
-  std::vector<int64> input_dims = {1, 3, 3, 5};
-  std::vector<int64> filter_dims = {3, 3, 5, 3};
-  Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
-  Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
-  {
-    auto input = builder.Parameter(0, input_shape, "input");
-    auto filter = builder.Parameter(1, filter_shape, "filter");
-
-    // Tensorflow dimension numbers for 2D convolution.
-    ConvolutionDimensionNumbers dnums;
-    dnums.set_input_batch_dimension(0);
-    dnums.set_output_batch_dimension(0);
-    dnums.add_input_spatial_dimensions(1);
-    dnums.add_output_spatial_dimensions(1);
-    dnums.add_input_spatial_dimensions(2);
-    dnums.add_output_spatial_dimensions(2);
-    dnums.set_input_feature_dimension(3);
-    dnums.set_output_feature_dimension(3);
-    dnums.add_kernel_spatial_dimensions(0);
-    dnums.add_kernel_spatial_dimensions(1);
-    dnums.set_kernel_input_feature_dimension(2);
-    dnums.set_kernel_output_feature_dimension(3);
+// std::iota doesn't work when init_value has a type Eigen::half in some build
+// servers. The error message is missing the operator ++.
+template <typename T>
+void iota_int_init_value(std::vector<T>& values, int init_value) {
+  std::for_each(values.begin(), values.end(),
+                [&](T& value) { value = static_cast<T>(init_value++); });
+}
 
-    builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid,
-                                      dnums);
+template <typename T>
+class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
+ public:
+  void RunTest() {
+    ComputationBuilder builder(client_, TestName());
+    std::vector<int64> input_dims = {1, 3, 3, 5};
+    std::vector<int64> filter_dims = {3, 3, 5, 3};
+    Shape input_shape = MakeShapeWrapper<T>(input_dims);
+    Shape filter_shape = MakeShapeWrapper<T>(filter_dims);
+    {
+      auto input = builder.Parameter(0, input_shape, "input");
+      auto filter = builder.Parameter(1, filter_shape, "filter");
+
+      // Tensorflow dimension numbers for 2D convolution.
+      ConvolutionDimensionNumbers dnums;
+      dnums.set_input_batch_dimension(0);
+      dnums.set_output_batch_dimension(0);
+      dnums.add_input_spatial_dimensions(1);
+      dnums.add_output_spatial_dimensions(1);
+      dnums.add_input_spatial_dimensions(2);
+      dnums.add_output_spatial_dimensions(2);
+      dnums.set_input_feature_dimension(3);
+      dnums.set_output_feature_dimension(3);
+      dnums.add_kernel_spatial_dimensions(0);
+      dnums.add_kernel_spatial_dimensions(1);
+      dnums.set_kernel_input_feature_dimension(2);
+      dnums.set_kernel_output_feature_dimension(3);
+
+      builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid,
+                                        dnums);
+    }
+
+    std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
+    iota_int_init_value(input_elems, 1);
+    auto input_r1 = Literal::CreateR1<T>(input_elems);
+    auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+
+    std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+    iota_int_init_value(filter_elems, 1);
+    auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+    auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+
+    auto expected_r1 = Literal::CreateR1<T>(
+        {static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
+    auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
+
+    auto input_literal =
+        client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+    auto filter_literal =
+        client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+
+    ComputeAndCompareLiteral(&builder, *expected_r4,
+                             {input_literal.get(), filter_literal.get()},
+                             error_spec_);
   }
+};
 
-  std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
-  std::iota(input_elems.begin(), input_elems.end(), 1.0f);
-  auto input_r1 = Literal::CreateR1<float>(input_elems);
-  auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
-
-  std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
-  std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
-  auto filter_r1 = Literal::CreateR1<float>(filter_elems);
-  auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
-
-  auto expected_r1 = Literal::CreateR1<float>({92115, 93150, 94185});
-  auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
-
-  auto input_literal = client_->TransferToServer(*input_r4).ConsumeValueOrDie();
-  auto filter_literal =
-      client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
-
-  ComputeAndCompareLiteral(&builder, *expected_r4,
-                           {input_literal.get(), filter_literal.get()},
-                           error_spec_);
-}
+TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x5_Valid, TestTypes);
+TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x5_Valid, Types) { this->RunTest(); }
 
 // Test fixture to run convolution tests with and without convolution
 // canonicalization enabled.
@@ -519,67 +596,117 @@ struct Convolve1DTestParam {
   int64 num_windows;
 };
 
-class Convolve1D1WindowTest
+class Convolve1D1WindowTestBase
     : public ConvolutionTest,
-      public ::testing::WithParamInterface<Convolve1DTestParam> {};
-
-XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) {
-  ComputationBuilder builder(client_, TestName());
-  int64 input_feature = GetParam().input_feature;
-  int64 output_feature = GetParam().output_feature;
-  int64 batch = GetParam().batch;
-  int64 num_windows = GetParam().num_windows;
-  int64 window_size = GetParam().window_size;
-  std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
-                                   input_feature};
-  std::vector<int64> filter_dims = {window_size, input_feature, output_feature};
-  Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
-  Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
-  {
-    auto input = builder.Parameter(0, input_shape, "input");
-    auto filter = builder.Parameter(1, filter_shape, "filter");
-
-    // Tensorflow dimension numbers for 1D convolution.
-    ConvolutionDimensionNumbers dnums;
-    dnums.set_input_batch_dimension(0);
-    dnums.set_output_batch_dimension(0);
-    dnums.add_input_spatial_dimensions(1);
-    dnums.add_output_spatial_dimensions(1);
-    dnums.set_input_feature_dimension(2);
-    dnums.set_output_feature_dimension(2);
-    dnums.add_kernel_spatial_dimensions(0);
-    dnums.set_kernel_input_feature_dimension(1);
-    dnums.set_kernel_output_feature_dimension(2);
-
-    builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid,
-                                      dnums);
+      public ::testing::WithParamInterface<Convolve1DTestParam> {
+ protected:
+  template <typename T>
+  void TestImpl() {
+    ComputationBuilder builder(client_, TestName());
+    int64 input_feature = GetParam().input_feature;
+    int64 output_feature = GetParam().output_feature;
+    int64 batch = GetParam().batch;
+    int64 num_windows = GetParam().num_windows;
+    int64 window_size = GetParam().window_size;
+    std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
+                                     input_feature};
+    std::vector<int64> filter_dims = {window_size, input_feature,
+                                      output_feature};
+    Shape input_shape = MakeShapeWrapper<T>(input_dims);
+    Shape filter_shape = MakeShapeWrapper<T>(filter_dims);
+    {
+      auto input = builder.Parameter(0, input_shape, "input");
+      auto filter = builder.Parameter(1, filter_shape, "filter");
+
+      // Tensorflow dimension numbers for 1D convolution.
+      ConvolutionDimensionNumbers dnums;
+      dnums.set_input_batch_dimension(0);
+      dnums.set_output_batch_dimension(0);
+      dnums.add_input_spatial_dimensions(1);
+      dnums.add_output_spatial_dimensions(1);
+      dnums.set_input_feature_dimension(2);
+      dnums.set_output_feature_dimension(2);
+      dnums.add_kernel_spatial_dimensions(0);
+      dnums.set_kernel_input_feature_dimension(1);
+      dnums.set_kernel_output_feature_dimension(2);
+
+      builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid,
+                                        dnums);
+    }
+
+    std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
+                               static_cast<T>(1.0f));
+    auto input_r1 = Literal::CreateR1<T>(input_elems);
+    auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+
+    std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
+                                static_cast<T>(1.0f));
+
+    auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+    auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+
+    std::vector<T> expect_elems(batch * output_feature * num_windows,
+                                static_cast<T>(window_size * input_feature));
+    auto expected_r1 = Literal::CreateR1<T>(expect_elems);
+    auto expected_r3 =
+        expected_r1->Reshape({batch, num_windows, output_feature})
+            .ConsumeValueOrDie();
+
+    auto input_literal =
+        client_->TransferToServer(*input_r3).ConsumeValueOrDie();
+    auto filter_literal =
+        client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
+    ComputeAndCompareLiteral(&builder, *expected_r3,
+                             {input_literal.get(), filter_literal.get()},
+                             error_spec_);
   }
+};
 
-  std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape), 1.0);
-  auto input_r1 = Literal::CreateR1<float>(input_elems);
-  auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
 
-  std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape), 1.0);
+XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
 
-  auto filter_r1 = Literal::CreateR1<float>(filter_elems);
-  auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+INSTANTIATE_TEST_CASE_P(
+    Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
+    ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
+                      Convolve1DTestParam{160, 1, 1, 5, 1},
+                      Convolve1DTestParam{24, 1, 1, 20, 1},
+                      Convolve1DTestParam{30, 1, 1, 20, 1},
+                      Convolve1DTestParam{23, 1, 1, 20, 20},
+                      Convolve1DTestParam{25, 1, 1, 20, 1},
+                      Convolve1DTestParam{24, 1, 1, 10, 5},
+                      Convolve1DTestParam{160, 1, 1, 10, 1},
+                      Convolve1DTestParam{255, 1, 1, 3, 1},
+                      Convolve1DTestParam{130, 1, 1, 1, 3},
+                      Convolve1DTestParam{64, 1, 1, 1, 1},
+                      Convolve1DTestParam{128, 1, 1, 1, 1},
+                      Convolve1DTestParam{139, 1, 1, 128, 1},
+                      Convolve1DTestParam{1, 10, 10, 1, 10},
+                      Convolve1DTestParam{1, 10, 130, 1, 2},
+                      Convolve1DTestParam{1, 10, 130, 1, 1},
+                      Convolve1DTestParam{1, 64, 64, 1, 10},
+                      Convolve1DTestParam{1, 65, 65, 1, 1},
+                      Convolve1DTestParam{1, 128, 128, 1, 1},
+                      Convolve1DTestParam{128, 128, 128, 128, 1},
+                      Convolve1DTestParam{1, 128, 128, 1, 1},
+                      Convolve1DTestParam{2, 2, 2, 2, 1},
+                      Convolve1DTestParam{161, 1, 1, 10, 1},
+                      Convolve1DTestParam{900, 1, 1, 10, 1},
+                      Convolve1DTestParam{640, 3, 3, 128, 1})
 
-  std::vector<float> expect_elems(batch * output_feature * num_windows,
-                                  window_size * input_feature);
-  auto expected_r1 = Literal::CreateR1<float>(expect_elems);
-  auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature})
-                         .ConsumeValueOrDie();
+);
 
-  auto input_literal = client_->TransferToServer(*input_r3).ConsumeValueOrDie();
-  auto filter_literal =
-      client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
-  ComputeAndCompareLiteral(&builder, *expected_r3,
-                           {input_literal.get(), filter_literal.get()},
-                           error_spec_);
+#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
+class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
+
+// TODO(b/72509305): Enable half data type tests for CPU.
+XLA_TEST_P(Convolve1D1WindowTestHalf,
+           DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(Convolve1D1Window))) {
+  TestImpl<Eigen::half>();
 }
 
 INSTANTIATE_TEST_CASE_P(
-    Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTest,
+    Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
     ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
                       Convolve1DTestParam{160, 1, 1, 5, 1},
                       Convolve1DTestParam{24, 1, 1, 20, 1},
@@ -592,7 +719,11 @@ INSTANTIATE_TEST_CASE_P(
                       Convolve1DTestParam{130, 1, 1, 1, 3},
                       Convolve1DTestParam{64, 1, 1, 1, 1},
                       Convolve1DTestParam{128, 1, 1, 1, 1},
+                      // TODO(b/72566306): the following three tests fail on CPU
+                      // backend due to result miscompare.
                       Convolve1DTestParam{139, 1, 1, 128, 1},
+                      Convolve1DTestParam{640, 3, 3, 128, 1},
+                      Convolve1DTestParam{900, 1, 1, 10, 1},
                       Convolve1DTestParam{1, 10, 10, 1, 10},
                       Convolve1DTestParam{1, 10, 130, 1, 2},
                       Convolve1DTestParam{1, 10, 130, 1, 1},
@@ -602,11 +733,10 @@ INSTANTIATE_TEST_CASE_P(
                       Convolve1DTestParam{128, 128, 128, 128, 1},
                       Convolve1DTestParam{1, 128, 128, 1, 1},
                       Convolve1DTestParam{2, 2, 2, 2, 1},
-                      Convolve1DTestParam{161, 1, 1, 10, 1},
-                      Convolve1DTestParam{900, 1, 1, 10, 1},
-                      Convolve1DTestParam{640, 3, 3, 128, 1})
+                      Convolve1DTestParam{161, 1, 1, 10, 1})
 
 );
+#endif
 
 TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
   ComputationBuilder builder(client_, TestName());
index cc4eaf62f50d1fa622c705fab810fe1e1b0fbf08..e2d406f66d94f8ec76faa5b7d2d2e84dcaf6db57 100644 (file)
@@ -161,4 +161,31 @@ string PrependDisabledIfIndicated(const string& test_case_name,
 
 #define XLA_TEST_P(test_case_name, test_name) \
   XLA_TEST_P_IMPL_(test_case_name, test_name)
+
+// This is identical to the TEST_F macro from "gtest", but it potentially
+// disables the test based on an external manifest file, DISABLED_MANIFEST.
+#define XLA_TYPED_TEST(CaseName, TestName)                                     \
+  template <typename gtest_TypeParam_>                                         \
+  class GTEST_TEST_CLASS_NAME_(CaseName, TestName)                             \
+      : public CaseName<gtest_TypeParam_> {                                    \
+   private:                                                                    \
+    typedef CaseName<gtest_TypeParam_> TestFixture;                            \
+    typedef gtest_TypeParam_ TypeParam;                                        \
+    virtual void TestBody();                                                   \
+  };                                                                           \
+  bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ =   \
+      ::testing::internal::TypeParameterizedTest<                              \
+          CaseName,                                                            \
+          ::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName,    \
+                                                                  TestName)>,  \
+          GTEST_TYPE_PARAMS_(CaseName)>::                                      \
+          Register(                                                            \
+              "", ::testing::internal::CodeLocation(__FILE__, __LINE__),       \
+              #CaseName,                                                       \
+              ::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
+              0);                                                              \
+  template <typename gtest_TypeParam_>                                         \
+  void GTEST_TEST_CLASS_NAME_(CaseName,                                        \
+                              TestName)<gtest_TypeParam_>::TestBody()
+
 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_