[XLA] Redesign: implement Tuple and GetTupleElement.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 00:02:59 +0000 (17:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 00:05:17 +0000 (17:05 -0700)
PiperOrigin-RevId: 190698245

tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/service/shape_inference.cc
tensorflow/compiler/xla/service/shape_inference.h
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/client_library_test_base.cc
tensorflow/compiler/xla/tests/client_library_test_base.h
tensorflow/compiler/xla/tests/tuple_test.cc

index fcaf393..7d39701 100644 (file)
@@ -491,11 +491,40 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
 }
 
 XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    std::vector<const Shape*> operand_shape_ptrs;
+    std::vector<Shape> operand_shapes;
+    for (const XlaOp& e : elements) {
+      TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(e));
+      operand_shapes.push_back(shape);
+    }
+    c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+                [](const Shape& shape) { return &shape; });
+    TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+                        ShapeInference::InferVariadicOpShape(
+                            HloOpcode::kTuple, operand_shape_ptrs));
+    return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
+  }());
 }
 
 XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data));
+    if (!ShapeUtil::IsTuple(tuple_shape)) {
+      return InvalidArgument(
+          "Operand to GetTupleElement() is not a tuple; got %s",
+          ShapeUtil::HumanString(tuple_shape).c_str());
+    }
+    *instr.mutable_shape() =
+        ShapeUtil::GetTupleElementShape(tuple_shape, index);
+
+    instr.set_tuple_index(index);
+
+    return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
+                          {tuple_data});
+  }());
 }
 
 XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
index 36456d5..77e12d3 100644 (file)
@@ -1070,6 +1070,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
   for (const HloInstruction* operand : operands) {
     operand_shapes.push_back(&operand->shape());
   }
+  return InferVariadicOpShape(opcode, operand_shapes);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
+    HloOpcode opcode,
+    tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
   return InferVariadicOpShape(OpcodeToVariadicOperation(opcode),
                               operand_shapes);
 }
index 88830e6..9da2c99 100644 (file)
@@ -85,6 +85,9 @@ class ShapeInference {
       tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
   static StatusOr<Shape> InferVariadicOpShape(
       HloOpcode opcode,
+      tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
+  static StatusOr<Shape> InferVariadicOpShape(
+      HloOpcode opcode,
       tensorflow::gtl::ArraySlice<const HloInstruction*> operands);
 
   // Infers the shape produced by applying the given mapping computation shape
index 5ab25f2..2fd97fa 100644 (file)
@@ -1011,6 +1011,8 @@ xla_test(
         "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:computation_builder",
         "//tensorflow/compiler/xla/client:local_client",
+        "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+        "//tensorflow/compiler/xla/client/xla_client:xla_computation",
         "//tensorflow/compiler/xla/tests:client_library_test_base",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
index ec95a68..4a9faef 100644 (file)
@@ -441,8 +441,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
   EXPECT_EQ(expected, actual->GetR1U8AsString());
 }
 
+template <typename BuilderT>
 void ClientLibraryTestBase::ComputeAndCompareTuple(
-    ComputationBuilder* builder, const Literal& expected,
+    BuilderT* builder, const Literal& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
   auto actual_status = ExecuteAndTransfer(builder, arguments);
   EXPECT_IS_OK(actual_status.status());
@@ -453,8 +454,9 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
   LiteralTestUtil::ExpectEqual(expected, *actual);
 }
 
+template <typename BuilderT>
 void ClientLibraryTestBase::ComputeAndCompareTuple(
-    ComputationBuilder* builder, const Literal& expected,
+    BuilderT* builder, const Literal& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
   auto actual_status = ExecuteAndTransfer(builder, arguments);
   EXPECT_IS_OK(actual_status.status());
@@ -619,4 +621,20 @@ template void ClientLibraryTestBase::ComputeAndCompareLiteral(
     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
     const Shape* shape_with_layout);
 
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+    ComputationBuilder* builder, const Literal& expected,
+    tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+    XlaBuilder* builder, const Literal& expected,
+    tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+    ComputationBuilder* builder, const Literal& expected,
+    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+
+template void ClientLibraryTestBase::ComputeAndCompareTuple(
+    XlaBuilder* builder, const Literal& expected,
+    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+
 }  // namespace xla
index 5ff200b..be90f14 100644 (file)
@@ -217,11 +217,13 @@ class ClientLibraryTestBase : public ::testing::Test {
 
   // Convenience method for running a built computation, transferring the
   // result, and comparing it to the expected tuple literal.
+  template <typename BuilderT>
   void ComputeAndCompareTuple(
-      ComputationBuilder* builder, const Literal& expected,
+      BuilderT* builder, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+  template <typename BuilderT>
   void ComputeAndCompareTuple(
-      ComputationBuilder* builder, const Literal& expected,
+      BuilderT* builder, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
 
   // Convenience method for running a built computation and comparing the result
index fa60af4..098be6d 100644 (file)
@@ -20,6 +20,8 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/computation_builder.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -41,7 +43,7 @@ class TupleTest : public ClientLibraryTestBase {
 
 // Tests a tuple-shaped constant.
 XLA_TEST_F(TupleTest, TupleConstant) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   const float constant_scalar = 7.3f;
   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
@@ -54,13 +56,13 @@ XLA_TEST_F(TupleTest, TupleConstant) {
                           Literal::CreateR1<float>(constant_vector).get(),
                           Literal::CreateR2<float>(constant_matrix).get()});
 
-  auto result = builder.ConstantLiteral(*value);
+  builder.ConstantLiteral(*value);
   ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
 }
 
 // Tests a tuple made of scalar constants.
 XLA_TEST_F(TupleTest, TupleScalarConstant) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   const float constant_scalar1 = 7.3f;
   const float constant_scalar2 = 1.2f;
@@ -68,13 +70,13 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
       Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(),
                           Literal::CreateR0<float>(constant_scalar2).get()});
 
-  auto result = builder.ConstantLiteral(*value);
+  builder.ConstantLiteral(*value);
   ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
 }
 
 // Tests the creation of tuple data.
 XLA_TEST_F(TupleTest, TupleCreate) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   const float constant_scalar = 7.3f;
   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
@@ -82,9 +84,9 @@ XLA_TEST_F(TupleTest, TupleCreate) {
       {1.1f, 2.2f, 3.5f},  // row 0
       {4.8f, 5.0f, 6.7f},  // row 1
   };
-  auto result = builder.Tuple({builder.ConstantR0<float>(constant_scalar),
-                               builder.ConstantR1<float>(constant_vector),
-                               builder.ConstantR2<float>(constant_matrix)});
+  builder.Tuple({builder.ConstantR0<float>(constant_scalar),
+                 builder.ConstantR1<float>(constant_vector),
+                 builder.ConstantR2<float>(constant_matrix)});
 
   auto expected =
       Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
@@ -95,9 +97,9 @@ XLA_TEST_F(TupleTest, TupleCreate) {
 
 // Tests the creation of tuple data.
 XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
-  auto result = builder.Tuple(
+  builder.Tuple(
       {builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})});
 
   auto expected = Literal::MakeTuple({Literal::CreateR0<float>(7.0).get(),
@@ -107,15 +109,15 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
 
 // Tests the creation of an empty tuple.
 XLA_TEST_F(TupleTest, EmptyTupleCreate) {
-  ComputationBuilder builder(client_, TestName());
-  auto result = builder.Tuple({});
+  XlaBuilder builder(TestName());
+  builder.Tuple({});
   auto expected = Literal::MakeTuple({});
   ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
 }
 
 // Trivial test for extracting a tuple element with GetTupleElement.
 XLA_TEST_F(TupleTest, GetTupleElement) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
   std::initializer_list<std::initializer_list<float>> constant_matrix = {
       {1.f, 2.f, 3.f},  // row 0
@@ -123,23 +125,23 @@ XLA_TEST_F(TupleTest, GetTupleElement) {
   };
   auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
                                    builder.ConstantR2<float>(constant_matrix)});
-  auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+  builder.GetTupleElement(tuple_data, 1);
   ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
                              error_spec_);
 }
 
 // Trivial test for extracting a tuple element with GetTupleElement.
 XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto tuple_data = builder.Tuple(
       {builder.ConstantR1<float>({}),
        builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))});
-  auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+  builder.GetTupleElement(tuple_data, 1);
   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
 }
 
 XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto value = builder.ConstantR1<float>({4.5f});
   builder.GetTupleElement(value, 1);
   auto result_status = builder.Build();
@@ -152,7 +154,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
 // Extracts both elements from a tuple with GetTupleElement and then adds them
 // together.
 XLA_TEST_F(TupleTest, AddTupleElements) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
   std::initializer_list<std::initializer_list<float>> constant_matrix = {
       {1.f, 2.f, 3.f},  // row 0
@@ -164,22 +166,22 @@ XLA_TEST_F(TupleTest, AddTupleElements) {
   auto matrix_element = builder.GetTupleElement(tuple_data, 1);
   auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
   auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
-  auto result = builder.Add(matrix_element, vector_element,
-                            /*broadcast_dimensions=*/{1});
+  builder.Add(matrix_element, vector_element,
+              /*broadcast_dimensions=*/{1});
 
   Array2D<float> expected({
       {2.f, 4.f, 6.f},  // row 0
       {5.f, 7.f, 9.f},  // row 1
   });
-  ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3}));
-  ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3}));
+  ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3}));
+  ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3}));
   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
 }
 
 // Extracts both elements from a tuple and then puts them into a new tuple in
 // the opposite order.
 XLA_TEST_F(TupleTest, TupleGTEToTuple) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
   std::initializer_list<std::initializer_list<float>> constant_matrix = {
       {1.f, 2.f, 3.f},  // row 0
@@ -187,8 +189,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
   };
   auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
                                    builder.ConstantR2<float>(constant_matrix)});
-  auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
-                                  builder.GetTupleElement(tuple_data, 0)});
+  builder.Tuple({builder.GetTupleElement(tuple_data, 1),
+                 builder.GetTupleElement(tuple_data, 0)});
   auto expected =
       Literal::MakeTuple({Literal::CreateR2<float>(constant_matrix).get(),
                           Literal::CreateR1<float>(constant_vector).get()});
@@ -196,8 +198,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
 }
 
 XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
-  ComputationBuilder b(client_, TestName());
-  ComputationDataHandle v1, v2;
+  XlaBuilder b(TestName());
+  XlaOp v1, v2;
 
   for (bool direction : {false, true}) {
     std::unique_ptr<GlobalData> v1_data =
@@ -210,7 +212,7 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
     auto v2_gt = b.Gt(v2, v1);             // true
     auto v1_v2 = b.Tuple({v1_gt, v2_gt});  // {false, true}
     auto v2_v1 = b.Tuple({v2_gt, v1_gt});  // {true, false}
-    auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
+    b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
     auto expected =
         Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(),
                             Literal::CreateR0<bool>(!direction).get()});
@@ -237,7 +239,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
   //              \                (tuple10)--                     /
   //               \              /           \                   /
   //                -----(GTE 0)--             --(GTE 1)----------
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
   std::initializer_list<std::initializer_list<float>> constant_matrix = {
       {1.f, 2.f, 3.f},  // row 0
@@ -257,8 +259,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
   auto addvectors = builder.Add(vector_from_01, vector_from_10);
   auto addmatrices = builder.Add(matrix_from_01, matrix_from_10);
 
-  auto result = builder.Add(addmatrices, addvectors,
-                            /*broadcast_dimensions=*/{1});
+  builder.Add(addmatrices, addvectors,
+              /*broadcast_dimensions=*/{1});
 
   Array2D<float> expected({
       {4.f, 8.f, 12.f},    // row 0
@@ -269,7 +271,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
 
 XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
   // Tests a selection between tuples with "false" path taken.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -278,8 +280,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
   auto tuple21 = builder.Tuple(
       {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
 
-  auto select =
-      builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+  builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
   auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
                                       Literal::CreateR1<float>(vec1).get()});
   ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
@@ -314,7 +315,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) {
 
 XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
   // Tests a selection between tuples with "true" path taken.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -323,8 +324,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
   auto tuple21 = builder.Tuple(
       {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
 
-  auto select =
-      builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
+  builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
   auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec1).get(),
                                       Literal::CreateR1<float>(vec2).get()});
   ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
@@ -333,7 +333,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
 XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
   // Tests a selection between tuples but the final result is an element of the
   // tuple, not the whole tuple.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -344,7 +344,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
 
   auto select =
       builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
-  auto element = builder.GetTupleElement(select, 0);
+  builder.GetTupleElement(select, 0);
 
   ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
 }
@@ -368,7 +368,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
   //                                /             --(GTE 1)--
   //                               /
   //                          (tuple 21)
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -384,8 +384,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
       builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21);
   auto select2 =
       builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1);
-  auto result = builder.Add(builder.GetTupleElement(select2, 0),
-                            builder.GetTupleElement(select2, 1));
+  builder.Add(builder.GetTupleElement(select2, 0),
+              builder.GetTupleElement(select2, 1));
 
   ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
 }
@@ -394,7 +394,7 @@ XLA_TEST_F(TupleTest,
            DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) {
   // Similar to SelectBetweenTuples, but the constants are shared between the
   // input tuples.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -403,19 +403,18 @@ XLA_TEST_F(TupleTest,
   auto tuple12 = builder.Tuple({c1, c2});
   auto tuple21 = builder.Tuple({c2, c1});
 
-  auto select =
-      builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+  builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+
   auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
                                       Literal::CreateR1<float>(vec1).get()});
   ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
 }
 
 XLA_TEST_F(TupleTest, NestedTuples) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto inner_tuple = builder.Tuple(
       {builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)});
-  auto outer_tuple =
-      builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
+  builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
 
   auto expected_v1 = Literal::CreateR1<float>({1.0, 2.0});
   auto expected_s = Literal::CreateR0<float>(42.0);
@@ -429,7 +428,7 @@ XLA_TEST_F(TupleTest, NestedTuples) {
 }
 
 XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
 
   Shape data_shape = ShapeUtil::MakeShape(F32, {3});
   Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
@@ -460,7 +459,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
 }
 
 XLA_TEST_F(TupleTest, ComplexTuples) {
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   {
     Shape c64r0 = ShapeUtil::MakeShape(C64, {});
     Shape c64r1 = ShapeUtil::MakeShape(C64, {2});