CHECK(idx == num_elements());
+ // Creates a 2D array of Eigen::half from the given nested initializer list of
+ // float values.
+ template <typename T2, typename = typename std::enable_if<
+ std::is_same<T, Eigen::half>::value &&
+ std::is_same<T2, float>::value>::type>
+ Array(std::initializer_list<std::initializer_list<T2>> values)
+ : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
+ int64 idx = 0;
+ for (const auto& it1 : values) {
+ for (const auto& it2 : it1) {
+ values_[idx] = static_cast<T>(it2);
+ ++idx;
+ }
+ }
+ CHECK(idx == num_elements());
+ }
// Creates a 3D array from the given nested initializer list. The outer
// initializer list is the first dimension, and so on.
Array(InitializerList3D values)
CHECK(idx == num_elements());
+ // Creates a 3D array of Eigen::half from the given nested initializer list of
+ // float values.
+ template <typename T2, typename = typename std::enable_if<
+ std::is_same<T, Eigen::half>::value &&
+ std::is_same<T2, float>::value>::type>
+ Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
+ values)
+ : Array(ToInt64Vector({values.size(), values.begin()->size(),
+ values.begin()->begin()->size()})) {
+ int64 idx = 0;
+ for (const auto& it1 : values) {
+ for (const auto& it2 : it1) {
+ for (const auto& it3 : it2) {
+ values_[idx] = static_cast<T>(it3);
+ ++idx;
+ }
+ }
+ }
+ CHECK(idx == num_elements());
+ }
// Creates a 4D array from the given nested initializer list. The outer
// initializer list is the first dimension, and so on.
Array(InitializerList4D values)
CHECK(idx == num_elements());
+ // Creates a 4D array of Eigen::half from the given nested initializer list of
+ // float values.
+ template <typename T2, typename = typename std::enable_if<
+ std::is_same<T, Eigen::half>::value &&
+ std::is_same<T2, float>::value>::type>
+ Array(std::initializer_list<
+ std::initializer_list<std::initializer_list<std::initializer_list<T2>>>>
+ values)
+ : Array(ToInt64Vector({values.size(), values.begin()->size(),
+ values.begin()->begin()->size(),
+ values.begin()->begin()->begin()->size()})) {
+ int64 idx = 0;
+ for (const auto& it1 : values) {
+ for (const auto& it2 : it1) {
+ for (const auto& it3 : it2) {
+ for (const auto& it4 : it3) {
+ values_[idx] = static_cast<T>(it4);
+ ++idx;
+ }
+ }
+ }
+ }
+ CHECK(idx == num_elements());
+ }
Array(const Array<T>& other)
: sizes_(other.sizes_), values_(new T[num_elements()]) {
std::copy(&other.values_[0], &other.values_[0] + num_elements(),
// Fills the array with the sequence i*multiplier for i=0,1,...
void FillWithMultiples(const T& multiplier) {
for (int64 i = 0; i < num_elements(); ++i) {
- values_[i] = i * multiplier;
+ values_[i] = static_cast<T>(i) * multiplier;
Array2D(std::initializer_list<std::initializer_list<T>> values)
: Array<T>(values) {}
+ // Creates an array of Eigen::half from the given nested initializer list of
+ // float values.
+ template <typename T2, typename = typename std::enable_if<
+ std::is_same<T, Eigen::half>::value &&
+ std::is_same<T2, float>::value>::type>
+ Array2D(std::initializer_list<std::initializer_list<T2>> values)
+ : Array<T>(values) {}
Array2D(const Array2D<T>& other) : Array<T>(other) {}
int64 n1() const { return this->dim(0); }
EXPECT_EQ(arr(1, 2), 6);
+TEST(Array2dTest, InitializerListCtorHalf) {
+ Array2D<Eigen::half> arr = {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}};
+ EXPECT_EQ(arr.n1(), 2);
+ EXPECT_EQ(arr.n2(), 3);
+ EXPECT_EQ(arr(0, 0), static_cast<Eigen::half>(1));
+ EXPECT_EQ(arr(0, 1), static_cast<Eigen::half>(2));
+ EXPECT_EQ(arr(0, 2), static_cast<Eigen::half>(3));
+ EXPECT_EQ(arr(1, 0), static_cast<Eigen::half>(4));
+ EXPECT_EQ(arr(1, 1), static_cast<Eigen::half>(5));
+ EXPECT_EQ(arr(1, 2), static_cast<Eigen::half>(6));
TEST(Array2dTest, Accessors) {
Array2D<int> arr = {{1, 2, 3}, {4, 5, 6}};
: Array<T>(values) {}
+ // Creates an array of Eigen::half from the given nested initializer list of
+ // float values.
+ template <typename T2, typename = typename std::enable_if<
+ std::is_same<T, Eigen::half>::value &&
+ std::is_same<T2, float>::value>::type>
+ Array3D(
+ std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
+ values)
+ : Array<T>(values) {}
int64 n1() const { return this->dim(0); }
int64 n2() const { return this->dim(1); }
int64 n3() const { return this->dim(2); }
EXPECT_EQ(arr(2, 3, 1), 24);
+TEST(Array3dTest, InitializerListCtorHalf) {
+ Array3D<Eigen::half> arr = {
+ {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}, {7.0f, 8.0f}},
+ {{9.0f, 10.0f}, {11.0f, 12.0f}, {13.0f, 14.0f}, {15.0f, 16.0f}},
+ {{17.0f, 18.0f}, {19.0f, 20.0f}, {21.0f, 22.0f}, {23.0f, 24.0f}}};
+ EXPECT_EQ(arr.n1(), 3);
+ EXPECT_EQ(arr.n2(), 4);
+ EXPECT_EQ(arr.n3(), 2);
+ EXPECT_EQ(arr.num_elements(), 24);
+ EXPECT_EQ(arr(0, 0, 0), static_cast<Eigen::half>(1));
+ EXPECT_EQ(arr(0, 0, 1), static_cast<Eigen::half>(2));
+ EXPECT_EQ(arr(0, 1, 0), static_cast<Eigen::half>(3));
+ EXPECT_EQ(arr(0, 3, 1), static_cast<Eigen::half>(8));
+ EXPECT_EQ(arr(1, 0, 0), static_cast<Eigen::half>(9));
+ EXPECT_EQ(arr(1, 1, 1), static_cast<Eigen::half>(12));
+ EXPECT_EQ(arr(2, 0, 0), static_cast<Eigen::half>(17));
+ EXPECT_EQ(arr(2, 1, 1), static_cast<Eigen::half>(20));
+ EXPECT_EQ(arr(2, 2, 0), static_cast<Eigen::half>(21));
+ EXPECT_EQ(arr(2, 3, 1), static_cast<Eigen::half>(24));
TEST(Array3dTest, Fill) {
Array3D<int> fullof7(2, 3, 4, 7);
for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) {
: Array<T>(values) {}
+ // Creates an array of Eigen::half from the given nested initializer list of
+ // float values.
+ template <typename T2, typename = typename std::enable_if<
+ std::is_same<T, Eigen::half>::value &&
+ std::is_same<T2, float>::value>::type>
+ Array4D(std::initializer_list<std::initializer_list<
+ std::initializer_list<std::initializer_list<T2>>>>
+ values)
+ : Array<T>(values) {}
// Numerically-named aliases for the various dimensions. This matches the
// dimension names used in array3d.
int64 n4() const { return this->dim(3); }
EXPECT_EQ(arr(2, 3, 1, 0), 24);
+TEST(Array3dTest, InitializerListCtorHalf) {
+ Array4D<Eigen::half> arr = {
+ {{{1.0f}, {2.0f}}, {{3.0f}, {4.0f}}, {{5.0f}, {6.0f}}, {{7.0f}, {8.0f}}},
+ {{{9.0f}, {10.0f}},
+ {{11.0f}, {12.0f}},
+ {{13.0f}, {14.0f}},
+ {{15.0f}, {16.0f}}},
+ {{{17.0f}, {18.0f}},
+ {{19.0f}, {20.0f}},
+ {{21.0f}, {22.0f}},
+ {{23.0f}, {24.0f}}}};
+ EXPECT_EQ(arr.n1(), 3);
+ EXPECT_EQ(arr.n2(), 4);
+ EXPECT_EQ(arr.n3(), 2);
+ EXPECT_EQ(arr.n4(), 1);
+ EXPECT_EQ(arr.num_elements(), 24);
+ EXPECT_EQ(arr(0, 0, 0, 0), static_cast<Eigen::half>(1));
+ EXPECT_EQ(arr(0, 0, 1, 0), static_cast<Eigen::half>(2));
+ EXPECT_EQ(arr(0, 1, 0, 0), static_cast<Eigen::half>(3));
+ EXPECT_EQ(arr(0, 3, 1, 0), static_cast<Eigen::half>(8));
+ EXPECT_EQ(arr(1, 0, 0, 0), static_cast<Eigen::half>(9));
+ EXPECT_EQ(arr(1, 1, 1, 0), static_cast<Eigen::half>(12));
+ EXPECT_EQ(arr(2, 0, 0, 0), static_cast<Eigen::half>(17));
+ EXPECT_EQ(arr(2, 1, 1, 0), static_cast<Eigen::half>(20));
+ EXPECT_EQ(arr(2, 2, 0, 0), static_cast<Eigen::half>(21));
+ EXPECT_EQ(arr(2, 3, 1, 0), static_cast<Eigen::half>(24));
TEST(Array4dTest, Fill) {
Array4D<int> fullof7(2, 3, 4, 5, 7);
fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
EXPECT_EQ(arr(1, 2), 6);
+TEST(ArrayTest, InitializerListCtorHalf) {
+ Array<Eigen::half> d2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
+ EXPECT_EQ(d2.dim(0), 2);
+ EXPECT_EQ(d2.dim(1), 3);
+ Array<Eigen::half> d3({{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}});
+ EXPECT_EQ(d3.dim(0), 3);
+ EXPECT_EQ(d3.dim(1), 2);
+ EXPECT_EQ(d3.dim(2), 1);
+ Array<Eigen::half> d4(
+ {{{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}},
+ {{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}}});
+ EXPECT_EQ(d4.dim(0), 2);
+ EXPECT_EQ(d4.dim(1), 3);
+ EXPECT_EQ(d4.dim(2), 2);
+ EXPECT_EQ(d4.dim(3), 1);
TEST(ArrayTest, IndexingReadWrite) {
Array<int> arr({2, 3});
extern const char* const kEigenMatMulF64SymbolName =
+extern const char* const kEigenConvF16SymbolName =
+ "__xla_cpu_runtime_EigenConvF16";
extern const char* const kEigenConvF32SymbolName =
extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
+extern const char* const kEigenSingleThreadedConvF16SymbolName =
+ "__xla_cpu_runtime_EigenSingleThreadedConvF16";
extern const char* const kEigenSingleThreadedConvF32SymbolName =
extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
// because it is a symbol in the cpu_runtime library.
extern const char* const kEigenMatMulF32SymbolName;
extern const char* const kEigenMatMulF64SymbolName;
+extern const char* const kEigenConvF16SymbolName;
extern const char* const kEigenConvF32SymbolName;
extern const char* const kEigenFftSymbolName;
extern const char* const kEigenSingleThreadedMatMulF32SymbolName;
extern const char* const kEigenSingleThreadedMatMulF64SymbolName;
+extern const char* const kEigenSingleThreadedConvF16SymbolName;
extern const char* const kEigenSingleThreadedConvF32SymbolName;
extern const char* const kAcquireInfeedBufferForDequeueSymbolName;
extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features) {
PrimitiveType type = target_array.GetShape().element_type();
- TF_RET_CHECK(F32 == type || F64 == type || C64 == type);
+ TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type);
DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
lhs_array, rhs_array, addend_array,
executable_run_options_value, ir_builder,
auto rhs = dot->operand(1);
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
- /*supported_types=*/{F32, F64, C64}));
+ /*supported_types=*/{F16, F32, F64, C64}));
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
if (dnums.lhs_batch_dimensions_size() > 0 ||
dnums.rhs_batch_dimensions_size() > 0) {
const auto& window = convolution->window();
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
- /*supported_types=*/{F32, C64}));
+ /*supported_types=*/{F16, F32, C64}));
const ConvolutionDimensionNumbers& dnums =
int64 rhs_col_dilation =
one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
- // Args have been computed, make the call.
- llvm::Type* float_ptr_type = ir_builder_.getFloatTy()->getPointerTo();
+ PrimitiveType primitive_type = lhs->shape().element_type();
+ llvm::Type* ir_ptr_type = primitive_type == F16
+ ? ir_builder_.getHalfTy()->getPointerTo()
+ : ir_builder_.getFloatTy()->getPointerTo();
llvm::Type* int64_type = ir_builder_.getInt64Ty();
llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
llvm::FunctionType* conv_type = llvm::FunctionType::get(
- {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
- int64_type, int64_type, int64_type, int64_type,
- int64_type, int64_type, int64_type, int64_type,
- int64_type, int64_type, int64_type, int64_type,
- int64_type, int64_type, int64_type, int64_type,
- int64_type, int64_type, int64_type, int64_type},
+ {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type,
+ int64_type, int64_type, int64_type, int64_type, int64_type,
+ int64_type, int64_type, int64_type, int64_type, int64_type,
+ int64_type, int64_type, int64_type, int64_type, int64_type,
+ int64_type, int64_type, int64_type, int64_type},
bool multi_threaded_eigen =
const char* fn_name =
- (multi_threaded_eigen
- ? runtime::kEigenConvF32SymbolName
- : runtime::kEigenSingleThreadedConvF32SymbolName);
+ primitive_type == F16
+ ? (multi_threaded_eigen
+ ? runtime::kEigenConvF16SymbolName
+ : runtime::kEigenSingleThreadedConvF16SymbolName)
+ : (multi_threaded_eigen
+ ? runtime::kEigenConvF32SymbolName
+ : runtime::kEigenSingleThreadedConvF32SymbolName);
llvm::Function* conv_func = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(fn_name, conv_type));
conv_func, {
- GetEmittedValueFor(convolution), float_ptr_type),
- ir_builder_.CreateBitCast(lhs_address, float_ptr_type),
- ir_builder_.CreateBitCast(rhs_address, float_ptr_type),
+ GetEmittedValueFor(convolution), ir_ptr_type),
+ ir_builder_.CreateBitCast(lhs_address, ir_ptr_type),
+ ir_builder_.CreateBitCast(rhs_address, ir_ptr_type),
int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
const xla::ExecutableRunOptions* run_options =
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
- tensorflow::xla::EigenConvF32Impl(
+ tensorflow::xla::EigenConvImpl(
+ *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch,
+ input_rows, input_cols, input_channels, kernel_rows, kernel_cols,
+ kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
+ col_stride, padding_top, padding_bottom, padding_left, padding_right,
+ lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF16(
+ const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
+ Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols,
+ int64 input_channels, int64 kernel_rows, int64 kernel_cols,
+ int64 kernel_channels, int64 kernel_filters, int64 output_rows,
+ int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top,
+ int64 padding_bottom, int64 padding_left, int64 padding_right,
+ int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation,
+ int64 rhs_col_dilation) {
+ const xla::ExecutableRunOptions* run_options =
+ static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+ tensorflow::xla::EigenConvImpl(
*run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch,
input_rows, input_cols, input_channels, kernel_rows, kernel_cols,
kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation,
tensorflow::int64 rhs_col_dilation);
+extern void __xla_cpu_runtime_EigenConvF16(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
+ Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs,
+ tensorflow::int64 input_batch, tensorflow::int64 input_rows,
+ tensorflow::int64 input_cols, tensorflow::int64 input_channels,
+ tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols,
+ tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters,
+ tensorflow::int64 output_rows, tensorflow::int64 output_cols,
+ tensorflow::int64 row_stride, tensorflow::int64 col_stride,
+ tensorflow::int64 padding_top, tensorflow::int64 padding_bottom,
+ tensorflow::int64 padding_left, tensorflow::int64 padding_right,
+ tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation,
+ tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation);
} // extern "C"
namespace tensorflow {
namespace xla {
-template <typename EigenDevice>
-void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs,
- float* rhs, int64 input_batch, int64 input_rows,
- int64 input_cols, int64 input_channels, int64 kernel_rows,
- int64 kernel_cols, int64 kernel_channels,
- int64 kernel_filters, int64 output_rows,
- int64 output_cols, int64 row_stride, int64 col_stride,
- int64 padding_top, int64 padding_bottom,
- int64 padding_left, int64 padding_right,
- int64 lhs_row_dilation, int64 lhs_col_dilation,
- int64 rhs_row_dilation, int64 rhs_col_dilation) {
- const Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor>,
+template <typename EigenDevice, typename ScalarType>
+void EigenConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs,
+ ScalarType* rhs, int64 input_batch, int64 input_rows,
+ int64 input_cols, int64 input_channels, int64 kernel_rows,
+ int64 kernel_cols, int64 kernel_channels,
+ int64 kernel_filters, int64 output_rows, int64 output_cols,
+ int64 row_stride, int64 col_stride, int64 padding_top,
+ int64 padding_bottom, int64 padding_left,
+ int64 padding_right, int64 lhs_row_dilation,
+ int64 lhs_col_dilation, int64 rhs_row_dilation,
+ int64 rhs_col_dilation) {
+ const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
input(lhs, input_batch, input_rows, input_cols, input_channels);
- const Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor>,
+ const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
kernel(rhs, kernel_rows, kernel_cols, kernel_channels, kernel_filters);
- Eigen::TensorMap<Eigen::Tensor<float, 4, Eigen::RowMajor>, Eigen::Aligned>
+ Eigen::TensorMap<Eigen::Tensor<ScalarType, 4, Eigen::RowMajor>,
+ Eigen::Aligned>
output(out, input_batch, output_rows, output_cols, kernel_filters);
Eigen::array<Eigen::IndexPair<int64>, 1> contract_dims;
row_stride, rhs_col_dilation, rhs_row_dilation,
lhs_col_dilation, lhs_row_dilation,
padding_left, padding_right, padding_top,
- padding_bottom, 0.0f)
+ padding_bottom, static_cast<ScalarType>(0.0f))
.contract(kernel.reshape(kernel_dims), contract_dims)
using tensorflow::int64;
+ const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
+ Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols,
+ int64 input_channels, int64 kernel_rows, int64 kernel_cols,
+ int64 kernel_channels, int64 kernel_filters, int64 output_rows,
+ int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top,
+ int64 padding_bottom, int64 padding_left, int64 padding_right,
+ int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation,
+ int64 rhs_col_dilation) {
+ tensorflow::xla::EigenConvImpl(
+ Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows,
+ input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels,
+ kernel_filters, output_rows, output_cols, row_stride, col_stride,
+ padding_top, padding_bottom, padding_left, padding_right,
+ lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
const void* run_options_ptr, float* out, float* lhs, float* rhs,
int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom,
int64 padding_left, int64 padding_right, int64 lhs_row_dilation,
int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
- tensorflow::xla::EigenConvF32Impl(
+ tensorflow::xla::EigenConvImpl(
Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows,
input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels,
kernel_filters, output_rows, output_cols, row_stride, col_stride,
extern "C" {
+extern void __xla_cpu_runtime_EigenSingleThreadedConvF16(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
+ Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs,
+ tensorflow::int64 input_batch, tensorflow::int64 input_rows,
+ tensorflow::int64 input_cols, tensorflow::int64 input_channels,
+ tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols,
+ tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters,
+ tensorflow::int64 output_rows, tensorflow::int64 output_cols,
+ tensorflow::int64 row_stride, tensorflow::int64 col_stride,
+ tensorflow::int64 padding_top, tensorflow::int64 padding_bottom,
+ tensorflow::int64 padding_left, tensorflow::int64 padding_right,
+ tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation,
+ tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation);
extern void __xla_cpu_runtime_EigenSingleThreadedConvF32(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
float* lhs, float* rhs, tensorflow::int64 input_batch,
+ REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16);
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream) {
- se::DeviceMemory<float> input_data(
- buffer_allocations.GetDeviceAddress(input_buffer_));
- se::DeviceMemory<float> filter_data(
- buffer_allocations.GetDeviceAddress(filter_buffer_));
- se::DeviceMemory<float> output_data(
- buffer_allocations.GetDeviceAddress(output_buffer_));
+ se::DeviceMemoryBase input_data =
+ buffer_allocations.GetDeviceAddress(input_buffer_);
+ se::DeviceMemoryBase filter_data =
+ buffer_allocations.GetDeviceAddress(filter_buffer_);
+ se::DeviceMemoryBase output_data =
+ buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
filter_data, output_data, scratch, window_, dim_nums_, algorithm_config,
- // Figure out which of output/input/filter is the result produced by this op,
- // and write the result tuple.
+ // Figure out which of output/input/filter is the result produced by
+ // this op, and write the result tuple.
void* result_ptr = [&] {
switch (convolution_kind_) {
case CudnnConvKind::kForward:
- // Remove any algorithms with tensor math enabled. These have lower precision
- // than regular algorithms, and we don't yet have a way to turn this on/off in
- // XLA.
- algorithms.erase(std::remove_if(algorithms.begin(), algorithms.end(),
- [&](const AlgorithmDesc& a) {
- return a.tensor_ops_enabled();
- }),
- algorithms.end());
return algorithms;
ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
for (const AlgorithmDesc& alg :
GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- bool launch_ok =
- RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf.ValueOrDie()),
- se::DeviceMemory<float>(filter_buf.ValueOrDie()),
- se::DeviceMemory<float>(output_buf.ValueOrDie()),
- &scratch_allocator, window, dnums,
- AlgorithmConfig(alg), &stream, &profile_result)
- .ok();
+ bool launch_ok = RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ input_buf.ValueOrDie(), filter_buf.ValueOrDie(),
+ output_buf.ValueOrDie(), &scratch_allocator, window,
+ dnums, AlgorithmConfig(alg), &stream, &profile_result)
+ .ok();
if (launch_ok && profile_result.is_valid()) {
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
bool allocated_ = false;
-} // anonymous namespace
-string CudnnConvKindToString(CudnnConvKind kind) {
- switch (kind) {
- case CudnnConvKind::kForward:
- return "forward";
- case CudnnConvKind::kBackwardFilter:
- return "backward_filter";
- case CudnnConvKind::kBackwardInput:
- return "backward_input";
- }
-Status RunCudnnConvolution(CudnnConvKind kind, const Shape& input_shape,
- const Shape& filter_shape, const Shape& output_shape,
- DeviceMemory<float> input_buf,
- DeviceMemory<float> filter_buf,
- DeviceMemory<float> output_buf,
- DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- AlgorithmConfig algorithm, Stream* stream,
- ProfileResult* profile_result /*= nullptr*/) {
- ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- input_buf, filter_buf, output_buf,
- &scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+template <typename T>
Status RunCudnnConvolution(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, DeviceMemory<float> input_buf,
- DeviceMemory<float> filter_buf, DeviceMemory<float> output_buf,
+ const Shape& output_shape, DeviceMemory<T> input_buf,
+ DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
se::ScratchAllocator* scratch_allocator, const Window& window,
const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
// tensorflow/python/ops/
const int effective_num_dimensions = std::max(2, num_dimensions);
- CHECK_EQ(F32, output_shape.element_type())
- << ShapeUtil::HumanString(output_shape);
+ if (std::is_same<T, float>::value) {
+ CHECK_EQ(F32, output_shape.element_type())
+ << ShapeUtil::HumanString(output_shape);
+ } else if (std::is_same<T, Eigen::half>::value) {
+ CHECK_EQ(F16, output_shape.element_type())
+ << ShapeUtil::HumanString(output_shape);
+ } else {
+ LOG(FATAL) << ShapeUtil::HumanString(output_shape);
+ }
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size());
CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size());
CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size());
return Status::OK();
+} // anonymous namespace
+string CudnnConvKindToString(CudnnConvKind kind) {
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ return "forward";
+ case CudnnConvKind::kBackwardFilter:
+ return "backward_filter";
+ case CudnnConvKind::kBackwardInput:
+ return "backward_input";
+ }
+Status RunCudnnConvolution(
+ CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
+ const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
+ perftools::gputools::DeviceMemoryBase filter_buf,
+ perftools::gputools::DeviceMemoryBase output_buf,
+ perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ perftools::gputools::dnn::AlgorithmConfig algorithm,
+ perftools::gputools::Stream* stream,
+ perftools::gputools::dnn::ProfileResult* profile_result) {
+ ScratchBufAllocator scratch_allocator(scratch_buf);
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ input_buf, filter_buf, output_buf,
+ &scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
+Status RunCudnnConvolution(
+ CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
+ const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
+ perftools::gputools::DeviceMemoryBase filter_buf,
+ perftools::gputools::DeviceMemoryBase output_buf,
+ perftools::gputools::ScratchAllocator* scratch_allocator,
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ perftools::gputools::dnn::AlgorithmConfig algorithm,
+ perftools::gputools::Stream* stream,
+ perftools::gputools::dnn::ProfileResult* profile_result) {
+ PrimitiveType output_primitive_type = output_shape.element_type();
+ CHECK(output_primitive_type == F32 || output_primitive_type == F16)
+ << ShapeUtil::HumanString(output_shape);
+ if (output_primitive_type == F32) {
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<float>(input_buf), se::DeviceMemory<float>(filter_buf),
+ se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
+ algorithm, stream, profile_result);
+ }
+ return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<Eigen::half>(input_buf),
+ se::DeviceMemory<Eigen::half>(filter_buf),
+ se::DeviceMemory<Eigen::half>(output_buf),
+ scratch_allocator, window, dnums, algorithm,
+ stream, profile_result);
} // namespace gpu
} // namespace xla
// Note that depending on the value of CudnnConvKind, the result of this call
// may be written into input_buf, filter_buf, or output_buf!
-// At the moment we only support cudnn convolutions over floats.
+// At the moment we only support cudnn convolutions over float and half, and
+// convolution with half data type is implemented with cudnn PSEUDO_HALF
+// configuration, that is, the input values are half and the internal
+// computation type is float.
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
// that size, if you like.
Status RunCudnnConvolution(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape,
- perftools::gputools::DeviceMemory<float> input_buf,
- perftools::gputools::DeviceMemory<float> filter_buf,
- perftools::gputools::DeviceMemory<float> output_buf,
+ const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
+ perftools::gputools::DeviceMemoryBase filter_buf,
+ perftools::gputools::DeviceMemoryBase output_buf,
perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window,
const ConvolutionDimensionNumbers& dnums,
perftools::gputools::dnn::AlgorithmConfig algorithm,
Status RunCudnnConvolution(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape,
- perftools::gputools::DeviceMemory<float> input_buf,
- perftools::gputools::DeviceMemory<float> filter_buf,
- perftools::gputools::DeviceMemory<float> output_buf,
+ const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf,
+ perftools::gputools::DeviceMemoryBase filter_buf,
+ perftools::gputools::DeviceMemoryBase output_buf,
perftools::gputools::ScratchAllocator* scratch_allocator,
const Window& window, const ConvolutionDimensionNumbers& dnums,
perftools::gputools::dnn::AlgorithmConfig algorithm,
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map));
+ case F16: {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[map],
+ MapImpl<Eigen::half>(map));
+ break;
+ }
case F32: {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map));
typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
- typed_visitors_[F16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
- return Unimplemented("HloEvaluator: unhandled primitive type: F16.");
- });
+ typed_visitors_[F16] = MakeUnique<TypedVisitor<Eigen::half, float>>(this);
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
-XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
- const int kInputActivationSizeY = 3;
- const int kInputActivationSizeX = 3;
- const int kInputActivationSizeZ = 256;
- const int kKernelSizeX = 2;
- const int kKernelSizeY = 2;
- const int kOutputActivationSizeZ = 256;
- const int kMiniBatchSize = 4;
- auto alhs =
- MakeUnique<Array4D<float>>(kMiniBatchSize, kInputActivationSizeZ,
- kInputActivationSizeY, kInputActivationSizeX);
- alhs->FillWithMultiples(1.0f);
- ASSERT_EQ(3, alhs->width());
- ASSERT_EQ(3, alhs->height());
- auto arhs =
- MakeUnique<Array4D<float>>(kOutputActivationSizeZ, kInputActivationSizeZ,
- kKernelSizeY, kKernelSizeX);
- Array2D<float> rhs_raster({
- {1.0f, 0.0f}, // row 0
- {0.0f, 0.0f}, // row 1
- });
- arhs->FillWithYX(rhs_raster);
- ASSERT_EQ(2, arhs->width());
- ASSERT_EQ(2, arhs->height());
+// TODO(b/72509305): Enable half data type tests for CPU
+using TestTypes = ::testing::Types<float, Eigen::half>;
+using TestTypes = ::testing::Types<float>;
- ComputationBuilder builder(client_, TestName());
- auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
- auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
- auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+template <typename T>
+Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice<int64> dimensions);
- ComputeAndCompare(&builder, conv, {}, error_spec_);
+template <>
+Shape MakeShapeWrapper<float>(tensorflow::gtl::ArraySlice<int64> dimensions) {
+ return ShapeUtil::MakeShape(F32, dimensions);
-TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
- ComputationBuilder builder(client_, TestName());
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+template <>
+Shape MakeShapeWrapper<Eigen::half>(
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
+ return ShapeUtil::MakeShape(F16, dimensions);
- Array4D<float> input_data(1, 1, 1, 2);
- input_data.FillWithYX(Array2D<float>({
- {1, 2},
- }));
- Array4D<float> filter_data(1, 1, 1, 2);
- filter_data.FillWithYX(Array2D<float>({
- {5, 6},
- }));
+template <typename T>
+class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
+ public:
+ void RunTest() {
+ const int kInputActivationSizeY = 3;
+ const int kInputActivationSizeX = 3;
+ const int kInputActivationSizeZ = 256;
+ const int kKernelSizeX = 2;
+ const int kKernelSizeY = 2;
+ const int kOutputActivationSizeZ = 256;
+ const int kMiniBatchSize = 4;
+ auto alhs =
+ MakeUnique<Array4D<T>>(kMiniBatchSize, kInputActivationSizeZ,
+ kInputActivationSizeY, kInputActivationSizeX);
+ alhs->FillWithMultiples(static_cast<T>(1.0f));
+ ASSERT_EQ(3, alhs->width());
+ ASSERT_EQ(3, alhs->height());
+ auto arhs =
+ MakeUnique<Array4D<T>>(kOutputActivationSizeZ, kInputActivationSizeZ,
+ kKernelSizeY, kKernelSizeX);
+ Array2D<T> rhs_raster({
+ {1.0f, 0.0f}, // row 0
+ {0.0f, 0.0f}, // row 1
+ });
+ arhs->FillWithYX(rhs_raster);
+ ASSERT_EQ(2, arhs->width());
+ ASSERT_EQ(2, arhs->height());
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR4FromArray4D<T>(*alhs);
+ auto rhs = builder.ConstantR4FromArray4D<T>(*arhs);
+ auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+ ComputeAndCompare(&builder, conv, {}, error_spec_);
+ }
- ComputeAndCompare(&builder, conv,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
- error_spec_);
+TYPED_TEST_CASE(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, TestTypes);
+XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, Types) {
+ this->RunTest();
+template <typename T>
+class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
+ public:
+ void RunTest() {
+ ComputationBuilder builder(client_, TestName());
+ Shape input_shape = MakeShapeWrapper<T>({1, 1, 1, 2});
+ Shape filter_shape = MakeShapeWrapper<T>({1, 1, 1, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Array4D<T> input_data(1, 1, 1, 2);
+ input_data.FillWithYX(Array2D<T>({
+ {1.0f, 2.0f},
+ }));
+ Array4D<T> filter_data(1, 1, 1, 2);
+ filter_data.FillWithYX(Array2D<T>({
+ {5.0f, 6.0f},
+ }));
+ ComputeAndCompare(&builder, conv,
+ {std::move(*Literal::CreateFromArray(input_data)),
+ std::move(*Literal::CreateFromArray(filter_data))},
+ error_spec_);
+ }
+TYPED_TEST_CASE(Convolve_1x1x1x2_1x1x1x2_Valid, TestTypes);
+TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid, Types) { this->RunTest(); }
// Tests valid padding for 2D convolution in raster space.
-TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
- ComputationBuilder builder(client_, TestName());
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+template <typename T>
+class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
+ public:
+ void RunTest() {
+ ComputationBuilder builder(client_, TestName());
+ Shape input_shape = MakeShapeWrapper<T>({1, 1, 4, 4});
+ Shape filter_shape = MakeShapeWrapper<T>({1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Array4D<T> input_data(1, 1, 4, 4);
+ input_data.FillWithYX(Array2D<T>({
+ {1.0f, 2.0f, 3.0f, 4.0f},
+ {5.0f, 6.0f, 7.0f, 8.0f},
+ {9.0f, 10.0f, 11.0f, 12.0f},
+ {13.0f, 14.0f, 15.0f, 16.0f},
+ }));
+ Array4D<T> filter_data(1, 1, 2, 2);
+ filter_data.FillWithYX(Array2D<T>({
+ {5.0f, 6.0f},
+ {7.0f, 8.0f},
+ }));
+ ComputeAndCompare(&builder, conv,
+ {std::move(*Literal::CreateFromArray(input_data)),
+ std::move(*Literal::CreateFromArray(filter_data))},
+ error_spec_);
+ }
- Array4D<float> input_data(1, 1, 4, 4);
- // clang-format off
- input_data.FillWithYX(Array2D<float>({
- {1, 2, 3, 4 },
- {5, 6, 7, 8 },
- {9, 10, 11, 12},
- {13, 14, 15, 16},
- }));
- // clang-format on
- Array4D<float> filter_data(1, 1, 2, 2);
- // clang-format off
- filter_data.FillWithYX(Array2D<float>({
- {5, 6},
- {7, 8},
- }));
- // clang-format on
- ComputeAndCompare(&builder, conv,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
- error_spec_);
+TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Valid, TestTypes);
+TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid, Types) { this->RunTest(); }
// Tests same padding for 2D convolution in raster space.
-TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
- ComputationBuilder builder(client_, TestName());
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
- Array4D<float> input_data(1, 1, 4, 4);
- // clang-format off
- input_data.FillWithYX(Array2D<float>({
- {1, 2, 3, 4 },
- {5, 6, 7, 8 },
- {9, 10, 11, 12},
- {13, 14, 15, 16},
- }));
- // clang-format on
- Array4D<float> filter_data(1, 1, 2, 2);
- // clang-format off
- filter_data.FillWithYX(Array2D<float>({
- {5, 6},
- {7, 8},
- }));
- // clang-format on
- ComputeAndCompare(&builder, conv,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
- error_spec_);
+template <typename T>
+class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
+ public:
+ void RunTest() {
+ ComputationBuilder builder(client_, TestName());
+ Shape input_shape = MakeShapeWrapper<T>({1, 1, 4, 4});
+ Shape filter_shape = MakeShapeWrapper<T>({1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ Array4D<T> input_data(1, 1, 4, 4);
+ input_data.FillWithYX(Array2D<T>({
+ {1.0f, 2.0f, 3.0f, 4.0f},
+ {5.0f, 6.0f, 7.0f, 8.0f},
+ {9.0f, 10.0f, 11.0f, 12.0f},
+ {13.0f, 14.0f, 15.0f, 16.0f},
+ }));
+ Array4D<T> filter_data(1, 1, 2, 2);
+ filter_data.FillWithYX(Array2D<T>({
+ {5.0f, 6.0f},
+ {7.0f, 8.0f},
+ }));
+ ComputeAndCompare(&builder, conv,
+ {std::move(*Literal::CreateFromArray(input_data)),
+ std::move(*Literal::CreateFromArray(filter_data))},
+ error_spec_);
+ }
+TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Same, TestTypes);
+TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same, Types) { this->RunTest(); }
// Tests same padding for 2D convolution in raster space with an odd sized
// kernel.
-TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
- ComputationBuilder builder(client_, TestName());
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
- Array4D<float> input_data(1, 1, 4, 4);
- // clang-format off
- input_data.FillWithYX(Array2D<float>({
- {1, 2, 3, 4 },
- {5, 6, 7, 8 },
- {9, 10, 11, 12},
- {13, 14, 15, 16},
- }));
- // clang-format on
- Array4D<float> filter_data(1, 1, 3, 3);
- // clang-format off
- filter_data.FillWithYX(Array2D<float>({
- { 5, 6, 7},
- { 8, 9, 10},
- {11, 12, 13},
- }));
- // clang-format on
- ComputeAndCompare(&builder, conv,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
- error_spec_);
+template <typename T>
+class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
+ public:
+ void RunTest() {
+ ComputationBuilder builder(client_, TestName());
+ Shape input_shape = MakeShapeWrapper<T>({1, 1, 4, 4});
+ Shape filter_shape = MakeShapeWrapper<T>({1, 1, 3, 3});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ Array4D<T> input_data(1, 1, 4, 4);
+ input_data.FillWithYX(Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f},
+ {5.0f, 6.0f, 7.0f, 8.0f},
+ {9.0f, 10.0f, 11.0f, 12.0f},
+ {13.0f, 14.0f, 15.0f, 16.0f}}));
+ Array4D<T> filter_data(1, 1, 3, 3);
+ filter_data.FillWithYX(Array2D<T>(
+ {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
+ // clang-format on
+ ComputeAndCompare(&builder, conv,
+ {std::move(*Literal::CreateFromArray(input_data)),
+ std::move(*Literal::CreateFromArray(filter_data))},
+ error_spec_);
+ }
+TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
+TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
ComputationBuilder builder(client_, TestName());
-XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithRHSDilation) {
- ComputationBuilder builder(client_, TestName());
- {
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- // Convolution dimensions are bf0_oi0->bo0.
- builder.ConvGeneralDilated(
- input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
- /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
- /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+template <typename T>
+class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
+ public:
+ void RunTest() {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = MakeShapeWrapper<T>({1, 2, 5});
+ Shape filter_shape = MakeShapeWrapper<T>({1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ // Convolution dimensions are bf0_oi0->bo0.
+ builder.ConvGeneralDilated(
+ input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
+ /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
+ /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+ }
+ Array3D<T> input(
+ {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
+ Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
+ Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
+ auto input_literal =
+ client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ .ConsumeValueOrDie();
+ ComputeAndCompareR3<T>(&builder, expected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}; // namespace
- Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
- Array3D<float> filter({{{10, 20}, {30, 40}}});
- Array3D<float> expected({{{570, 670, 770}}});
- auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
- .ConsumeValueOrDie();
- ComputeAndCompareR3<float>(&builder, expected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
+TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
+TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
ComputationBuilder builder(client_, TestName());
-XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithPadding) {
- ComputationBuilder builder(client_, TestName());
- {
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- // Convolution dimensions are bf0_oi0->bo0.
- builder.ConvGeneralDilated(
- input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
- /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
- /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+template <typename T>
+class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
+ public:
+ void RunTest() {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = MakeShapeWrapper<T>({1, 2, 5});
+ Shape filter_shape = MakeShapeWrapper<T>({1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ // Convolution dimensions are bf0_oi0->bo0.
+ builder.ConvGeneralDilated(
+ input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
+ /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
+ /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
+ }
+ Array3D<T> input(
+ {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
+ Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
+ Array3D<T> expected(
+ {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
+ auto input_literal =
+ client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ .ConsumeValueOrDie();
+ ComputeAndCompareR3<T>(&builder, expected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
- Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
- Array3D<float> filter({{{10, 20}, {30, 40}}});
- Array3D<float> expected({{{0, 260, 510, 610, 710, 810, 350, 0}}});
- auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
- .ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
- .ConsumeValueOrDie();
- ComputeAndCompareR3<float>(&builder, expected,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
+TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
+TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
ComputationBuilder builder(client_, TestName());
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
- std::iota(input_elems.begin(), input_elems.end(), 1.0f);
+ iota(input_elems.begin(), input_elems.end(), 1.0f);
auto input_r1 = Literal::CreateR1<float>(input_elems);
auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
- std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
+ iota(filter_elems.begin(), filter_elems.end(), 1.0f);
auto filter_r1 = Literal::CreateR1<float>(filter_elems);
auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
-XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) {
- ComputationBuilder builder(client_, TestName());
- std::vector<int64> input_dims = {1, 3, 3, 5};
- std::vector<int64> filter_dims = {3, 3, 5, 3};
- Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
- Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
- {
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- // Tensorflow dimension numbers for 2D convolution.
- ConvolutionDimensionNumbers dnums;
- dnums.set_input_batch_dimension(0);
- dnums.set_output_batch_dimension(0);
- dnums.add_input_spatial_dimensions(1);
- dnums.add_output_spatial_dimensions(1);
- dnums.add_input_spatial_dimensions(2);
- dnums.add_output_spatial_dimensions(2);
- dnums.set_input_feature_dimension(3);
- dnums.set_output_feature_dimension(3);
- dnums.add_kernel_spatial_dimensions(0);
- dnums.add_kernel_spatial_dimensions(1);
- dnums.set_kernel_input_feature_dimension(2);
- dnums.set_kernel_output_feature_dimension(3);
+// std::iota doesn't work when init_value has a type Eigen::half in some build
+// servers. The error message is missing the operator ++.
+template <typename T>
+void iota_int_init_value(std::vector<T>& values, int init_value) {
+ std::for_each(values.begin(), values.end(),
+ [&](T& value) { value = static_cast<T>(init_value++); });
- builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid,
- dnums);
+template <typename T>
+class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
+ public:
+ void RunTest() {
+ ComputationBuilder builder(client_, TestName());
+ std::vector<int64> input_dims = {1, 3, 3, 5};
+ std::vector<int64> filter_dims = {3, 3, 5, 3};
+ Shape input_shape = MakeShapeWrapper<T>(input_dims);
+ Shape filter_shape = MakeShapeWrapper<T>(filter_dims);
+ {
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ // Tensorflow dimension numbers for 2D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(0);
+ dnums.set_output_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.set_input_feature_dimension(3);
+ dnums.set_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+ builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid,
+ dnums);
+ }
+ std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
+ iota_int_init_value(input_elems, 1);
+ auto input_r1 = Literal::CreateR1<T>(input_elems);
+ auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+ iota_int_init_value(filter_elems, 1);
+ auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+ auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ auto expected_r1 = Literal::CreateR1<T>(
+ {static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
+ auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
+ auto input_literal =
+ client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, *expected_r4,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
- std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
- std::iota(input_elems.begin(), input_elems.end(), 1.0f);
- auto input_r1 = Literal::CreateR1<float>(input_elems);
- auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
- std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
- std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
- auto filter_r1 = Literal::CreateR1<float>(filter_elems);
- auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
- auto expected_r1 = Literal::CreateR1<float>({92115, 93150, 94185});
- auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
- auto input_literal = client_->TransferToServer(*input_r4).ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r4,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
+TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x5_Valid, TestTypes);
+TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x5_Valid, Types) { this->RunTest(); }
// Test fixture to run convolution tests with and without convolution
// canonicalization enabled.
int64 num_windows;
-class Convolve1D1WindowTest
+class Convolve1D1WindowTestBase
: public ConvolutionTest,
- public ::testing::WithParamInterface<Convolve1DTestParam> {};
-XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) {
- ComputationBuilder builder(client_, TestName());
- int64 input_feature = GetParam().input_feature;
- int64 output_feature = GetParam().output_feature;
- int64 batch = GetParam().batch;
- int64 num_windows = GetParam().num_windows;
- int64 window_size = GetParam().window_size;
- std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
- input_feature};
- std::vector<int64> filter_dims = {window_size, input_feature, output_feature};
- Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
- Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
- {
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- // Tensorflow dimension numbers for 1D convolution.
- ConvolutionDimensionNumbers dnums;
- dnums.set_input_batch_dimension(0);
- dnums.set_output_batch_dimension(0);
- dnums.add_input_spatial_dimensions(1);
- dnums.add_output_spatial_dimensions(1);
- dnums.set_input_feature_dimension(2);
- dnums.set_output_feature_dimension(2);
- dnums.add_kernel_spatial_dimensions(0);
- dnums.set_kernel_input_feature_dimension(1);
- dnums.set_kernel_output_feature_dimension(2);
- builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid,
- dnums);
+ public ::testing::WithParamInterface<Convolve1DTestParam> {
+ protected:
+ template <typename T>
+ void TestImpl() {
+ ComputationBuilder builder(client_, TestName());
+ int64 input_feature = GetParam().input_feature;
+ int64 output_feature = GetParam().output_feature;
+ int64 batch = GetParam().batch;
+ int64 num_windows = GetParam().num_windows;
+ int64 window_size = GetParam().window_size;
+ std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
+ input_feature};
+ std::vector<int64> filter_dims = {window_size, input_feature,
+ output_feature};
+ Shape input_shape = MakeShapeWrapper<T>(input_dims);
+ Shape filter_shape = MakeShapeWrapper<T>(filter_dims);
+ {
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ // Tensorflow dimension numbers for 1D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(0);
+ dnums.set_output_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.set_input_feature_dimension(2);
+ dnums.set_output_feature_dimension(2);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.set_kernel_input_feature_dimension(1);
+ dnums.set_kernel_output_feature_dimension(2);
+ builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid,
+ dnums);
+ }
+ std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
+ static_cast<T>(1.0f));
+ auto input_r1 = Literal::CreateR1<T>(input_elems);
+ auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+ std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
+ static_cast<T>(1.0f));
+ auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+ auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ std::vector<T> expect_elems(batch * output_feature * num_windows,
+ static_cast<T>(window_size * input_feature));
+ auto expected_r1 = Literal::CreateR1<T>(expect_elems);
+ auto expected_r3 =
+ expected_r1->Reshape({batch, num_windows, output_feature})
+ .ConsumeValueOrDie();
+ auto input_literal =
+ client_->TransferToServer(*input_r3).ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, *expected_r3,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
- std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape), 1.0);
- auto input_r1 = Literal::CreateR1<float>(input_elems);
- auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
- std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape), 1.0);
+XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
- auto filter_r1 = Literal::CreateR1<float>(filter_elems);
- auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+ Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
+ ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
+ Convolve1DTestParam{160, 1, 1, 5, 1},
+ Convolve1DTestParam{24, 1, 1, 20, 1},
+ Convolve1DTestParam{30, 1, 1, 20, 1},
+ Convolve1DTestParam{23, 1, 1, 20, 20},
+ Convolve1DTestParam{25, 1, 1, 20, 1},
+ Convolve1DTestParam{24, 1, 1, 10, 5},
+ Convolve1DTestParam{160, 1, 1, 10, 1},
+ Convolve1DTestParam{255, 1, 1, 3, 1},
+ Convolve1DTestParam{130, 1, 1, 1, 3},
+ Convolve1DTestParam{64, 1, 1, 1, 1},
+ Convolve1DTestParam{128, 1, 1, 1, 1},
+ Convolve1DTestParam{139, 1, 1, 128, 1},
+ Convolve1DTestParam{1, 10, 10, 1, 10},
+ Convolve1DTestParam{1, 10, 130, 1, 2},
+ Convolve1DTestParam{1, 10, 130, 1, 1},
+ Convolve1DTestParam{1, 64, 64, 1, 10},
+ Convolve1DTestParam{1, 65, 65, 1, 1},
+ Convolve1DTestParam{1, 128, 128, 1, 1},
+ Convolve1DTestParam{128, 128, 128, 128, 1},
+ Convolve1DTestParam{1, 128, 128, 1, 1},
+ Convolve1DTestParam{2, 2, 2, 2, 1},
+ Convolve1DTestParam{161, 1, 1, 10, 1},
+ Convolve1DTestParam{900, 1, 1, 10, 1},
+ Convolve1DTestParam{640, 3, 3, 128, 1})
- std::vector<float> expect_elems(batch * output_feature * num_windows,
- window_size * input_feature);
- auto expected_r1 = Literal::CreateR1<float>(expect_elems);
- auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature})
- .ConsumeValueOrDie();
- auto input_literal = client_->TransferToServer(*input_r3).ConsumeValueOrDie();
- auto filter_literal =
- client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
- ComputeAndCompareLiteral(&builder, *expected_r3,
- {input_literal.get(), filter_literal.get()},
- error_spec_);
+class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
+// TODO(b/72509305): Enable half data type tests for CPU.
+ TestImpl<Eigen::half>();
- Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTest,
+ Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
Convolve1DTestParam{160, 1, 1, 5, 1},
Convolve1DTestParam{24, 1, 1, 20, 1},
Convolve1DTestParam{130, 1, 1, 1, 3},
Convolve1DTestParam{64, 1, 1, 1, 1},
Convolve1DTestParam{128, 1, 1, 1, 1},
+ // TODO(b/72566306): the following three tests fail on CPU
+ // backend due to result miscompare.
Convolve1DTestParam{139, 1, 1, 128, 1},
+ Convolve1DTestParam{640, 3, 3, 128, 1},
+ Convolve1DTestParam{900, 1, 1, 10, 1},
Convolve1DTestParam{1, 10, 10, 1, 10},
Convolve1DTestParam{1, 10, 130, 1, 2},
Convolve1DTestParam{1, 10, 130, 1, 1},
Convolve1DTestParam{128, 128, 128, 128, 1},
Convolve1DTestParam{1, 128, 128, 1, 1},
Convolve1DTestParam{2, 2, 2, 2, 1},
- Convolve1DTestParam{161, 1, 1, 10, 1},
- Convolve1DTestParam{900, 1, 1, 10, 1},
- Convolve1DTestParam{640, 3, 3, 128, 1})
+ Convolve1DTestParam{161, 1, 1, 10, 1})
TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
ComputationBuilder builder(client_, TestName());
#define XLA_TEST_P(test_case_name, test_name) \
XLA_TEST_P_IMPL_(test_case_name, test_name)
+// This is identical to the TEST_F macro from "gtest", but it potentially
+// disables the test based on an external manifest file, DISABLED_MANIFEST.
+#define XLA_TYPED_TEST(CaseName, TestName) \
+ template <typename gtest_TypeParam_> \
+ class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
+ : public CaseName<gtest_TypeParam_> { \
+ private: \
+ typedef CaseName<gtest_TypeParam_> TestFixture; \
+ typedef gtest_TypeParam_ TypeParam; \
+ virtual void TestBody(); \
+ }; \
+ bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
+ ::testing::internal::TypeParameterizedTest< \
+ CaseName, \
+ ::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
+ TestName)>, \
+ GTEST_TYPE_PARAMS_(CaseName)>:: \
+ Register( \
+ "", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
+ #CaseName, \
+ ::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
+ 0); \
+ template <typename gtest_TypeParam_> \
+ void GTEST_TEST_CLASS_NAME_(CaseName, \
+ TestName)<gtest_TypeParam_>::TestBody()