[XLA] Redesign: implement and test ReduceWindow.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 10 Apr 2018 23:32:05 +0000 (16:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 10 Apr 2018 23:36:42 +0000 (16:36 -0700)
PiperOrigin-RevId: 192368401

tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/reduce_window_test.cc

index 9e4b9cc..c869eb2 100644 (file)
@@ -1425,7 +1425,21 @@ XlaOp XlaBuilder::ReduceWindow(
     const XlaComputation& computation,
     tensorflow::gtl::ArraySlice<int64> window_dimensions,
     tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+    TF_RETURN_IF_ERROR(
+        ValidatePaddingValues(AsInt64Slice(operand_shape.dimensions()),
+                              window_dimensions, window_strides));
+
+    std::vector<std::pair<int64, int64>> padding_values =
+        MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions,
+                    window_strides, padding);
+    return ReduceWindowWithGeneralPadding(operand, init_value, computation,
+                                          window_dimensions, window_strides,
+                                          padding_values);
+  });
 }
 
 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
@@ -1434,7 +1448,25 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
     tensorflow::gtl::ArraySlice<int64> window_dimensions,
     tensorflow::gtl::ArraySlice<int64> window_strides,
     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+    TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
+    TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
+                        computation.GetProgramShape());
+    TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
+                        MakeWindow(window_dimensions, window_strides, padding,
+                                   /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
+    TF_ASSIGN_OR_RETURN(
+        *instr.mutable_shape(),
+        ShapeInference::InferReduceWindowShape(operand_shape, init_shape,
+                                               instr.window(), to_apply_shape));
+
+    AddCalledComputation(computation, &instr);
+    return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
+                          {operand, init_value});
+  });
 }
 
 XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
index 67c53c6..a615acd 100644 (file)
@@ -1091,10 +1091,11 @@ xla_test_library(
         "//tensorflow/compiler/xla:reference_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:xla_data_proto",
-        "//tensorflow/compiler/xla/client:computation_builder",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/client:padding",
         "//tensorflow/compiler/xla/client/lib:arithmetic",
+        "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+        "//tensorflow/compiler/xla/client/xla_client:xla_computation",
         "//tensorflow/compiler/xla/tests:client_library_test_base",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
index 8dd24f1..8ef980e 100644 (file)
@@ -21,10 +21,11 @@ limitations under the License.
 #include "tensorflow/compiler/xla/array2d.h"
 #include "tensorflow/compiler/xla/array3d.h"
 #include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/reference_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -63,11 +64,9 @@ class ReduceWindowTestBase : public ClientLibraryTestBase {
 class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
                          public ReduceWindowTestBase {
  public:
-  ReduceWindowTest() : builder_(client_, TestName()) {
-    set_use_bfloat16(GetParam());
-  }
+  ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
 
-  void ReduceWindowAdd(const ComputationDataHandle& input,
+  void ReduceWindowAdd(const XlaOp& input,
                        tensorflow::gtl::ArraySlice<int64> window_dimensions,
                        tensorflow::gtl::ArraySlice<int64> window_strides,
                        Padding padding) {
@@ -78,16 +77,17 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
                           window_dimensions, window_strides, padding);
   }
 
-  void ReduceWindowMax(const ComputationDataHandle& input,
+  void ReduceWindowMax(const XlaOp& input,
                        tensorflow::gtl::ArraySlice<int64> window_dimensions,
                        tensorflow::gtl::ArraySlice<int64> window_strides,
                        Padding padding) {
     auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_);
-    builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions,
-                          window_strides, padding);
+    builder_.ReduceWindow(input, init,
+                          CreateScalarMaxComputation(FloatType(), &builder_),
+                          window_dimensions, window_strides, padding);
   }
 
-  void ReduceWindowMin(const ComputationDataHandle& input,
+  void ReduceWindowMin(const XlaOp& input,
                        tensorflow::gtl::ArraySlice<int64> window_dimensions,
                        tensorflow::gtl::ArraySlice<int64> window_strides,
                        Padding padding) {
@@ -97,7 +97,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
                           window_dimensions, window_strides, padding);
   }
 
-  ComputationBuilder builder_;
+  XlaBuilder builder_;
 };
 
 TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
@@ -310,7 +310,7 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
   auto rhs = b->Parameter(1, scalar, "rhs");
   b->Min(b->Add(lhs, rhs),
          CreateConstantFromLiteral(*Literal::CreateR0<float>(8.0f), b.get()));
-  Computation reduce_fn = b->BuildAndNoteError();
+  XlaComputation reduce_fn = b->BuildAndNoteError();
 
   builder_.ReduceWindow(
       input,
@@ -338,7 +338,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
-  ComputationDataHandle input;
+  XlaOp input;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "parameter", &builder_, &input);
 
@@ -406,7 +406,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationDataHandle input;
+  XlaOp input;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "parameter", &builder_, &input);
 
@@ -428,7 +428,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationDataHandle input;
+  XlaOp input;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "parameter", &builder_, &input);
 
@@ -450,7 +450,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationDataHandle input;
+  XlaOp input;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "parameter", &builder_, &input);
 
@@ -551,7 +551,7 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
 
 TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
   Array2D<float> input_array(6, 4, 1.0f);
-  ComputationDataHandle input = builder_.Broadcast(
+  XlaOp input = builder_.Broadcast(
       CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4});
 
   Padding padding = Padding::kSame;
@@ -610,7 +610,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
   R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
 
   void DoIt() {
-    ComputationBuilder b(client_, TestName());
+    XlaBuilder b(TestName());
     const auto& param = ::testing::get<0>(GetParam());
 
     const float kInitValue = 0.0f;
@@ -621,7 +621,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
     std::unique_ptr<Literal> input_literal =
         Literal::CreateR4FromArray4DWithLayout(
             input, LayoutUtil::MakeLayout(param.layout));
-    ComputationDataHandle parameter;
+    XlaOp parameter;
     auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
                                                        &b, &parameter);
 
@@ -962,7 +962,7 @@ class R3ReduceWindowTest : public ReduceWindowTestBase,
 };
 
 TEST_P(R3ReduceWindowTest, Add) {
-  ComputationBuilder b(client_, TestName());
+  XlaBuilder b(TestName());
   const auto& param = ::testing::get<0>(GetParam());
   CHECK(param.reducer == kAdd);
 
@@ -973,7 +973,7 @@ TEST_P(R3ReduceWindowTest, Add) {
       Literal::CreateR3FromArray3DWithLayout(
           input, LayoutUtil::MakeLayout(param.layout));
 
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
                                                      &b, &parameter);
   auto init_value =
@@ -1100,7 +1100,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
   R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
 
   void DoIt() {
-    ComputationBuilder b(client_, TestName());
+    XlaBuilder b(TestName());
     const auto& param = ::testing::get<0>(GetParam());
     CHECK(param.reducer == kAdd);
 
@@ -1110,7 +1110,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
         Literal::CreateR2FromArray2DWithLayout(
             input, LayoutUtil::MakeLayout(param.layout));
 
-    ComputationDataHandle parameter;
+    XlaOp parameter;
     auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
                                                        &b, &parameter);
     std::vector<std::pair<int64, int64>> padding(2);
@@ -1298,7 +1298,7 @@ class R1ReduceWindowTest : public ReduceWindowTestBase,
 };
 
 TEST_P(R1ReduceWindowTest, DoIt) {
-  ComputationBuilder b(client_, TestName());
+  XlaBuilder b(TestName());
   const auto& param = ::testing::get<0>(GetParam());
   CHECK(param.reducer == kAdd || param.reducer == kMax);
 
@@ -1307,7 +1307,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
   std::iota(std::begin(input_vector), std::end(input_vector), 0);
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector));
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
                                                      &b, &parameter);