[XLA] Redesign: implement Collapse and migrate reshape_test.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 7 Apr 2018 00:49:47 +0000 (17:49 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 7 Apr 2018 01:21:31 +0000 (18:21 -0700)
PiperOrigin-RevId: 191965245

tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/reshape_test.cc

index e623639..3d0cb35 100644 (file)
@@ -593,7 +593,45 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
 
 XlaOp XlaBuilder::Collapse(const XlaOp& operand,
                            tensorflow::gtl::ArraySlice<int64> dimensions) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    if (dimensions.size() <= 1) {
+      // Not collapsing anything, trivially we can return the operand versus
+      // enqueueing a trivial reshape.
+      return operand;
+    }
+
+    // Out-of-order collapse is not supported.
+    // Checks that the collapsed dimensions are in order and consecutive.
+    for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
+         i < dimensions.size(); ++i) {
+      if (dimensions[i] - 1 != dimensions[i - 1]) {
+        return InvalidArgument(
+            "Collapsed dimensions are not in consecutive order.");
+      }
+    }
+
+    // Create a new sizes vector from the old shape, replacing the collapsed
+    // dimensions by the product of their sizes.
+    TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand));
+
+    VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape);
+    VLOG(3) << "dims to collapse: "
+            << tensorflow::str_util::Join(dimensions, ",");
+
+    std::vector<int64> new_sizes;
+    for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) {
+      if (i <= dimensions.front() || i > dimensions.back()) {
+        new_sizes.push_back(original_shape.dimensions(i));
+      } else {
+        new_sizes.back() *= original_shape.dimensions(i);
+      }
+    }
+
+    VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
+            << "]";
+
+    return Reshape(operand, new_sizes);
+  });
 }
 
 void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
index 072c5cd..0276db9 100644 (file)
@@ -1372,11 +1372,10 @@ xla_test(
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla:test_helpers",
         "//tensorflow/compiler/xla:xla_data_proto",
-        "//tensorflow/compiler/xla/client:computation",
-        "//tensorflow/compiler/xla/client:computation_builder",
         "//tensorflow/compiler/xla/client:global_data",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+        "//tensorflow/compiler/xla/client/xla_client:xla_computation",
         "//tensorflow/compiler/xla/tests:client_library_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
index 02272d6..d7462d5 100644 (file)
@@ -20,11 +20,10 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/array2d.h"
 #include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/reference_util.h"
@@ -53,11 +52,11 @@ class ReshapeTest : public ::testing::WithParamInterface<bool>,
 
 // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension.
 XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   Array2D<float> input_array(1, 1);
   input_array.Fill(1.0f);
   auto input_literal = Literal::CreateR2FromArray2D(input_array);
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
                                                  &builder, &parameter);
   builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -68,9 +67,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
 }
 
 XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateR1<float>({1.0f});
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
                                                  &builder, &parameter);
   builder.Collapse(/*operand=*/parameter, /*dimensions=*/{});
@@ -81,9 +80,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
 }
 
 XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateR1<float>({1.0f});
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
                                                  &builder, &parameter);
   builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0});
@@ -95,11 +94,11 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
 
 // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar.
 XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   Array2D<float> input_array(1, 1);
   input_array.Fill(1.0f);
   auto input_literal = Literal::CreateR2FromArray2D(input_array);
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
                                                  &builder, &parameter);
   auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
@@ -112,15 +111,14 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
 }
 
 XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(1.0f);
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
                                                  &builder, &parameter);
   auto a = builder.Neg(parameter);
-  auto reshape =
-      builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
+  builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
 
   auto expected_literal = Literal::CreateR1<float>({-1.0f});
   ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
@@ -131,10 +129,10 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
 // does not handle zero-sized shapes correctly. Failed last on 2017-11-30
 // with an incorrect result rank.
 XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   Array2D<float> input_array(0, 3);
   auto input_literal = Literal::CreateR2FromArray2D(input_array);
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -147,11 +145,11 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) {
 // does not handle zero-sized shapes correctly. Failed last on 2017-05-15
 // with an incorrect result rank.
 XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   std::unique_ptr<Literal> param0_literal =
       Literal::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
                                                  &builder, &parameter);
   builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -164,10 +162,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) {
 // does not handle zero-sized shapes correctly. Failed last on 2017-11-30
 // with an incorrect result rank.
 XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   Array2D<float> input_array(3, 0);
   auto input_literal = Literal::CreateR2FromArray2D(input_array);
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -178,9 +176,9 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) {
 
 // Collapses a 2-dimensional row vector to 1 dimension.
 XLA_TEST_P(ReshapeTest, Trivial1x3) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -191,9 +189,9 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) {
 
 // Collapses a 2-dimensional column vector to 1 dimension.
 XLA_TEST_P(ReshapeTest, Trivial3x1) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -344,9 +342,9 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) {
 // does not handle zero-sized shapes correctly. Failed last on 2017-11-30
 // with an incorrect result rank.
 XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateFromArray(Array4D<float>(2, 3, 4, 0));
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
@@ -359,10 +357,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) {
 // Reshapes a 2-dimensional array with dimensions that are not just a
 // rearrangement of the originals (split), but no reordering (no shuffle).
 XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
   auto input_literal = Literal::CreateFromArray(*a4x3);
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
@@ -379,9 +377,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
 // with an incorrect result rank.
 //
 XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 6));
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
@@ -394,10 +392,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) {
 // Reshapes a 2-dimensional array with dimensions that are not just a
 // rearrangement of the originals (split), and reorder the input (shuffle).
 XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
   auto input_literal = Literal::CreateFromArray(*a4x3);
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
@@ -421,9 +419,9 @@ static Array3D<float> ArrayForDocR3Tests() {
 }
 
 XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
@@ -436,9 +434,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
 }
 
 XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
@@ -456,9 +454,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
 }
 
 XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
@@ -471,9 +469,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
 }
 
 XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
@@ -491,9 +489,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
 }
 
 XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
@@ -521,12 +519,12 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
 //
 // 1 2 3 4 5 6 1 2 3 4 5 6
 XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   Array4D<float> t2x2x2x3(2, 2, 2, 3);
   auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3);
   t2x2x2x3.FillWithYX(*filler2x3);
   auto input_literal = Literal::CreateFromArray(t2x2x2x3);
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
@@ -540,7 +538,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
 
 // As above, but uses reshape directly.
 XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   Array4D<float> t(2, 1, 2, 2);
   t(0, 0, 0, 0) = 0;
   t(0, 0, 0, 1) = 1;
@@ -551,7 +549,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
   t(1, 0, 1, 0) = 6;
   t(1, 0, 1, 1) = 7;
   auto input_literal = Literal::CreateFromArray(t);
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
@@ -566,7 +564,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
 // Reshape various ranks to a scalar.
 XLA_TEST_P(ReshapeTest, ToScalar) {
   for (int rank = 0; rank < 8; ++rank) {
-    ComputationBuilder b(client_, TestName());
+    XlaBuilder b(TestName());
     std::vector<int64> ones(rank, 1);  // this is {1, ..., 1}.
     std::vector<int64> dimensions(rank);
     std::iota(dimensions.begin(), dimensions.end(), 0);
@@ -574,7 +572,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
     std::vector<int64> zeros(rank, 0);  // this is {0, ..., 0}.
     input_literal.Set<float>(zeros, 83.0f);
 
-    ComputationDataHandle parameter;
+    XlaOp parameter;
     auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                    &b, &parameter);
     b.Reshape(parameter, dimensions, {});
@@ -586,9 +584,9 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
 }
 
 XLA_TEST_P(ReshapeTest, BadDimensions) {
-  ComputationBuilder b(client_, TestName());
+  XlaBuilder b(TestName());
   auto input_literal = Literal::CreateR1<float>({1.0f});
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
                                                  &parameter);
   b.Reshape(parameter, {}, {});
@@ -598,9 +596,9 @@ XLA_TEST_P(ReshapeTest, BadDimensions) {
 }
 
 XLA_TEST_P(ReshapeTest, BadNewSizes) {
-  ComputationBuilder b(client_, TestName());
+  XlaBuilder b(TestName());
   auto input_literal = Literal::CreateR1<float>({1.0f, 2.0f});
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
                                                  &parameter);
   b.Reshape(parameter, {1}, {});
@@ -609,7 +607,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) {
 }
 
 XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   // clang-format off
   auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D<float>{
     {
@@ -635,7 +633,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
   },
        LayoutUtil::MakeLayout({0, 1, 2, 3}));
   // clang-format on
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
 
@@ -646,7 +644,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
       {222, 333, 444, 555, 666, 777, 888, 999},
   });
 
-  Computation computation = builder.Build().ConsumeValueOrDie();
+  XlaComputation computation = builder.Build().ConsumeValueOrDie();
   ExecutionOptions execution_options = execution_options_;
   *execution_options.mutable_shape_with_output_layout() =
       ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8},
@@ -664,13 +662,13 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
 }
 
 XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::unique_ptr<Literal> input_literal = Literal::CreateR2<float>({
       {0, 1, 2, 3, 4, 5, 6, 7},
       {100, 101, 102, 103, 104, 105, 106, 107},
       {200, 201, 202, 203, 204, 205, 206, 207},
   });
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
@@ -691,13 +689,13 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
 
 // Tests R2->R4 reshape with the reshape dimensions {1, 0}.
 XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::unique_ptr<Literal> input_literal = Literal::CreateR2<float>({
       {0, 1, 2, 3, 4, 5, 6, 7},
       {100, 101, 102, 103, 104, 105, 106, 107},
       {200, 201, 202, 203, 204, 205, 206, 207},
   });
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
                                                  &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
@@ -717,7 +715,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
 }
 
 XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::mt19937 rng;
   std::uniform_real_distribution<float> distribution;
   Array4D<float> input(2, 1, 1, 1);
@@ -727,7 +725,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "input", &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
@@ -739,7 +737,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
 }
 
 XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::mt19937 rng;
   std::uniform_real_distribution<float> distribution;
   Array4D<float> input(2, 1, 4, 1);
@@ -749,7 +747,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "input", &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
@@ -762,7 +760,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
 
 // Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}.
 XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::mt19937 rng;
   std::uniform_real_distribution<float> distribution;
   Array4D<float> input(5, 10, 2, 3);
@@ -772,7 +770,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "input", &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3},
@@ -789,7 +787,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
 }
 
 XLA_TEST_P(ReshapeTest, NoopReshape) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::mt19937 rng;
   std::uniform_real_distribution<float> distribution;
   Array4D<float> input_array(2, 3, 5, 7);
@@ -799,12 +797,12 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "input", &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2},
                   /*new_sizes=*/{7, 2, 3, 5});
-  Computation computation = builder.Build().ConsumeValueOrDie();
+  XlaComputation computation = builder.Build().ConsumeValueOrDie();
 
   ExecutionOptions execution_options = execution_options_;
   *execution_options.mutable_shape_with_output_layout() =
@@ -827,12 +825,12 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
 }
 
 XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto literal_1x2x3x4 = Literal::CreateR4<float>(
       {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
         {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
 
-  ComputationDataHandle parameter;
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
                                                  &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3},
@@ -846,8 +844,8 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
       {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
         {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
 
-  ComputationBuilder builder(client_, TestName());
-  ComputationDataHandle parameter;
+  XlaBuilder builder(TestName());
+  XlaOp parameter;
   auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
                                                  &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0},
@@ -880,8 +878,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationBuilder builder(client_, TestName());
-  ComputationDataHandle parameter;
+  XlaBuilder builder(TestName());
+  XlaOp parameter;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "input", &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
@@ -909,8 +907,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationBuilder builder(client_, TestName());
-  ComputationDataHandle parameter;
+  XlaBuilder builder(TestName());
+  XlaOp parameter;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "input", &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
@@ -938,8 +936,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationBuilder builder(client_, TestName());
-  ComputationDataHandle parameter;
+  XlaBuilder builder(TestName());
+  XlaOp parameter;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "input", &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
@@ -968,8 +966,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
-  ComputationBuilder builder(client_, TestName());
-  ComputationDataHandle parameter;
+  XlaBuilder builder(TestName());
+  XlaOp parameter;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "input", &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
@@ -997,8 +995,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
   std::unique_ptr<Literal> input_literal =
       Literal::CreateR4FromArray4DWithLayout(
           input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
-  ComputationBuilder builder(client_, TestName());
-  ComputationDataHandle parameter;
+  XlaBuilder builder(TestName());
+  XlaOp parameter;
   auto input_data = CreateParameterAndTransferLiteral(
       0, *input_literal, "input", &builder, &parameter);
   builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3},