From 3b36867affdf446d69ab68639152026a8f164084 Mon Sep 17 00:00:00 2001 From: Tayo Oguntebi Date: Tue, 13 Mar 2018 19:00:06 -0700 Subject: [PATCH] Adds R2 ReduceWindow test and generalizes the R2 test suite, particularly to allow arbitary padding. PiperOrigin-RevId: 188966155 --- tensorflow/compiler/xla/reference_util.cc | 44 +++++---- tensorflow/compiler/xla/reference_util.h | 7 +- .../compiler/xla/tests/reduce_window_test.cc | 108 +++++++++++++-------- 3 files changed, 98 insertions(+), 61 deletions(-) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 8711b8a..b894880 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -189,18 +189,6 @@ ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { - std::vector dim_lengths{static_cast(operand.size())}; - return ReduceWindow1DGeneric( - operand, init, reduce_func, window, stride, - xla::MakePadding(dim_lengths, window, stride, padding)); -} - -/* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, - const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, const tensorflow::gtl::ArraySlice>& padding) { std::vector dim_lengths{static_cast(operand.size())}; @@ -235,23 +223,28 @@ ReferenceUtil::ReduceWindow1DAdd( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; - return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, - padding); + std::vector dim_lengths{static_cast(operand.size())}; + return ReduceWindow1DGeneric( + operand, init, add_reduce, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); } -/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, + const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { std::vector dim_lengths{operand.height(), operand.width()}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0], window_counts[1]); @@ -267,7 +260,7 @@ ReferenceUtil::ReduceWindow1DAdd( if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && i0_base + i0_win < operand.n1() && i1_base + i1_win < operand.n2()) { - val += operand(i0_base + i0_win, i1_base + i1_win); + val = reduce_func(val, operand(i0_base + i0_win, i1_base + i1_win)); } } } @@ -277,6 +270,17 @@ ReferenceUtil::ReduceWindow1DAdd( return result; } +/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( + const Array2D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; + std::vector dim_lengths{operand.height(), operand.width()}; + return ReduceWindow2DGeneric( + operand, init, add_reduce, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} + /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( const Array3D& operand, float init, const tensorflow::gtl::ArraySlice& window, diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 57b0218..c3b0406 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -199,9 +199,10 @@ class ReferenceUtil { const tensorflow::gtl::ArraySlice& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); - static std::unique_ptr> ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); + static std::unique_ptr> ReduceWindow2DGeneric( + const Array2D& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 6f3b8ea..8b736f6 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -960,49 +960,67 @@ struct R2ReduceWindowTestData { int64 base_bounds[2]; int64 window_bounds[2]; int64 strides[2]; + int64 pad_low[2]; + int64 pad_high[2]; int64 layout[2]; - Padding padding; Reducer reducer; } kR2TestCases[] = { {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 2}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 1}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3}, - /*strides=*/{1, 1}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100}, - /*strides=*/{2, 99}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, // TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a // ptxas bug. #ifndef XLA_TEST_BACKEND_GPU {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25}, - /*strides=*/{5, 4}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, #endif {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2}, - /*strides=*/{3, 3}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36}, - /*strides=*/{4, 5}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2}, - /*strides=*/{2, 2}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + // Regression test for b/73903312: bf16 lacks precision to store result of + // very large windows. Testing with a reasonable window larger than 128. + {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, }; string R2ReduceWindowTestDataToString( @@ -1012,10 +1030,11 @@ string R2ReduceWindowTestDataToString( string str = tensorflow::strings::StrCat( "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__padding_", param.padding == Padding::kSame ? "same" : "valid", // - "__layout_", param.layout[0], "_", param.layout[1], // + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), + "__layout_", param.layout[0], "_", param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { str = tensorflow::strings::StrCat(str, "_bfloat16"); @@ -1043,17 +1062,29 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ComputationDataHandle parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); + std::vector> padding(2); + for (int i = 0; i < 2; ++i) { + padding[i] = {param.pad_low[i], param.pad_high[i]}; + } + auto computation = param.reducer == kAdd + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindow(/*operand=*/parameter, - /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); - auto expected = ReferenceUtil::ReduceWindow2DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + auto reduce_func = param.reducer == kAdd + ? +[](float a, float b) { return a + b; } + : +[](float a, float b) { return std::max(a, b); }; + auto expected = ReferenceUtil::ReduceWindow2DGeneric( + /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func, + /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/padding); ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); @@ -1078,8 +1109,9 @@ XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = { {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, }; INSTANTIATE_TEST_CASE_P( -- 2.7.4