[XLA] Redesign: migrate while_test to use XlaBuilder, and implement the ops needed...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Apr 2018 01:18:11 +0000 (18:18 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 3 Apr 2018 01:21:03 +0000 (18:21 -0700)
Also, when a module has embedded computaitons, the service side would complain if the instruction names are not unique in the scope of the module. To ensure instruction names are unique in module, use both the computation id and instruction id as suffix.

PiperOrigin-RevId: 191379697

tensorflow/compiler/xla/client/lib/BUILD
tensorflow/compiler/xla/client/lib/arithmetic.cc
tensorflow/compiler/xla/client/lib/arithmetic.h
tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/client/xla_client/xla_builder.h
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/reduce_test.cc
tensorflow/compiler/xla/tests/while_test.cc

index d02972f..f4673a8 100644 (file)
@@ -24,6 +24,8 @@ cc_library(
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:computation_builder",
+        "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+        "//tensorflow/compiler/xla/client/xla_client:xla_computation",
         "//tensorflow/core:lib",
     ],
 )
index 24048a1..63df449 100644 (file)
@@ -26,6 +26,7 @@ limitations under the License.
 
 namespace xla {
 namespace {
+
 using InstructionGenerator =
     ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&,
                               const ComputationDataHandle&);
@@ -47,6 +48,27 @@ Computation CreateScalarComputation(const string& name, PrimitiveType type,
   generator(b.get(), lhs, rhs);
   return b->BuildAndNoteError();
 }
+
+using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);
+
+XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
+                                       XlaBuilder* builder,
+                                       XlaOpGenerator generator) {
+  std::unique_ptr<XlaBuilder> b;
+  if (type == PRED) {
+    b = builder->CreateSubBuilder(name);
+  } else {
+    b = builder->CreateSubBuilder(
+        tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type)));
+  }
+
+  const Shape scalar = ShapeUtil::MakeShape(type, {});
+  auto lhs = b->Parameter(0, scalar, "lhs");
+  auto rhs = b->Parameter(1, scalar, "rhs");
+  generator(b.get(), lhs, rhs);
+  return b->BuildAndNoteError();
+}
+
 }  // namespace
 
 Computation CreateScalarAddComputation(PrimitiveType type,
@@ -60,7 +82,7 @@ Computation CreateScalarAddComputation(PrimitiveType type,
 Computation CreateScalarMultiplyComputation(PrimitiveType type,
                                             ComputationBuilder* builder) {
   return CreateScalarComputation(
-      "add", type, builder,
+      "mul", type, builder,
       [](ComputationBuilder* b, const ComputationDataHandle& lhs,
          const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); });
 }
@@ -114,4 +136,75 @@ StatusOr<ComputationDataHandle> Any(const ComputationDataHandle& predicates,
   return builder->Reduce(predicates, f, logical_or, all_dimensions);
 }
 
+XlaComputation CreateScalarAddComputation(PrimitiveType type,
+                                          XlaBuilder* builder) {
+  return CreateScalarComputation(
+      "add", type, builder,
+      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+        return b->Add(lhs, rhs);
+      });
+}
+
+XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
+                                               XlaBuilder* builder) {
+  return CreateScalarComputation(
+      "mul", type, builder,
+      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+        return b->Mul(lhs, rhs);
+      });
+}
+
+XlaComputation CreateScalarGeComputation(PrimitiveType type,
+                                         XlaBuilder* builder) {
+  return CreateScalarComputation(
+      "ge", type, builder,
+      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+        return b->Ge(lhs, rhs);
+      });
+}
+
+XlaComputation CreateScalarMaxComputation(PrimitiveType type,
+                                          XlaBuilder* builder) {
+  return CreateScalarComputation(
+      "max", type, builder,
+      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+        return b->Max(lhs, rhs);
+      });
+}
+
+XlaComputation CreateScalarMinComputation(PrimitiveType type,
+                                          XlaBuilder* builder) {
+  return CreateScalarComputation(
+      "min", type, builder,
+      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+        return b->Min(lhs, rhs);
+      });
+}
+
+XlaComputation CreateScalarAndComputation(XlaBuilder* builder) {
+  return CreateScalarComputation(
+      "and", PRED, builder,
+      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+        return b->And(lhs, rhs);
+      });
+}
+
+XlaComputation CreateScalarOrComputation(XlaBuilder* builder) {
+  return CreateScalarComputation(
+      "or", PRED, builder,
+      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
+        return b->Or(lhs, rhs);
+      });
+}
+
+StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder) {
+  auto f = builder->ConstantR0<bool>(false);
+  XlaComputation logical_or = CreateScalarOrComputation(builder);
+  TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
+                      builder->GetShape(predicates));
+  std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
+  std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
+  return builder->Reduce(predicates, f, logical_or, all_dimensions);
+}
+
 }  // namespace xla
index ae89784..f4d3fc8 100644 (file)
@@ -20,6 +20,8 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
 namespace xla {
@@ -56,6 +58,48 @@ Computation CreateScalarOrComputation(ComputationBuilder* builder);
 StatusOr<ComputationDataHandle> Any(const ComputationDataHandle& predicates,
                                     ComputationBuilder* builder);
 
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar add computation and returns it.
+XlaComputation CreateScalarAddComputation(PrimitiveType type,
+                                          XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar multiply computation and returns it.
+XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
+                                               XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar ge computation and returns it.
+XlaComputation CreateScalarGeComputation(PrimitiveType type,
+                                         XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar max computation and returns it.
+XlaComputation CreateScalarMaxComputation(PrimitiveType type,
+                                          XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar min computation and returns it.
+XlaComputation CreateScalarMinComputation(PrimitiveType type,
+                                          XlaBuilder* builder);
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar logical AND computation and returns it.
+XlaComputation CreateScalarAndComputation(XlaBuilder* builder);
+
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Creates a scalar logical OR computation and returns it.
+XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
+
+// TODO(b/74197823): This is a part of a NOT YET ready refactor.
+//
+// Returns whether any predicate in "predicates" is set.
+//
+// Note: if predicates is zero-sized, Any() vacuously returns false.
+StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder);
+
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_
index ec23621..c2e661c 100644 (file)
@@ -81,7 +81,7 @@ StatusOr<Shape> XlaOp::GetShape() const {
 }
 
 XlaBuilder::XlaBuilder(const string& computation_name)
-    : name_(computation_name) {}
+    : name_(computation_name), unique_id_(GetUniqueId()) {}
 
 XlaBuilder::~XlaBuilder() {}
 
@@ -179,7 +179,6 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
   }
 
   HloComputationProto entry;
-  entry.set_name(name_);
 
   {
     int64 root_id;
@@ -193,9 +192,9 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
     entry.add_instructions()->Swap(&instruction);
   }
 
-  const int64 id = GetUniqueId();
-  entry.set_id(id);
-  XlaComputation computation(id);
+  entry.set_id(unique_id_);
+  entry.set_name(StrCat(name_, entry.id()));  // Ensure that the name is unique.
+  XlaComputation computation(entry.id());
   HloModuleProto* module = computation.mutable_proto();
   module->set_name(entry.name());
   module->set_id(entry.id());
@@ -407,12 +406,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
         ShapeInference::InferCallShape(operand_shape_ptrs,
                                        /*to_apply=*/called_program_shape));
 
-    // Add called computation.
-    instr.add_called_computation_ids(
-        computation.proto().entry_computation_id());
-    for (const HloComputationProto& e : computation.proto().computations()) {
-      embedded_.insert({e.id(), e});
-    }
+    AddCalledComputation(computation, &instr);
 
     return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
   });
@@ -470,7 +464,22 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand,
                         tensorflow::gtl::ArraySlice<int64> start_indices,
                         tensorflow::gtl::ArraySlice<int64> limit_indices,
                         tensorflow::gtl::ArraySlice<int64> strides) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+    TF_ASSIGN_OR_RETURN(
+        *instr.mutable_shape(),
+        ShapeInference::InferSliceShape(operand_shape, start_indices,
+                                        limit_indices, strides));
+    for (int i = 0; i < start_indices.size(); i++) {
+      auto* slice_config = instr.add_slice_dimensions();
+      slice_config->set_start(start_indices[i]);
+      slice_config->set_limit(limit_indices[i]);
+      slice_config->set_stride(strides[i]);
+    }
+
+    return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
+  });
 }
 
 XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
@@ -485,7 +494,20 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
 
 XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
                                      const XlaOp& start_indices) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+    TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
+    TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
+                        GetShape(start_indices));
+    TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+                        ShapeInference::InferDynamicUpdateSliceShape(
+                            operand_shape, update_shape, start_indices_shape));
+
+    return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
+                          {operand, update, start_indices});
+  });
 }
 
 XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
@@ -620,12 +642,29 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
 }
 
 XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+
+    DotDimensionNumbers dimension_numbers;
+    dimension_numbers.add_lhs_contracting_dimensions(
+        lhs_shape.dimensions_size() == 1 ? 0 : 1);
+    dimension_numbers.add_rhs_contracting_dimensions(0);
+    return DotGeneral(lhs, rhs, dimension_numbers);
+  });
 }
 
 XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
                              const DotDimensionNumbers& dimension_numbers) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+    TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
+    TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+                        ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
+                                                        dimension_numbers));
+    *instr.mutable_dot_dimension_numbers() = dimension_numbers;
+    return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
+  });
 }
 
 XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
@@ -860,7 +899,14 @@ XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
 
 XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
                                      PrimitiveType new_element_type) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+    TF_ASSIGN_OR_RETURN(
+        *instr.mutable_shape(),
+        ShapeInference::InferConvertShape(operand_shape, new_element_type));
+    return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
+  });
 }
 
 XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
@@ -894,19 +940,64 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
   return UnimplementedOp();
 }
 
+XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
+                        tensorflow::gtl::ArraySlice<XlaOp> parameters,
+                        const Shape& shape) {
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+
+    // Check the number of parameters per RNG distribution.
+    switch (distribution) {
+      case RandomDistribution::RNG_NORMAL:
+      case RandomDistribution::RNG_UNIFORM:
+        if (parameters.size() != 2) {
+          return InvalidArgument(
+              "RNG distribution (%s) expects 2 parameters, but got %ld",
+              RandomDistribution_Name(distribution).c_str(), parameters.size());
+        }
+        break;
+      default:
+        LOG(FATAL) << "unhandled distribution " << distribution;
+    }
+
+    TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
+    *instr.mutable_shape() = shape;
+
+    instr.set_distribution(distribution);
+
+    return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
+  });
+}
+
 XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma,
                             const Shape& shape) {
-  return UnimplementedOp();
+  return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
 }
 
 XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b,
                              const Shape& shape) {
-  return UnimplementedOp();
+  return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
 }
 
 XlaOp XlaBuilder::While(const XlaComputation& condition,
                         const XlaComputation& body, const XlaOp& init) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+
+    // Infer shape.
+    TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
+    TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
+                        condition.GetProgramShape());
+    TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init));
+    TF_ASSIGN_OR_RETURN(
+        *instr.mutable_shape(),
+        ShapeInference::InferWhileShape(condition_program_shape,
+                                        body_program_shape, init_shape));
+    // Body comes before condition computation in the vector.
+    AddCalledComputation(body, &instr);
+    AddCalledComputation(condition, &instr);
+    return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
+  });
 }
 
 XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
@@ -926,7 +1017,27 @@ XlaOp XlaBuilder::Reduce(
     const XlaOp& operand, const XlaOp& init_value,
     const XlaComputation& computation,
     tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
-  return UnimplementedOp();
+  return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+    TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
+    TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
+                        computation.GetProgramShape());
+    TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+                        ShapeInference::InferReduceShape(
+                            operand_shape, init_shape, dimensions_to_reduce,
+                            called_program_shape));
+
+    for (int64 dim : dimensions_to_reduce) {
+      instr.add_dimensions(dim);
+    }
+
+    AddCalledComputation(computation, &instr);
+
+    return AddInstruction(std::move(instr), HloOpcode::kReduce,
+                          {operand, init_value});
+  });
 }
 
 XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
@@ -1109,10 +1220,10 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(
   instr.set_id(handle);
   instr.set_opcode(HloOpcodeString(opcode));
   if (instr.name().empty()) {
-    instr.set_name(StrCat(instr.opcode(), ".", handle));
+    instr.set_name(StrCat(instr.opcode(), ".", unique_id_, ".", handle));
   } else {
     // Append the handle to make sure the name is unique.
-    instr.set_name(StrCat(instr.name(), ".", handle));
+    instr.set_name(StrCat(instr.name(), ".", unique_id_, ".", handle));
   }
   for (const auto& operand : operands) {
     if (operand.builder_ == nullptr) {
@@ -1138,6 +1249,14 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(
   return op;
 }
 
+void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
+                                      HloInstructionProto* instr) {
+  instr->add_called_computation_ids(computation.proto().entry_computation_id());
+  for (const HloComputationProto& e : computation.proto().computations()) {
+    embedded_.insert({e.id(), e});
+  }
+}
+
 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
     const XlaOp& op) const {
   TF_RETURN_IF_ERROR(first_error_);
index f43101d..0673b86 100644 (file)
@@ -803,6 +803,9 @@ class XlaBuilder {
       HloInstructionProto&& instr, HloOpcode opcode,
       tensorflow::gtl::ArraySlice<XlaOp> operands = {});
 
+  void AddCalledComputation(const XlaComputation& computation,
+                            HloInstructionProto* instr);
+
   // Notes that the error occurred by:
   // * storing it internally and capturing a backtrace if it's the first error
   //   (this deferred value will be produced on the call to Build())
@@ -829,6 +832,10 @@ class XlaBuilder {
   XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
                   const XlaOp& ehs);
 
+  XlaOp RngOp(RandomDistribution distribution,
+              tensorflow::gtl::ArraySlice<XlaOp> parameters,
+              const Shape& shape);
+
   StatusOr<XlaOp> InDimBroadcast(
       const Shape& shape, const XlaOp& operand,
       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
@@ -846,7 +853,8 @@ class XlaBuilder {
   // computation and fills the root_id in the pointer.
   StatusOr<ProgramShape> GetProgramShape(int64* root_id);
 
-  string name_;  // Name to use for the built computation.
+  string name_;      // Name to use for the built computation.
+  int64 unique_id_;  // The unique id for the built computation.
 
   // The first error encountered while building the computation.
   // This is OK until the first error is encountered.
index 9cead12..aba61fb 100644 (file)
@@ -347,10 +347,10 @@ xla_test(
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client:client_library",
-        "//tensorflow/compiler/xla/client:computation",
-        "//tensorflow/compiler/xla/client:computation_builder",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/client/lib:arithmetic",
+        "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+        "//tensorflow/compiler/xla/client/xla_client:xla_computation",
         "//tensorflow/compiler/xla/service:platform_util",
         "//tensorflow/compiler/xla/tests:client_library_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
index 3a097a0..d24927d 100644 (file)
@@ -57,6 +57,11 @@ limitations under the License.
 namespace xla {
 namespace {
 
+using FuncGeneratorForType = Computation (*)(PrimitiveType,
+                                             ComputationBuilder*);
+
+using FuncGenerator = Computation (*)(ComputationBuilder*);
+
 class ReduceTest : public ClientLibraryTestBase {
  protected:
   ReduceTest() {
@@ -755,53 +760,57 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
 }
 
 XLA_TEST_F(ReduceTest, VectorizedReduce_Add) {
-  RunVectorizedReduceTest(CreateScalarAddComputation,
-                          [](float a, float b) { return a + b; },
-                          [](int32 a, int32 b) {
-                            return static_cast<int32>(static_cast<uint32>(a) +
-                                                      static_cast<uint32>(b));
-                          },
-                          [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0);
+  RunVectorizedReduceTest(
+      static_cast<FuncGeneratorForType>(CreateScalarAddComputation),
+      [](float a, float b) { return a + b; },
+      [](int32 a, int32 b) {
+        return static_cast<int32>(static_cast<uint32>(a) +
+                                  static_cast<uint32>(b));
+      },
+      [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0);
 }
 
 XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) {
-  RunVectorizedReduceTest(CreateScalarMultiplyComputation,
-                          [](float a, float b) { return a * b; },
-                          [](int32 a, int32 b) {
-                            return static_cast<int32>(static_cast<uint32>(a) *
-                                                      static_cast<uint32>(b));
-                          },
-                          [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1);
+  RunVectorizedReduceTest(
+      static_cast<FuncGeneratorForType>(CreateScalarMultiplyComputation),
+      [](float a, float b) { return a * b; },
+      [](int32 a, int32 b) {
+        return static_cast<int32>(static_cast<uint32>(a) *
+                                  static_cast<uint32>(b));
+      },
+      [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1);
 }
 
 XLA_TEST_F(ReduceTest, VectorizedReduce_Max) {
-  RunVectorizedReduceTest(CreateScalarMaxComputation,
-                          [](float a, float b) { return std::max(a, b); },
-                          [](int32 a, int32 b) { return std::max(a, b); },
-                          [](uint32 a, uint32 b) { return std::max(a, b); },
-                          std::numeric_limits<float>::min(),
-                          std::numeric_limits<int32>::min(),
-                          std::numeric_limits<uint32>::min());
+  RunVectorizedReduceTest(
+      static_cast<FuncGeneratorForType>(CreateScalarMaxComputation),
+      [](float a, float b) { return std::max(a, b); },
+      [](int32 a, int32 b) { return std::max(a, b); },
+      [](uint32 a, uint32 b) { return std::max(a, b); },
+      std::numeric_limits<float>::min(), std::numeric_limits<int32>::min(),
+      std::numeric_limits<uint32>::min());
 }
 
 XLA_TEST_F(ReduceTest, VectorizedReduce_Min) {
-  RunVectorizedReduceTest(CreateScalarMinComputation,
-                          [](float a, float b) { return std::min(a, b); },
-                          [](int32 a, int32 b) { return std::min(a, b); },
-                          [](uint32 a, uint32 b) { return std::min(a, b); },
-                          std::numeric_limits<float>::max(),
-                          std::numeric_limits<int32>::max(),
-                          std::numeric_limits<uint32>::max());
+  RunVectorizedReduceTest(
+      static_cast<FuncGeneratorForType>(CreateScalarMinComputation),
+      [](float a, float b) { return std::min(a, b); },
+      [](int32 a, int32 b) { return std::min(a, b); },
+      [](uint32 a, uint32 b) { return std::min(a, b); },
+      std::numeric_limits<float>::max(), std::numeric_limits<int32>::max(),
+      std::numeric_limits<uint32>::max());
 }
 
 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) {
   RunVectorizedReduceTestForType<bool>(
-      CreateScalarAndComputation, [](bool a, bool b) { return a && b; }, true);
+      static_cast<FuncGenerator>(CreateScalarAndComputation),
+      [](bool a, bool b) { return a && b; }, true);
 }
 
 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) {
   RunVectorizedReduceTestForType<bool>(
-      CreateScalarOrComputation, [](bool a, bool b) { return a || b; }, false);
+      static_cast<FuncGenerator>(CreateScalarOrComputation),
+      [](bool a, bool b) { return a || b; }, false);
 }
 
 class ReduceR3ToR2Test : public ReduceTest,
index 33d457c..89ce2ce 100644 (file)
@@ -18,10 +18,10 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/service/platform_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -54,29 +54,28 @@ TEST_F(WhileTest, WhileWithScalarS32Result) {
   auto result_shape = ShapeUtil::MakeShape(S32, {});
 
   // Create a computation for the condition: repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     builder.Gt(builder.ConstantR0<int32>(5), prev);
     condition = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a computation for the body: add 1 to the result variable.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto input = builder.ConstantR0<int32>(1);
-    auto result = builder.Add(input, prev);
+    builder.Add(input, prev);
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto init = builder.ConstantR0<int32>(0);
-  auto result = builder.While(condition, body, init);
-  auto shape = builder.GetShape(result).ConsumeValueOrDie();
+  builder.While(condition, body, init);
 
   ComputeAndCompareR0<int32>(&builder, 5, {});
 }
@@ -91,29 +90,28 @@ TEST_F(WhileTest, WhileWithScalarS64Result) {
   auto result_shape = ShapeUtil::MakeShape(S64, {});
 
   // Create a computation for the condition: repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     builder.Gt(builder.ConstantR0<int64>(5), prev);
     condition = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a computation for the body: add 1 to the result variable.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto input = builder.ConstantR0<int64>(1);
-    auto result = builder.Add(input, prev);
+    builder.Add(input, prev);
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto init = builder.ConstantR0<int64>(0);
-  auto result = builder.While(condition, body, init);
-  auto shape = builder.GetShape(result).ConsumeValueOrDie();
+  builder.While(condition, body, init);
 
   ComputeAndCompareR0<int64>(&builder, 5, {});
 }
@@ -123,31 +121,30 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
   auto orig_shape = ShapeUtil::MakeShape(S32, {2});
 
   // Create a computation for the condition: repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     builder.Gt(builder.ConstantR0<int32>(5), prev);
     condition = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a computation for the body: add 1 to the result variable.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto input = builder.ConstantR0<int32>(1);
-    auto result = builder.Add(input, prev);
+    builder.Add(input, prev);
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
                              builder.ConstantR0<int32>(0),
                              CreateScalarAddComputation(S32, &builder), {0});
-  auto result = builder.While(condition, body, init);
-  auto shape = builder.GetShape(result).ConsumeValueOrDie();
+  builder.While(condition, body, init);
 
   ComputeAndCompareR0<int32>(&builder, 5, {});
 }
@@ -156,28 +153,28 @@ TEST_F(WhileTest, WhileWithPredicateResult) {
   auto result_shape = ShapeUtil::MakeShape(PRED, {});
 
   // Create a computation for the condition: run until condition is true.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     builder.Ne(builder.ConstantR0<bool>(true), prev);
     condition = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a computation for the body: or condition with true.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
-    auto result = builder.Or(prev, builder.ConstantR0<bool>(true));
+    builder.Or(prev, builder.ConstantR0<bool>(true));
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto init = builder.Ne(builder.ConstantR0<bool>(false),
                          builder.ConstantR0<bool>(true));
-  auto result = builder.While(condition, body, init);
+  builder.While(condition, body, init);
 
   ComputeAndCompareR0<bool>(&builder, true, {});
 }
@@ -194,9 +191,9 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
   Shape result_shape = ShapeUtil::MakeShape(F32, {0});
 
   // Create a computation for the reduction.
-  Computation add;
+  XlaComputation add;
   {
-    ComputationBuilder builder(client_, "add");
+    XlaBuilder builder("add");
     auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
     builder.Add(x, y);
@@ -205,33 +202,34 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
 
   // Create a computation for the condition.
   // Repeat until the sum of the result vector is less than 15.5f.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
                               /*dimensions_to_reduce=*/{0});
-    auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+    builder.Gt(builder.ConstantR0<float>(15.5f), sum);
     condition = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a computation for the body.
   // Add a constant vector of 1.f to the result vector.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto input = builder.ConstantR1<float>({});
-    auto result = builder.Add(input, prev);
+    builder.Add(input, prev);
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.ConstantR1<float>({});
   auto result = builder.While(condition, body, init);
-  VLOG(2) << "while = " << ShapeUtil::HumanString(
-                               *builder.GetShape(result).ConsumeValueOrDie());
+  VLOG(2) << "while = "
+          << ShapeUtil::HumanString(
+                 builder.GetShape(result).ConsumeValueOrDie());
 
   ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
 }
@@ -247,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) {
   Shape result_shape = ShapeUtil::MakeShape(F32, {8});
 
   // Create a computation for the reduction.
-  Computation add;
+  XlaComputation add;
   {
-    ComputationBuilder builder(client_, "add");
+    XlaBuilder builder("add");
     auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
     builder.Add(x, y);
@@ -258,33 +256,34 @@ TEST_F(WhileTest, WhileWithVectorResult) {
 
   // Create a computation for the condition.
   // Repeat until the sum of the result vector is less than 5.5f.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
                               /*dimensions_to_reduce=*/{0});
-    auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+    builder.Gt(builder.ConstantR0<float>(15.5f), sum);
     condition = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a computation for the body.
   // Add a constant vector of 1.f to the result vector.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto input = builder.ConstantR1<float>(8, 0.125f);
-    auto result = builder.Add(input, prev);
+    builder.Add(input, prev);
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.ConstantR1<float>(8, 0.f);
   auto result = builder.While(condition, body, init);
-  VLOG(2) << "while = " << ShapeUtil::HumanString(
-                               *builder.GetShape(result).ConsumeValueOrDie());
+  VLOG(2) << "while = "
+          << ShapeUtil::HumanString(
+                 builder.GetShape(result).ConsumeValueOrDie());
 
   // Individual elements with increase by 1/8 each time through the loop, so
   // the sum will increase by 1.0.  It will first be >15.5 when the elements
@@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
   Shape result_shape = ShapeUtil::MakeShape(F32, {8});
 
   // Create a computation for the reduction.
-  Computation add;
+  XlaComputation add;
   {
-    ComputationBuilder builder(client_, "add");
+    XlaBuilder builder("add");
     auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
     builder.Add(x, y);
@@ -317,34 +316,34 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
 
   // Create a computation for the condition.
   // Repeat until the sum of the result vector is less than 5.5f.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
                               /*dimensions_to_reduce=*/{0});
-    auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+    builder.Gt(builder.ConstantR0<float>(15.5f), sum);
     condition = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a computation for the body.
   // Add a constant vector of 1.f to the result vector.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto input = builder.ConstantR1<float>(8, 0.125f);
-    auto result = builder.Add(input, prev);
+    builder.Add(input, prev);
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.ConstantR1<float>(8, 0.f);
   auto result = builder.While(condition, body, init);
   VLOG(2) << "while = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(result).ConsumeValueOrDie());
+                 builder.GetShape(result).ConsumeValueOrDie());
   builder.Tuple({result});
 
   // Individual elements with increase by 1/8 each time through the loop, so
@@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
   // Create a computation for the condition.
   // Repeat for N iterations.
   const int N = 2;
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Gt(builder.ConstantR0<int32>(N), iteration);
@@ -377,28 +376,28 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
 
   // Create a computation for the body.
   // Add 1 to the iteration variable and permute the weights.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     auto w1 = builder.GetTupleElement(prev, 1);
     auto w2 = builder.GetTupleElement(prev, 2);
     auto w3 = builder.GetTupleElement(prev, 3);
-    auto result = builder.Tuple(
+    builder.Tuple(
         {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.Tuple(
       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
        builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
   auto result = builder.While(condition, body, init);
   VLOG(2) << "result = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(result).ConsumeValueOrDie());
+                 builder.GetShape(result).ConsumeValueOrDie());
 
   auto expected_counter = Literal::CreateR0<int32>(N);
   auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f});
@@ -419,9 +418,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
   // Create a computation for the condition.
   // Repeat for N iterations.
   const int N = 2;
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Gt(builder.ConstantR0<int32>(N), iteration);
@@ -430,21 +429,21 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
 
   // Create a computation for the body.
   // Add 1 to the iteration variable permute the weights.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     auto w1 = builder.GetTupleElement(prev, 1);
     auto w2 = builder.GetTupleElement(prev, 2);
     auto w3 = builder.GetTupleElement(prev, 3);
-    auto result = builder.Tuple(
+    builder.Tuple(
         {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.Tuple(
       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
        builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
@@ -455,7 +454,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
   auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3));
   VLOG(2) << "result = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(result).ConsumeValueOrDie());
+                 builder.GetShape(result).ConsumeValueOrDie());
   std::vector<float> expected = {6.f, 6.f, 6.f};
   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
 }
@@ -474,9 +473,9 @@ TEST_F(WhileTest, WhileWithTupleResult) {
 
   // Create a computation for the condition.
   // Repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -486,26 +485,27 @@ TEST_F(WhileTest, WhileWithTupleResult) {
   // Create a computation for the body.
   // Add 1 to the iteration variable and add a constant vector of 1.0f to
   // the weight variable, both of which are tuple elements.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     auto weights = builder.GetTupleElement(prev, 1);
     auto input = builder.ConstantR1<float>(10, 1.f);
     auto new_weights = builder.Add(weights, input);
-    auto result = builder.Tuple(
+    builder.Tuple(
         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.Tuple(
       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
   auto result = builder.While(condition, body, init);
-  VLOG(2) << "while = " << ShapeUtil::HumanString(
-                               *builder.GetShape(result).ConsumeValueOrDie());
+  VLOG(2) << "while = "
+          << ShapeUtil::HumanString(
+                 builder.GetShape(result).ConsumeValueOrDie());
 
   auto expected_counter = Literal::CreateR0<int32>(5);
   auto expected_data = Literal::CreateR1<float>(
@@ -523,9 +523,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
 
   // Create a computation for the condition.
   // Repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -534,27 +534,27 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
 
   // Create a computation for the body.
   // Add 1 to the iteration variable and or the predicate with true
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     auto pred = builder.GetTupleElement(prev, 1);
     auto new_pred = builder.Or(pred, builder.ConstantR0<bool>(true));
-    auto result = builder.Tuple(
+    builder.Tuple(
         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred});
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.Tuple({builder.ConstantR0<int32>(0),
                              builder.Ne(builder.ConstantR0<bool>(false),
                                         builder.ConstantR0<bool>(true))});
   auto result = builder.While(condition, body, init);
   VLOG(2) << "while = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(result).ConsumeValueOrDie());
+                 builder.GetShape(result).ConsumeValueOrDie());
 
   auto expected_counter = Literal::CreateR0<int32>(5);
   auto expected_predicate = Literal::CreateR0<bool>(true);
@@ -570,9 +570,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
 
   // Create a computation for the condition.
   // Repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -582,25 +582,24 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
   // Create a computation for the body.
   // Add 1 to the iteration variable and set the other tuple element to a
   // constant.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
-    auto result =
-        builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
-                       builder.ConstantR0<int32>(7)});
+    builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
+                   builder.ConstantR0<int32>(7)});
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.Tuple(
       {builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)});
   auto result = builder.While(condition, body, init);
   VLOG(2) << "while = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(result).ConsumeValueOrDie());
+                 builder.GetShape(result).ConsumeValueOrDie());
 
   auto expected_counter = Literal::CreateR0<int32>(5);
   auto expected_data = Literal::CreateR0<int32>(7);
@@ -631,20 +630,20 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
 
   // Create a computation for the condition.
   // Repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   const int c1 = 5;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Lt(iteration, builder.ConstantR0<int32>(c1));
     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
   }
 
-  Computation condition2;
+  XlaComputation condition2;
   const int c2 = 7;
   {
-    ComputationBuilder builder(client_, "condition2");
+    XlaBuilder builder("condition2");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Lt(iteration, builder.ConstantR0<int32>(c2));
@@ -654,34 +653,34 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
   // Create a computation for the body.
   // Add 1 to the iteration variable and add a constant vector of 1.0f to
   // the weight variable, both of which are tuple elements.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     auto weights = builder.GetTupleElement(prev, 1);
     auto input = builder.ConstantR1<float>(10, 1.f);
     auto new_weights = builder.Add(weights, input);
-    auto result = builder.Tuple(
+    builder.Tuple(
         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
   }
 
-  Computation body2;
+  XlaComputation body2;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     auto weights = builder.GetTupleElement(prev, 1);
     auto input = builder.ConstantR1<float>(10, 1.f);
     auto new_weights = builder.Add(weights, input);
-    auto result = builder.Tuple(
+    builder.Tuple(
         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
     TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.Tuple(
       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
   auto while1 = builder.While(condition, body, init);
@@ -692,11 +691,11 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
   auto while_result2 = builder.GetTupleElement(while2, 1);
   VLOG(2) << "while_result2 = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(while_result2).ConsumeValueOrDie());
+                 builder.GetShape(while_result2).ConsumeValueOrDie());
   auto result = builder.Add(while_result1, while_result2);
   VLOG(2) << "result = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(result).ConsumeValueOrDie());
+                 builder.GetShape(result).ConsumeValueOrDie());
   const float sum = c1 + c2;
   std::vector<float> expected(10, sum);
   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -710,20 +709,20 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
 
   // Create a computation for the condition.
   // Repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   const int c1 = 5;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Lt(iteration, builder.ConstantR0<int32>(c1));
     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
   }
 
-  Computation condition2;
+  XlaComputation condition2;
   const int c2 = 7;
   {
-    ComputationBuilder builder(client_, "condition2");
+    XlaBuilder builder("condition2");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Lt(iteration, builder.ConstantR0<int32>(c2));
@@ -733,21 +732,21 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
   // Create a computation for the body.
   // Add 1 to the iteration variable and add a constant vector of 1.0f to
   // the weight variable, both of which are tuple elements.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     auto weights = builder.GetTupleElement(prev, 1);
     auto input = builder.ConstantR1<float>(10, 1.f);
     auto new_weights = builder.Add(weights, input);
-    auto result = builder.Tuple(
+    builder.Tuple(
         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.Tuple(
       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
   auto while1 = builder.While(condition, body, init);
@@ -758,11 +757,11 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
   auto while_result2 = builder.GetTupleElement(while2, 1);
   VLOG(2) << "while_result2 = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(while_result2).ConsumeValueOrDie());
+                 builder.GetShape(while_result2).ConsumeValueOrDie());
   auto result = builder.Add(while_result1, while_result2);
   VLOG(2) << "result = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(result).ConsumeValueOrDie());
+                 builder.GetShape(result).ConsumeValueOrDie());
   const float sum = c1 + c2;
   std::vector<float> expected(10, sum);
   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -777,20 +776,20 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
 
   // Create a computation for the condition.
   // Repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   const int c1 = 5;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Lt(iteration, builder.ConstantR0<int32>(c1));
     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
   }
 
-  Computation condition2;
+  XlaComputation condition2;
   const int c2 = 7;
   {
-    ComputationBuilder builder(client_, "condition2");
+    XlaBuilder builder("condition2");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Lt(iteration, builder.ConstantR0<int32>(c2));
@@ -800,21 +799,21 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
   // Create a computation for the body.
   // Add 1 to the iteration variable and add a constant vector of 1.0f to
   // the weight variable, both of which are tuple elements.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     auto weights = builder.GetTupleElement(prev, 1);
     auto input = builder.ConstantR1<float>(10, 1.f);
     auto new_weights = builder.Add(weights, input);
-    auto result = builder.Tuple(
+    builder.Tuple(
         {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.Tuple(
       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
   auto while1 = builder.While(condition, body, init);
@@ -824,11 +823,11 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
   auto while_result2 = builder.GetTupleElement(while2, 1);
   VLOG(2) << "while_result2 = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(while_result2).ConsumeValueOrDie());
+                 builder.GetShape(while_result2).ConsumeValueOrDie());
   auto result = builder.Add(while_result1, while_result2);
   VLOG(2) << "result = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(result).ConsumeValueOrDie());
+                 builder.GetShape(result).ConsumeValueOrDie());
   const float sum = c1 + c2;
   std::vector<float> expected(10, sum);
   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -844,9 +843,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
 
   // Create a computation for the condition.
   // Repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -856,9 +855,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
   // Create a computation for the body.
   // Add 1 to the iteration variable and add a constant vector of 1.0f to
   // the weight variable, both of which are tuple elements.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     // TupleElement 0
     auto iteration = builder.GetTupleElement(prev, 0);
@@ -873,18 +872,18 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
     // UpdateSlice.
     auto out1 = builder.DynamicUpdateSlice(input, update, starts);
 
-    auto result = builder.Tuple({out0, out1});
+    builder.Tuple({out0, out1});
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, "while");
+  XlaBuilder builder("while");
   auto init = builder.Tuple(
       {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
   auto result = builder.While(condition, body, init);
   VLOG(2) << "while = "
           << ShapeUtil::HumanString(
-                 *builder.GetShape(result).ConsumeValueOrDie());
+                 builder.GetShape(result).ConsumeValueOrDie());
 
   auto expected_counter = Literal::CreateR0<int32>(5);
   auto expected_data = Literal::CreateR1<float>(
@@ -915,18 +914,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) {
 
   // Create a computation for the condition: repeat for count iterations.
   auto build_condition = [this, v6s32](int count) {
-    ComputationBuilder builder(client_, TestName());
+    XlaBuilder builder(TestName());
     auto prev = builder.Reshape(
         builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0},
-          {});
+        {});
     builder.Gt(builder.ConstantR0<int32>(count), prev);
     return builder.Build().ConsumeValueOrDie();
   };
 
   // Create a computation for the body: add 1 to the result variable.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, v6s32, "prev");
     auto inc = builder.ConcatInDim(
         {builder.ConstantR1<int32>({1}),
@@ -934,16 +933,15 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) {
                             builder.ConstantR0<int32>(100),
                             ShapeUtil::MakeShape(S32, {5}))},
         0);
-    auto result = builder.Add(inc, prev);
+    builder.Add(inc, prev);
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
   auto while_loop = [this, &body, build_condition](int count) {
-    ComputationBuilder builder(client_, TestName());
+    XlaBuilder builder(TestName());
     auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0});
-    auto result = builder.While(build_condition(count), body, init);
-    auto shape = builder.GetShape(result).ConsumeValueOrDie();
+    builder.While(build_condition(count), body, init);
     return builder.Build();
   };
 
@@ -1107,9 +1105,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
   auto inner_result_shape = ShapeUtil::MakeTupleShape(
       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
 
-  Computation inner_condition;
+  XlaComputation inner_condition;
   {
-    ComputationBuilder builder(client_, "inner_condition");
+    XlaBuilder builder("inner_condition");
     auto params = builder.Parameter(0, inner_result_shape, "prev");
     auto i = builder.GetTupleElement(params, 0);
     builder.Lt(i, builder.ConstantR0<int32>(7));
@@ -1118,9 +1116,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
 
   // Creates a computation for the outer loop condition:
   // repeat while result < 30.
-  Computation outer_condition;
+  XlaComputation outer_condition;
   {
-    ComputationBuilder builder(client_, "outer_condition");
+    XlaBuilder builder("outer_condition");
     auto prev = builder.Parameter(0, outer_result_shape, "prev");
     builder.Lt(prev, builder.ConstantR0<int32>(30));
     outer_condition = builder.Build().ConsumeValueOrDie();
@@ -1128,34 +1126,33 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
 
   // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
   // `result`.
-  Computation inner_body;
+  XlaComputation inner_body;
   {
-    ComputationBuilder builder(client_, "inner_body");
+    XlaBuilder builder("inner_body");
     auto params = builder.Parameter(0, inner_result_shape, "prev");
     auto i = builder.GetTupleElement(params, 0);
     auto result = builder.GetTupleElement(params, 1);
     i = builder.Add(builder.ConstantR0<int32>(1), i);
     result = builder.Add(builder.ConstantR0<int32>(2), result);
-    auto output = builder.Tuple({i, result});
+    builder.Tuple({i, result});
     inner_body = builder.Build().ConsumeValueOrDie();
   }
 
   // Creates a computation for the outer loop: run the inner loop with i = 0.
-  Computation outer_body;
+  XlaComputation outer_body;
   {
-    ComputationBuilder builder(client_, "outer_body");
+    XlaBuilder builder("outer_body");
     auto prev = builder.Parameter(0, outer_result_shape, "prev");
     auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev});
     auto result = builder.While(inner_condition, inner_body, init);
-    auto output = builder.GetTupleElement(result, 1);
+    builder.GetTupleElement(result, 1);
     outer_body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto init = builder.ConstantR0<int32>(0);
-  auto result = builder.While(outer_condition, outer_body, init);
-  auto shape = builder.GetShape(result).ConsumeValueOrDie();
+  builder.While(outer_condition, outer_body, init);
 
   ComputeAndCompareR0<int32>(&builder, 42, {});
 }
@@ -1170,18 +1167,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
   auto result_shape = ShapeUtil::MakeShape(S32, {});
 
   // Create a computation for the condition: repeat for 5 iterations.
-  Computation condition_callee;
+  XlaComputation condition_callee;
   {
-    ComputationBuilder builder(client_, "condition_callee");
+    XlaBuilder builder("condition_callee");
     auto prev = builder.Parameter(0, result_shape, "prev");
     builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)});
 
     condition_callee = builder.Build().ConsumeValueOrDie();
   }
 
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto result = builder.Call(condition_callee, {prev});
     builder.GetTupleElement(result, 0);
@@ -1189,20 +1186,19 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
   }
 
   // Create a computation for the body: add 1 to the result variable.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, result_shape, "prev");
     auto input = builder.ConstantR0<int32>(1);
-    auto result = builder.Add(input, prev);
+    builder.Add(input, prev);
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While node with computations for the condition and the body.
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto init = builder.ConstantR0<int32>(0);
-  auto result = builder.While(condition, body, init);
-  auto shape = builder.GetShape(result).ConsumeValueOrDie();
+  builder.While(condition, body, init);
 
   ComputeAndCompareR0<int32>(&builder, 5, {});
 }
@@ -1214,28 +1210,28 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
       {scalar_s32, matrix_shape, matrix_shape, matrix_shape});
 
   // Create a computation for the condition: repeat for 5 iterations.
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client_, "condition");
+    XlaBuilder builder("condition");
     auto state = builder.Parameter(0, while_shape, "state");
     builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0));
     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
   }
 
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client_, "body");
+    XlaBuilder builder("body");
     auto state = builder.Parameter(0, while_shape, "state");
     auto indvar = builder.GetTupleElement(state, 0);
     auto input_0 = builder.GetTupleElement(state, 1);
     auto input_1 = builder.GetTupleElement(state, 2);
     auto output = builder.Tanh(builder.Dot(input_0, input_1));
     auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1));
-    auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output});
+    builder.Tuple({indvar_next, input_0, input_1, output});
     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
   }
 
-  ComputationBuilder builder(client_, TestName());
+  XlaBuilder builder(TestName());
   auto matrix_input = builder.Parameter(0, matrix_shape, "matrix");
   auto init = builder.Tuple(
       {builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input});
@@ -1268,9 +1264,9 @@ void BM_WhileLoop(int num_iters) {
 
   // Create while condition computation with 'loop_limit'.
   const int32 loop_limit = 100;
-  Computation condition;
+  XlaComputation condition;
   {
-    ComputationBuilder builder(client, "condition");
+    XlaBuilder builder("condition");
     auto prev = builder.Parameter(0, loop_state_shape, "prev");
     auto iteration = builder.GetTupleElement(prev, 0);
     builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit));
@@ -1278,9 +1274,9 @@ void BM_WhileLoop(int num_iters) {
   }
 
   // Create while body computation with unit loop increment.
-  Computation body;
+  XlaComputation body;
   {
-    ComputationBuilder builder(client, "body");
+    XlaBuilder builder("body");
     auto prev = builder.Parameter(0, loop_state_shape, "prev");
     // TupleElement 0
     auto iteration = builder.GetTupleElement(prev, 0);
@@ -1294,12 +1290,12 @@ void BM_WhileLoop(int num_iters) {
     auto starts = builder.ConstantR1<int32>({0, 0, 0});
     // UpdateSlice.
     auto out1 = builder.DynamicUpdateSlice(input, update, starts);
-    auto result = builder.Tuple({out0, out1});
+    builder.Tuple({out0, out1});
     body = builder.Build().ConsumeValueOrDie();
   }
 
   // Create a While instruction.
-  ComputationBuilder builder(client, "while");
+  XlaBuilder builder("while");
   auto zero = builder.ConstantR0<float>(0.0);
   auto input = builder.Broadcast(zero, {seq_len, 1024, 1024});
   auto init = builder.Tuple({builder.ConstantR0<int32>(0), input});