[XLA] Redesign: cleanup client_library_test_base.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 4 May 2018 02:45:59 +0000 (19:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 4 May 2018 17:40:57 +0000 (10:40 -0700)
PiperOrigin-RevId: 195357555

tensorflow/compiler/xla/tests/client_library_test_base.cc
tensorflow/compiler/xla/tests/client_library_test_base.h

index 22660c3..c09e7ea 100644 (file)
@@ -94,28 +94,14 @@ string ClientLibraryTestBase::TestName() const {
   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
 }
 
-template <typename BuilderT>
 StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
-    BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+    XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
   // Build the computation, as a convenience.
   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
   return client_->Execute(computation, arguments, &execution_options_);
 }
 
 StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
-    const Computation& computation,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-    const Shape* shape_with_output_layout) {
-  ExecutionOptions execution_options = execution_options_;
-  if (shape_with_output_layout != nullptr) {
-    *execution_options.mutable_shape_with_output_layout() =
-        *shape_with_output_layout;
-  }
-  return client_->ExecuteAndTransfer(computation, arguments,
-                                     &execution_options);
-}
-
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
     const XlaComputation& computation,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
     const Shape* shape_with_output_layout) {
@@ -128,17 +114,6 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
                                      &execution_options);
 }
 
-template <>
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
-    ComputationBuilder* builder,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-    const Shape* shape_with_output_layout) {
-  // Build the computation, as a convenience.
-  TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
-  return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
-}
-
-template <>
 StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
     XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
     const Shape* shape_with_output_layout) {
@@ -162,18 +137,6 @@ ClientLibraryTestBase::ExecuteAndTransferReference(
                                          &execution_options);
 }
 
-std::unique_ptr<GlobalData> ClientLibraryTestBase::ExecuteOrDie(
-    ComputationBuilder* builder,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
-  return Execute(builder, arguments).ConsumeValueOrDie();
-}
-
-std::unique_ptr<Literal> ClientLibraryTestBase::ExecuteAndTransferOrDie(
-    ComputationBuilder* builder,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
-  return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie();
-}
-
 string ClientLibraryTestBase::ExecuteToString(
     XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
   auto computation_status = builder->Build();
@@ -191,32 +154,6 @@ string ClientLibraryTestBase::ExecuteToString(
   }
 }
 
-string ClientLibraryTestBase::ExecuteToString(
-    ComputationBuilder* builder,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
-  auto computation_status = builder->Build();
-  if (!computation_status.ok()) {
-    return computation_status.status().ToString();
-  }
-  auto computation = computation_status.ConsumeValueOrDie();
-
-  auto result =
-      client_->ExecuteAndTransfer(computation, arguments, &execution_options_);
-  if (!result.ok()) {
-    return result.status().ToString();
-  } else {
-    return result.ValueOrDie()->ToString();
-  }
-}
-
-void ClientLibraryTestBase::ComputeAndCompareR1(
-    ComputationBuilder* builder, const tensorflow::core::Bitmap& expected,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
-  std::unique_ptr<Literal> expected_literal = Literal::CreateR1(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
-                                                  arguments);
-}
-
 void ClientLibraryTestBase::ComputeAndCompareR1(
     XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
@@ -225,18 +162,16 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
                                                   arguments);
 }
 
-template <typename BuilderT>
 void ClientLibraryTestBase::ComputeAndCompareLiteral(
-    BuilderT* builder, const Literal& expected,
+    XlaBuilder* builder, const Literal& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
     const Shape* shape_with_layout) {
   EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
                                                   shape_with_layout));
 }
 
-template <typename BuilderT>
 void ClientLibraryTestBase::ComputeAndCompareLiteral(
-    BuilderT* builder, const Literal& expected,
+    XlaBuilder* builder, const Literal& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
     const Shape* shape_with_layout) {
   EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
@@ -245,7 +180,7 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral(
 
 tensorflow::Status
 ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
-    const xla::Computation& computation, const Literal& expected,
+    const xla::XlaComputation& computation, const Literal& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
     const std::function<void(const Literal& actual,
                              const string& error_message)>& verify_output) {
@@ -271,7 +206,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
 
 tensorflow::Status
 ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
-    const xla::Computation& computation, const Literal& expected,
+    const xla::XlaComputation& computation, const Literal& /*expected*/,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
     const std::function<void(const Literal& actual,
                              const string& error_message)>& verify_output,
@@ -334,28 +269,8 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
   return choose(0);
 }
 
-tensorflow::Status
-ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
-    const xla::XlaComputation& /*computation*/, const Literal& /*expected*/,
-    tensorflow::gtl::ArraySlice<GlobalData*> /*arguments*/,
-    const std::function<void(const Literal& actual,
-                             const string& error_message)>& /*verify_output*/) {
-  return Unimplemented("not yet implemented for XlaComputation");
-}
-
-tensorflow::Status
-ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
-    const xla::XlaComputation& /*computation*/, const Literal& /*expected*/,
-    tensorflow::gtl::ArraySlice<GlobalData*> /*arguments*/,
-    const std::function<void(const Literal& actual,
-                             const string& error_message)>& /*verify_output*/,
-    const Shape* /*output_with_layout*/) {
-  return Unimplemented("not yet implemented for XlaComputation");
-}
-
-template <typename BuilderT>
 tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
-    BuilderT* builder, const Literal& expected,
+    XlaBuilder* builder, const Literal& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
     const Shape* shape_with_layout) {
   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
@@ -412,9 +327,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
   return tensorflow::Status::OK();
 }
 
-template <typename BuilderT>
 tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
-    BuilderT* builder, const Literal& expected,
+    XlaBuilder* builder, const Literal& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
     ErrorSpec error, const Shape* shape_with_layout) {
   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
@@ -484,9 +398,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
   EXPECT_EQ(expected, actual->GetR1U8AsString());
 }
 
-template <typename BuilderT>
 void ClientLibraryTestBase::ComputeAndCompareTuple(
-    BuilderT* builder, const Literal& expected,
+    XlaBuilder* builder, const Literal& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
   auto actual_status = ExecuteAndTransfer(builder, arguments);
   EXPECT_IS_OK(actual_status.status());
@@ -497,9 +410,8 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
   LiteralTestUtil::ExpectEqual(expected, *actual);
 }
 
-template <typename BuilderT>
 void ClientLibraryTestBase::ComputeAndCompareTuple(
-    BuilderT* builder, const Literal& expected,
+    XlaBuilder* builder, const Literal& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
   auto actual_status = ExecuteAndTransfer(builder, arguments);
   EXPECT_IS_OK(actual_status.status());
@@ -511,60 +423,6 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
 }
 
 void ClientLibraryTestBase::ComputeAndCompare(
-    ComputationBuilder* builder, const ComputationDataHandle& operand,
-    tensorflow::gtl::ArraySlice<Literal> arguments) {
-  auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
-  EXPECT_IS_OK(status_or_data);
-  if (!status_or_data.ok()) {
-    return;
-  }
-  std::unique_ptr<Literal> reference, result;
-  std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
-  LiteralTestUtil::ExpectEqual(*reference, *result);
-}
-
-void ClientLibraryTestBase::ComputeAndCompare(
-    ComputationBuilder* builder, const ComputationDataHandle& operand,
-    tensorflow::gtl::ArraySlice<Literal> arguments, ErrorSpec error) {
-  auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
-  EXPECT_IS_OK(status_or_data);
-  if (!status_or_data.ok()) {
-    return;
-  }
-  std::unique_ptr<Literal> reference, result;
-  std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
-  LiteralTestUtil::ExpectNear(*reference, *result, error);
-}
-
-StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
-ClientLibraryTestBase::ComputeValueAndReference(
-    ComputationBuilder* builder, const ComputationDataHandle& operand,
-    tensorflow::gtl::ArraySlice<Literal> arguments) {
-  // Transfer the arguments to the executor service. We put the unique_ptr's
-  // into a vector to keep the data alive on the service until the end of this
-  // function.
-  std::vector<std::unique_ptr<GlobalData>> argument_data;
-  for (const auto& arg : arguments) {
-    TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg));
-    argument_data.push_back(std::move(data));
-  }
-
-  // Create raw pointers to the GlobalData for the rest of the call stack.
-  std::vector<GlobalData*> argument_data_ptr;
-  std::transform(
-      argument_data.begin(), argument_data.end(),
-      std::back_inserter(argument_data_ptr),
-      [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
-
-  TF_ASSIGN_OR_RETURN(
-      auto reference,
-      builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments));
-  TF_ASSIGN_OR_RETURN(auto result,
-                      ExecuteAndTransfer(builder, argument_data_ptr));
-  return std::make_pair(std::move(reference), std::move(result));
-}
-
-void ClientLibraryTestBase::ComputeAndCompare(
     XlaBuilder* builder, tensorflow::gtl::ArraySlice<Literal> arguments) {
   auto status_or_data = ComputeValueAndReference(builder, arguments);
   EXPECT_IS_OK(status_or_data);
@@ -651,8 +509,8 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() {
   return computation_status.ConsumeValueOrDie();
 }
 
-Computation ClientLibraryTestBase::CreateScalarReluSensitivity() {
-  ComputationBuilder builder(client_, "relu_sensitivity");
+XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
+  XlaBuilder builder("relu_sensitivity");
   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
   auto activation = builder.Parameter(0, shape, "activation");
   auto backprop = builder.Parameter(1, shape, "backprop");
@@ -693,14 +551,6 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
   return array;
 }
 
-ComputationDataHandle ClientLibraryTestBase::AddParam(
-    const Literal& argument, ComputationBuilder* builder) {
-  ComputationDataHandle data_handle;
-  arguments_.push_back(CreateParameterAndTransferLiteral(
-      arguments_.size(), argument, "", builder, &data_handle));
-  return data_handle;
-}
-
 XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
                                       XlaBuilder* builder) {
   XlaOp data_handle;
@@ -709,59 +559,10 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
   return data_handle;
 }
 
-ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral(
-    const Literal& literal, ComputationBuilder* builder) {
-  return builder->ConstantLiteral(
-      use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
-}
-
 XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
                                                        XlaBuilder* builder) {
   return builder->ConstantLiteral(
       use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
 }
 
-template void ClientLibraryTestBase::ComputeAndCompareLiteral(
-    ComputationBuilder* builder, const Literal& expected,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-    const Shape* shape_with_layout);
-
-template void ClientLibraryTestBase::ComputeAndCompareLiteral(
-    XlaBuilder* builder, const Literal& expected,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-    const Shape* shape_with_layout);
-
-template void ClientLibraryTestBase::ComputeAndCompareLiteral(
-    ComputationBuilder* builder, const Literal& expected,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
-    const Shape* shape_with_layout);
-
-template void ClientLibraryTestBase::ComputeAndCompareLiteral(
-    XlaBuilder* builder, const Literal& expected,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
-    const Shape* shape_with_layout);
-
-template void ClientLibraryTestBase::ComputeAndCompareTuple(
-    ComputationBuilder* builder, const Literal& expected,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
-template void ClientLibraryTestBase::ComputeAndCompareTuple(
-    XlaBuilder* builder, const Literal& expected,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
-template void ClientLibraryTestBase::ComputeAndCompareTuple(
-    ComputationBuilder* builder, const Literal& expected,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
-
-template void ClientLibraryTestBase::ComputeAndCompareTuple(
-    XlaBuilder* builder, const Literal& expected,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
-
-template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
-    ComputationBuilder* builder,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
-template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
-    XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
 }  // namespace xla
index 32eea7c..e58979a 100644 (file)
@@ -25,10 +25,9 @@ limitations under the License.
 #include "tensorflow/compiler/xla/array3d.h"
 #include "tensorflow/compiler/xla/array4d.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
 #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -91,21 +90,11 @@ class ClientLibraryTestBase : public ::testing::Test {
   // Convenience methods for building and running a computation with the member
   // execution options. Modify execution_options_ in your test if you want to
   // customize the options.
-  template <typename BuilderT>
   StatusOr<std::unique_ptr<GlobalData>> Execute(
-      BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+      XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
 
-  // TODO(b/74197823): Remove the template type 'BuilderT' in all methods once
-  // the migration to XlaBuilder is complete.
-
-  template <typename BuilderT>
   StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
-      BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-      const Shape* shape_with_output_layout = nullptr);
-
-  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
-      const Computation& computation,
-      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+      XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
       const Shape* shape_with_output_layout = nullptr);
 
   StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
@@ -121,101 +110,90 @@ class ClientLibraryTestBase : public ::testing::Test {
       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
       const Shape* shape_with_output_layout = nullptr);
 
-  // Convenience OrDie variants of above methods.
-  std::unique_ptr<GlobalData> ExecuteOrDie(
-      ComputationBuilder* builder,
-      tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-  std::unique_ptr<Literal> ExecuteAndTransferOrDie(
-      ComputationBuilder* builder,
-      tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
   // Run a computation and return its value as a string. If an error
   // occurs, then instead return the error as a string.
   string ExecuteToString(XlaBuilder* builder,
                          tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-  string ExecuteToString(ComputationBuilder* builder,
-                         tensorflow::gtl::ArraySlice<GlobalData*> arguments);
 
   // Convenience methods for building and running a computation, transferring
   // the result, and comparing it to the expected value(s). Methods are
   // templated on the native host type which maps to specific XLA types (See
-  // ComputationBuilder/XlaBuilder for details). For each rank, two forms are
+  // XlaBuilder for details). For each rank, two forms are
   // provided: one for floating point types with an ErrorSpec parameter, and one
   // for integral types without the ErrorSpec parameter.
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR0(BuilderT* builder, NativeT expected,
+  template <typename NativeT>
+  void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR0(BuilderT* builder, NativeT expected,
+  template <typename NativeT>
+  void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                            ErrorSpec error);
 
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR1(BuilderT* builder,
+  template <typename NativeT>
+  void ComputeAndCompareR1(XlaBuilder* builder,
                            tensorflow::gtl::ArraySlice<NativeT> expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR1(BuilderT* builder,
+  template <typename NativeT>
+  void ComputeAndCompareR1(XlaBuilder* builder,
                            tensorflow::gtl::ArraySlice<NativeT> expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                            ErrorSpec error);
 
   // As above, but uses a bitmap to hold the predicate vector to avoid
   // deficiencies of vector<bool>.
-  void ComputeAndCompareR1(ComputationBuilder* builder,
-                           const tensorflow::core::Bitmap& expected,
-                           tensorflow::gtl::ArraySlice<GlobalData*> arguments);
   void ComputeAndCompareR1(XlaBuilder* builder,
                            const tensorflow::core::Bitmap& expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
 
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR2(BuilderT* builder, const Array2D<NativeT>& expected,
+  template <typename NativeT>
+  void ComputeAndCompareR2(XlaBuilder* builder,
+                           const Array2D<NativeT>& expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR2(BuilderT* builder, const Array2D<NativeT>& expected,
+  template <typename NativeT>
+  void ComputeAndCompareR2(XlaBuilder* builder,
+                           const Array2D<NativeT>& expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                            ErrorSpec error);
 
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR3(BuilderT* builder, const Array3D<NativeT>& expected,
+  template <typename NativeT>
+  void ComputeAndCompareR3(XlaBuilder* builder,
+                           const Array3D<NativeT>& expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR3(BuilderT* builder, const Array3D<NativeT>& expected,
+  template <typename NativeT>
+  void ComputeAndCompareR3(XlaBuilder* builder,
+                           const Array3D<NativeT>& expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                            ErrorSpec error);
 
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR4(BuilderT* builder, const Array4D<NativeT>& expected,
+  template <typename NativeT>
+  void ComputeAndCompareR4(XlaBuilder* builder,
+                           const Array4D<NativeT>& expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-  template <typename NativeT, typename BuilderT>
-  void ComputeAndCompareR4(BuilderT* builder, const Array4D<NativeT>& expected,
+  template <typename NativeT>
+  void ComputeAndCompareR4(XlaBuilder* builder,
+                           const Array4D<NativeT>& expected,
                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                            ErrorSpec error);
 
   // Build and run the computation and compare the result with the given
   // literal. shape_with_layout indicates the result layout to request when
   // calling Execute.
-  template <typename BuilderT>
   void ComputeAndCompareLiteral(
-      BuilderT* builder, const Literal& expected,
+      XlaBuilder* builder, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
       const Shape* shape_with_layout = nullptr);
-  template <typename BuilderT>
   void ComputeAndCompareLiteral(
-      BuilderT* builder, const Literal& expected,
+      XlaBuilder* builder, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
       const Shape* shape_with_layout = nullptr);
 
   // ComputeAndCompare variant which returns an error status.
-  template <typename BuilderT>
   tensorflow::Status ComputeAndCompareLiteralWithStatus(
-      BuilderT* builder, const Literal& expected,
+      XlaBuilder* builder, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
       const Shape* shape_with_layout = nullptr);
-  template <typename BuilderT>
   tensorflow::Status ComputeAndCompareLiteralWithStatus(
-      BuilderT* builder, const Literal& expected,
+      XlaBuilder* builder, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
       const Shape* shape_with_layout = nullptr);
 
@@ -227,26 +205,14 @@ class ClientLibraryTestBase : public ::testing::Test {
 
   // Convenience method for running a built computation, transferring the
   // result, and comparing it to the expected tuple literal.
-  template <typename BuilderT>
   void ComputeAndCompareTuple(
-      BuilderT* builder, const Literal& expected,
+      XlaBuilder* builder, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-  template <typename BuilderT>
   void ComputeAndCompareTuple(
-      BuilderT* builder, const Literal& expected,
+      XlaBuilder* builder, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
 
   // Convenience method for running a built computation and comparing the result
-  // with the HloEvaluator.
-  void ComputeAndCompare(ComputationBuilder* builder,
-                         const ComputationDataHandle& operand,
-                         tensorflow::gtl::ArraySlice<Literal> arguments);
-  void ComputeAndCompare(ComputationBuilder* builder,
-                         const ComputationDataHandle& operand,
-                         tensorflow::gtl::ArraySlice<Literal> arguments,
-                         ErrorSpec error);
-
-  // Convenience method for running a built computation and comparing the result
   // with the reference result.
   void ComputeAndCompare(XlaBuilder* builder,
                          tensorflow::gtl::ArraySlice<Literal> arguments);
@@ -257,7 +223,7 @@ class ClientLibraryTestBase : public ::testing::Test {
   // Create scalar operations for use in reductions.
   XlaComputation CreateScalarRelu();
   XlaComputation CreateScalarMax();
-  Computation CreateScalarReluSensitivity();
+  XlaComputation CreateScalarReluSensitivity();
 
   // Special case convenience functions for creating filled arrays.
 
@@ -297,35 +263,26 @@ class ClientLibraryTestBase : public ::testing::Test {
   // server, then stores into "data_handle" the global handle for that
   // parameter. When the use_bfloat16 flag is set but the literal has F32
   // elements, the literal will be converted to BF16 before being transferred.
-  template <typename BuilderT, typename HandleT>
   std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
       int64 parameter_number, const Literal& literal, const string& name,
-      BuilderT* builder, HandleT* data_handle);
+      XlaBuilder* builder, XlaOp* data_handle);
 
   // As above, but the caller can specify the device that the literal is
   // transferred to. If device_handle is nullptr, the literal will be
   // transferred to the default device.
-  template <typename BuilderT, typename HandleT>
   std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
       int64 parameter_number, const Literal& literal, const string& name,
-      const DeviceHandle* device_handle, BuilderT* builder,
-      HandleT* data_handle);
+      const DeviceHandle* device_handle, XlaBuilder* builder,
+      XlaOp* data_handle);
 
   // Creates a parameter instruction and sets the value that will be passed to
   // the computation as specified. This function must be used for all parameters
   // or none and no parameters must be passed when invoking the computation if
   // using this mechanism. If using this mechanism, then each parameter must be
   // set exactly once. The first added parameter gets index 0, then 1 and so on.
-  ComputationDataHandle AddParam(const Literal& argument,
-                                 ComputationBuilder* builder);
   XlaOp AddParam(const Literal& argument, XlaBuilder* builder);
 
   template <class T>
-  ComputationDataHandle AddParam(const Array<T>& argument,
-                                 ComputationBuilder* builder) {
-    return AddParam(*Literal::CreateFromArray(argument), builder);
-  }
-  template <class T>
   XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
     return AddParam(*Literal::CreateFromArray(argument), builder);
   }
@@ -333,18 +290,11 @@ class ClientLibraryTestBase : public ::testing::Test {
   // Creates a constant instruction with the given literal. When the
   // use_bfloat16 flag is set but the literal has F32 elements, the elements
   // will be converted to BF16s.
-  ComputationDataHandle CreateConstantFromLiteral(const Literal& literal,
-                                                  ComputationBuilder* builder);
   XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);
 
   // Creates a constant instruction with the given array. When the use_bfloat16
   // flag is set but the array has float elements, the elements will be
   // converted to bfloat16s.
-  template <typename NativeT>
-  ComputationDataHandle CreateConstantFromArray(const Array<NativeT>& array,
-                                                ComputationBuilder* builder) {
-    return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder);
-  }
 
   template <typename NativeT>
   XlaOp CreateConstantFromArray(const Array<NativeT>& array,
@@ -354,13 +304,6 @@ class ClientLibraryTestBase : public ::testing::Test {
 
   // Same as CreateConstantFromArray, but for scalars.
   template <typename NativeT>
-  ComputationDataHandle CreateConstantFromScalar(NativeT value,
-                                                 ComputationBuilder* builder) {
-    return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value),
-                                     builder);
-  }
-
-  template <typename NativeT>
   XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
     return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value),
                                      builder);
@@ -374,12 +317,12 @@ class ClientLibraryTestBase : public ::testing::Test {
   //
   // When the use_bfloat16 flag is set but NativeT is float, the data will be
   // converted to bfloat16.
-  template <typename NativeT, typename BuilderT, typename HandleT>
+  template <typename NativeT>
   std::unique_ptr<GlobalData> CreateR0Parameter(NativeT value,
                                                 int64 parameter_number,
                                                 const string& name,
-                                                BuilderT* builder,
-                                                HandleT* data_handle);
+                                                XlaBuilder* builder,
+                                                XlaOp* data_handle);
 
   // Creates a parameter instruction that wraps the given values and then stores
   // into "data_handle" the global handle for that parameter.
@@ -389,10 +332,10 @@ class ClientLibraryTestBase : public ::testing::Test {
   //
   // When the use_bfloat16 flag is set but NativeT is float, the data will be
   // converted to bfloat16.
-  template <typename NativeT, typename BuilderT, typename HandleT>
+  template <typename NativeT>
   std::unique_ptr<GlobalData> CreateR1Parameter(
       tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
-      const string& name, BuilderT* builder, HandleT* data_handle);
+      const string& name, XlaBuilder* builder, XlaOp* data_handle);
 
   // Creates a parameter instruction that wraps the given constant array
   // "array_2d" and then stores to "data_handle" the global handle for that
@@ -403,10 +346,10 @@ class ClientLibraryTestBase : public ::testing::Test {
   //
   // When the use_bfloat16 flag is set but NativeT is float, the data will be
   // converted to bfloat16.
-  template <typename NativeT, typename BuilderT, typename HandleT>
+  template <typename NativeT>
   std::unique_ptr<GlobalData> CreateR2Parameter(
       const Array2D<NativeT>& array_2d, int64 parameter_number,
-      const string& name, BuilderT* builder, HandleT* data_handle);
+      const string& name, XlaBuilder* builder, XlaOp* data_handle);
 
   // Creates a parameter instruction that wraps the given constant array
   // "array_3d" and then stores to "data_handle" the global handle for that
@@ -417,10 +360,10 @@ class ClientLibraryTestBase : public ::testing::Test {
   //
   // When the use_bfloat16 flag is set but NativeT is float, the data will be
   // converted to bfloat16.
-  template <typename NativeT, typename BuilderT, typename HandleT>
+  template <typename NativeT>
   std::unique_ptr<GlobalData> CreateR3Parameter(
       const Array3D<NativeT>& array_3d, int64 parameter_number,
-      const string& name, BuilderT* builder, HandleT* data_handle);
+      const string& name, XlaBuilder* builder, XlaOp* data_handle);
 
   // Getter and setter for the use_bfloat16 flag, which indicates whether to run
   // tests with all float-type input/output converted to bfloat16.
@@ -435,21 +378,6 @@ class ClientLibraryTestBase : public ::testing::Test {
   ExecutionOptions execution_options_;
 
  private:
-  // Build and run the computation with all permutations of output layouts.
-  tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts(
-      const xla::Computation& computation, const Literal& expected,
-      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-      const std::function<void(const Literal& actual,
-                               const string& error_message)>& verify_output);
-  // Build and run the computation with all permutations of layouts of all input
-  // arguments.
-  tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts(
-      const xla::Computation& computation, const Literal& expected,
-      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-      const std::function<void(const Literal& actual,
-                               const string& error_message)>& verify_output,
-      const Shape* output_with_layout = nullptr);
-
   tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts(
       const xla::XlaComputation& computation, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
@@ -463,13 +391,6 @@ class ClientLibraryTestBase : public ::testing::Test {
       const Shape* output_with_layout = nullptr);
 
   // Executes the computation and calculates the expected reference value using
-  // the HloEvaluator. Returns two literals in the order of (expected, actual).
-  StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
-  ComputeValueAndReference(ComputationBuilder* builder,
-                           const ComputationDataHandle& operand,
-                           tensorflow::gtl::ArraySlice<Literal> arguments);
-
-  // Executes the computation and calculates the expected reference value using
   // the reference client. Returns two literals in the order of (expected,
   // actual).
   StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
@@ -484,9 +405,9 @@ class ClientLibraryTestBase : public ::testing::Test {
   std::vector<std::unique_ptr<GlobalData>> arguments_;
 };
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR0(
-    BuilderT* builder, NativeT expected,
+    XlaBuilder* builder, NativeT expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
   std::unique_ptr<Literal> expected_literal =
       Literal::CreateR0<NativeT>(expected);
@@ -494,9 +415,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
                                                   arguments);
 }
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR0(
-    BuilderT* builder, NativeT expected,
+    XlaBuilder* builder, NativeT expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
   static_assert(std::is_same<NativeT, float>::value ||
                     std::is_same<NativeT, double>::value ||
@@ -510,9 +431,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
                                                   arguments, error);
 }
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR1(
-    BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+    XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
   std::unique_ptr<Literal> expected_literal =
       Literal::CreateR1<NativeT>(expected);
@@ -520,9 +441,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
                                                   arguments);
 }
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR1(
-    BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+    XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
   static_assert(std::is_same<NativeT, float>::value ||
                     std::is_same<NativeT, double>::value ||
@@ -536,9 +457,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
                                                   arguments, error);
 }
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR2(
-    BuilderT* builder, const Array2D<NativeT>& expected,
+    XlaBuilder* builder, const Array2D<NativeT>& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
   std::unique_ptr<Literal> expected_literal =
       Literal::CreateR2FromArray2D<NativeT>(expected);
@@ -546,9 +467,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
                                                   arguments);
 }
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR2(
-    BuilderT* builder, const Array2D<NativeT>& expected,
+    XlaBuilder* builder, const Array2D<NativeT>& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
   static_assert(std::is_same<NativeT, float>::value ||
                     std::is_same<NativeT, double>::value ||
@@ -562,9 +483,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
                                                   arguments, error);
 }
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR3(
-    BuilderT* builder, const Array3D<NativeT>& expected,
+    XlaBuilder* builder, const Array3D<NativeT>& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
   std::unique_ptr<Literal> expected_literal =
       Literal::CreateR3FromArray3D<NativeT>(expected);
@@ -572,9 +493,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
                                                   arguments);
 }
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR3(
-    BuilderT* builder, const Array3D<NativeT>& expected,
+    XlaBuilder* builder, const Array3D<NativeT>& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
   static_assert(std::is_same<NativeT, float>::value ||
                     std::is_same<NativeT, double>::value ||
@@ -588,9 +509,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
                                                   arguments, error);
 }
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR4(
-    BuilderT* builder, const Array4D<NativeT>& expected,
+    XlaBuilder* builder, const Array4D<NativeT>& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
   std::unique_ptr<Literal> expected_literal =
       Literal::CreateR4FromArray4D<NativeT>(expected);
@@ -598,9 +519,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
                                                   arguments);
 }
 
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
 void ClientLibraryTestBase::ComputeAndCompareR4(
-    BuilderT* builder, const Array4D<NativeT>& expected,
+    XlaBuilder* builder, const Array4D<NativeT>& expected,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
   static_assert(std::is_same<NativeT, float>::value ||
                     std::is_same<NativeT, double>::value ||
@@ -614,10 +535,10 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
                                                   arguments, error);
 }
 
-template <typename NativeT, typename BuilderT, typename HandleT>
+template <typename NativeT>
 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
     NativeT value, int64 parameter_number, const string& name,
-    BuilderT* builder, HandleT* data_handle) {
+    XlaBuilder* builder, XlaOp* data_handle) {
   std::unique_ptr<Literal> literal = Literal::CreateR0(value);
   if (use_bfloat16_ && literal->shape().element_type() == F32) {
     literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
@@ -628,10 +549,10 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
   return data;
 }
 
-template <typename NativeT, typename BuilderT, typename HandleT>
+template <typename NativeT>
 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
     tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
-    const string& name, BuilderT* builder, HandleT* data_handle) {
+    const string& name, XlaBuilder* builder, XlaOp* data_handle) {
   std::unique_ptr<Literal> literal = Literal::CreateR1(values);
   if (use_bfloat16_ && literal->shape().element_type() == F32) {
     literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
@@ -642,10 +563,10 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
   return data;
 }
 
-template <typename NativeT, typename BuilderT, typename HandleT>
+template <typename NativeT>
 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
     const Array2D<NativeT>& array_2d, int64 parameter_number,
-    const string& name, BuilderT* builder, HandleT* data_handle) {
+    const string& name, XlaBuilder* builder, XlaOp* data_handle) {
   std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
   if (use_bfloat16_ && literal->shape().element_type() == F32) {
     literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
@@ -656,10 +577,10 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
   return data;
 }
 
-template <typename NativeT, typename BuilderT, typename HandleT>
+template <typename NativeT>
 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
     const Array3D<NativeT>& array_3d, int64 parameter_number,
-    const string& name, BuilderT* builder, HandleT* data_handle) {
+    const string& name, XlaBuilder* builder, XlaOp* data_handle) {
   std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
   if (use_bfloat16_ && literal->shape().element_type() == F32) {
     literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
@@ -695,23 +616,21 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
   return result;
 }
 
-template <typename BuilderT, typename HandleT>
 std::unique_ptr<GlobalData>
 ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
                                                          const Literal& literal,
                                                          const string& name,
-                                                         BuilderT* builder,
-                                                         HandleT* data_handle) {
+                                                         XlaBuilder* builder,
+                                                         XlaOp* data_handle) {
   return CreateParameterAndTransferLiteral(parameter_number, literal, name,
                                            nullptr, builder, data_handle);
 }
 
-template <typename BuilderT, typename HandleT>
 std::unique_ptr<GlobalData>
 ClientLibraryTestBase::CreateParameterAndTransferLiteral(
     int64 parameter_number, const Literal& literal, const string& name,
-    const DeviceHandle* device_handle, BuilderT* builder,
-    HandleT* data_handle) {
+    const DeviceHandle* device_handle, XlaBuilder* builder,
+    XlaOp* data_handle) {
   const Literal* param_literal = &literal;
   std::unique_ptr<Literal> converted_literal;
   if (use_bfloat16_) {