From 213a98d893105945540e0169faa124ac7e1200ba Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 17:03:03 -0700 Subject: [PATCH] [XLA] Redesign: deprecate ComputationBuilder. PiperOrigin-RevId: 195335330 --- tensorflow/compiler/xla/client/computation.h | 2 + .../compiler/xla/client/computation_builder.h | 2 + tensorflow/compiler/xla/client/lib/BUILD | 5 +- tensorflow/compiler/xla/client/lib/arithmetic.cc | 90 +------------------- tensorflow/compiler/xla/client/lib/arithmetic.h | 55 ++---------- tensorflow/compiler/xla/client/lib/testing.cc | 16 ++-- tensorflow/compiler/xla/client/lib/testing.h | 1 - tensorflow/compiler/xla/service/BUILD | 10 +-- tensorflow/compiler/xla/service/cpu/BUILD | 4 +- .../compiler/xla/service/cpu/sample_harness.cc | 10 +-- .../compiler/xla/service/hlo_cost_analysis_test.cc | 73 +++++++--------- .../compiler/xla/service/hlo_evaluator_test.cc | 10 +-- .../xla/service/hlo_tfgraph_builder_test.cc | 1 - .../compiler/xla/service/transpose_folding_test.cc | 10 +-- .../xla/service/zero_sized_hlo_elimination_test.cc | 1 - tensorflow/compiler/xla/tests/BUILD | 18 ---- .../xla/tests/local_client_aot_test_helper.cc | 19 +++-- .../compiler/xla/tests/set_return_value_test.cc | 98 ---------------------- .../compiler/xla/tests/vector_ops_simple_test.cc | 3 +- 19 files changed, 85 insertions(+), 343 deletions(-) delete mode 100644 tensorflow/compiler/xla/tests/set_return_value_test.cc diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h index a53fc9e..9a1bcde 100644 --- a/tensorflow/compiler/xla/client/computation.h +++ b/tensorflow/compiler/xla/client/computation.h @@ -30,6 +30,8 @@ namespace xla { // Wraps a ComputationHandle protobuf with a lifetime. Computation is // movable and not copyable to capture the same kind of unique // ownership that std::unique_ptr represents. +// +// TODO(b/74197823): Deprecated. Use XlaComputation instead. class Computation { public: // Creates a null Computation. diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 9431c2c..ac1eb91 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -48,6 +48,8 @@ namespace xla { // deferred from being handled until Build() is called. // // Thread-compatible. +// +// TODO(b/74197823): Deprecated. Use XlaBuilder instead. class ComputationBuilder { public: // client: client in which to build the computation. diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 59c4a53..d49d959 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -22,8 +22,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", @@ -43,9 +41,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 63df449..a1d3479 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,7 +17,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -27,28 +28,6 @@ limitations under the License. namespace xla { namespace { -using InstructionGenerator = - ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&, - const ComputationDataHandle&); - -Computation CreateScalarComputation(const string& name, PrimitiveType type, - ComputationBuilder* builder, - InstructionGenerator generator) { - std::unique_ptr b; - if (type == PRED) { - b = builder->CreateSubBuilder(name); - } else { - b = builder->CreateSubBuilder( - tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); - } - - const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - generator(b.get(), lhs, rhs); - return b->BuildAndNoteError(); -} - using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, @@ -71,71 +50,6 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, } // namespace -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "add", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Add(lhs, rhs); }); -} - -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "mul", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); }); -} - -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "ge", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Ge(lhs, rhs); }); -} - -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "max", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Max(lhs, rhs); }); -} - -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "min", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); }); -} - -Computation CreateScalarAndComputation(ComputationBuilder* builder) { - return CreateScalarComputation( - "and", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->And(lhs, rhs); }); -} - -Computation CreateScalarOrComputation(ComputationBuilder* builder) { - return CreateScalarComputation( - "or", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); }); -} - -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder) { - auto f = builder->ConstantR0(false); - Computation logical_or = CreateScalarOrComputation(builder); - TF_ASSIGN_OR_RETURN(std::unique_ptr predicates_shape, - builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(*predicates_shape)); - std::iota(all_dimensions.begin(), all_dimensions.end(), 0); - return builder->Reduce(predicates, f, logical_or, all_dimensions); -} - XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index f4d3fc8..64b6b7d 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -18,8 +18,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -27,74 +25,31 @@ limitations under the License. namespace xla { // Creates a scalar add computation and returns it. -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar multiply computation and returns it. -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar ge computation and returns it. -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar max computation and returns it. -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar min computation and returns it. -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar logical AND computation and returns it. -Computation CreateScalarAndComputation(ComputationBuilder* builder); - -// Creates a scalar logical OR computation and returns it. -Computation CreateScalarOrComputation(ComputationBuilder* builder); - -// Returns whether any predicate in "predicates" is set. -// -// Note: if predicates is zero-sized, Any() vacuously returns false. -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder); - -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// -// Creates a scalar add computation and returns it. XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar multiply computation and returns it. XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar ge computation and returns it. XlaComputation CreateScalarGeComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar max computation and returns it. XlaComputation CreateScalarMaxComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar min computation and returns it. XlaComputation CreateScalarMinComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar logical AND computation and returns it. XlaComputation CreateScalarAndComputation(XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Creates a scalar logical OR computation and returns it. XlaComputation CreateScalarOrComputation(XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 311dc4b..9cd87f7 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -46,16 +45,14 @@ int64 DataSizeOfShape(const Shape& shape) { return total_size; } -// Create a ComputationDataHandle for an op what generates fake data with the -// given shape. -ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, - ComputationBuilder* builder) { +// Creates a XlaOp for an op what generates fake data with the given shape. +XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { return builder->Broadcast( builder->ConstantLiteral(Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } - std::vector parts; + std::vector parts; for (const Shape& s : shape.tuple_shapes()) { parts.push_back(BuildFakeDataOpOnDevice(s, builder)); } @@ -64,11 +61,10 @@ ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { - ComputationBuilder b( - client, + XlaBuilder b( tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); - Computation computation = b.Build().ConsumeValueOrDie(); + XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape; diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 1dc2622..9e06141 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0b8b22b..9c362d8 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -233,7 +233,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1669,10 +1669,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2406,7 +2406,6 @@ tf_cc_test( srcs = ["hlo_tfgraph_builder_test.cc"], deps = [ ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", @@ -2475,7 +2474,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2512,6 +2511,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index cb81e41..7e6d58c 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -365,10 +365,10 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index b3f4609..167aa4a 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -48,13 +48,13 @@ int main(int argc, char** argv) { client->TransferToServer(*param1_literal).ConsumeValueOrDie(); // Build computation. - xla::ComputationBuilder builder(client, ""); + xla::XlaBuilder builder(""); auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); auto add = builder.Add(p1, p0, {0}); - xla::StatusOr computation_status = builder.Build(); - xla::Computation computation = computation_status.ConsumeValueOrDie(); + xla::StatusOr computation_status = builder.Build(); + xla::XlaComputation computation = computation_status.ConsumeValueOrDie(); // Execute and transfer result of computation. xla::ExecutionProfile profile; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 81cc7c4..16fdda8 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -20,16 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/logging.h" @@ -58,11 +55,10 @@ class HloCostAnalysisTest : public ::testing::Test { // whitebox accesses to the user computation built from the client, // as shown in the BuildHloGraph functions below. service_(static_cast(ClientLibrary::GetXlaService( - static_cast(client_)->platform()))), - computation_tracker_(service_->computation_tracker()) { + static_cast(client_)->platform()))) { // Create a computation for a unary user function: x => exp(x + 0.5) { - ComputationBuilder builder(client_, "add_and_exp"); + XlaBuilder builder("add_and_exp"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto half = builder.ConstantR0(0.5); builder.Exp(builder.Add(x, half)); @@ -73,7 +69,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary user function: (x, y) => x + y { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -84,7 +80,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) { - ComputationBuilder builder(client_, "sigmoid"); + XlaBuilder builder("sigmoid"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto one = builder.ConstantR0(1.0); builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); @@ -95,7 +91,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary max function: (x, y) => max (x, y) { - ComputationBuilder builder(client_, "max"); + XlaBuilder builder("max"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Max(x, y); @@ -106,7 +102,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary GT function: (x, y) => x > y { - ComputationBuilder builder(client_, "gt"); + XlaBuilder builder("gt"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Gt(x, y); @@ -117,35 +113,30 @@ class HloCostAnalysisTest : public ::testing::Test { } // Build HLO graph from the given builder and return the HLO module. - std::unique_ptr BuildHloGraph(ComputationBuilder* builder) { + std::unique_ptr BuildHloGraph(XlaBuilder* builder) { auto computation_status = builder->Build(); TF_CHECK_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto user_computation_status = - computation_tracker_.Resolve(computation.handle()); - TF_CHECK_OK(user_computation_status.status()); - auto user_computation = user_computation_status.ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - return std::move( - computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig()) - .ValueOrDie()); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); + return HloModule::CreateFromProto(computation.proto(), config) + .ConsumeValueOrDie(); } Client* client_; Service* service_; - const ComputationTracker& computation_tracker_; // User computations used for higher order operations (e.g., Map, Reduce). - Computation add_; - Computation add_and_exp_; - Computation sigmoid_; - Computation max_; - Computation gt_; + XlaComputation add_; + XlaComputation add_and_exp_; + XlaComputation sigmoid_; + XlaComputation max_; + XlaComputation gt_; }; TEST_F(HloCostAnalysisTest, MatrixMultiply) { - ComputationBuilder builder(client_, "matrix_multiply"); + XlaBuilder builder("matrix_multiply"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); auto result = builder.Dot(lhs, rhs); @@ -167,7 +158,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { } TEST_F(HloCostAnalysisTest, Map) { - ComputationBuilder builder(client_, "map"); + XlaBuilder builder("map"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); auto result = builder.Map({input}, add_and_exp_, {0}); @@ -184,7 +175,7 @@ TEST_F(HloCostAnalysisTest, Map) { } TEST_F(HloCostAnalysisTest, Convolution) { - ComputationBuilder builder(client_, "convolution"); + XlaBuilder builder("convolution"); auto input = builder.Parameter( 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, @@ -213,7 +204,7 @@ TEST_F(HloCostAnalysisTest, Convolution) { } TEST_F(HloCostAnalysisTest, Reduce) { - ComputationBuilder builder(client_, "reduce"); + XlaBuilder builder("reduce"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = @@ -231,7 +222,7 @@ TEST_F(HloCostAnalysisTest, Reduce) { } TEST_F(HloCostAnalysisTest, ReduceWindow) { - ComputationBuilder builder(client_, "reduce_window"); + XlaBuilder builder("reduce_window"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = builder.ReduceWindow(input, builder.ConstantR0(0), add_, @@ -248,7 +239,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { } TEST_F(HloCostAnalysisTest, SelectAndScatter) { - ComputationBuilder builder(client_, "select_and_scatter"); + XlaBuilder builder("select_and_scatter"); auto operand = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto source = @@ -269,7 +260,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { } TEST_F(HloCostAnalysisTest, Broadcast) { - ComputationBuilder b(client_, "broadcast"); + XlaBuilder b("broadcast"); b.Broadcast(b.ConstantR0(42), {10, 7}); auto hlo_module = BuildHloGraph(&b); HloCostAnalysis analysis(ShapeSize); @@ -280,7 +271,7 @@ TEST_F(HloCostAnalysisTest, Broadcast) { // Calculates the computation cost of a graph with more than one HLO node. TEST_F(HloCostAnalysisTest, FullyConnectedForward) { - ComputationBuilder builder(client_, "fully_connected_forward"); + XlaBuilder builder("fully_connected_forward"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); auto weight = @@ -305,7 +296,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis conv_analysis(ShapeSize); { - ComputationBuilder builder(client_, "conv_looking_matmul"); + XlaBuilder builder("conv_looking_matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), "input"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), @@ -318,7 +309,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis matmul_analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); auto rhs = @@ -427,7 +418,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); auto tuple = builder.Tuple({x, y}); @@ -443,7 +434,7 @@ TEST_F(HloCostAnalysisTest, TupleCost) { } TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { - ComputationBuilder builder(client_, "BaseDilatedConvolution"); + XlaBuilder builder("BaseDilatedConvolution"); auto input = builder.Parameter( 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, @@ -458,7 +449,7 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { auto result = builder.ConvGeneralDilated( input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, - ComputationBuilder::CreateDefaultConvDimensionNumbers(2)); + XlaBuilder::CreateDefaultConvDimensionNumbers(2)); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 230147a..cc16446 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -827,7 +827,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1046,7 +1046,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1109,7 +1109,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1180,7 +1180,7 @@ TEST_P(HloEvaluatorTest, *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index f8d98f0..be156d7 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index c7c4160..0319109 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -222,7 +222,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -275,7 +275,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -334,7 +334,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -398,7 +398,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index a4e67cc..f533128 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 54cf054..0571ff5 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1934,24 +1934,6 @@ xla_test( ) xla_test( - name = "set_return_value_test", - srcs = ["set_return_value_test.cc"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - -xla_test( name = "reshape_motion_test", srcs = ["reshape_motion_test.cc"], deps = [ diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 3704ddd..a366afe 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -21,7 +21,8 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -29,27 +30,31 @@ limitations under the License. #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" +namespace { + using xla::string; -xla::Computation Doubler(xla::Client* client) { - xla::ComputationBuilder builder(client, "doubler"); +xla::XlaComputation Doubler() { + xla::XlaBuilder builder("doubler"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto x = builder.Parameter(0, r0f32, "x"); builder.Mul(x, builder.ConstantR0(2.0)); return std::move(builder.Build().ValueOrDie()); } +} // namespace + int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie(); - xla::ComputationBuilder builder(client, "aot_test_helper"); + xla::XlaBuilder builder("aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); auto opaque_param = builder.Parameter(0, opaque_shape, "x"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32); - builder.Call(Doubler(client), {sum}); + builder.Call(Doubler(), {sum}); if (argc != 2) { LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; @@ -71,8 +76,8 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); - xla::Computation computation = builder.Build().ConsumeValueOrDie(); - xla::CompileOnlyClient::AotComputationInstance instance{ + xla::XlaComputation computation = builder.Build().ConsumeValueOrDie(); + xla::CompileOnlyClient::AotXlaComputationInstance instance{ &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; xla::cpu::CpuAotCompilationOptions options( diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc deleted file mode 100644 index 29f79ec..0000000 --- a/tensorflow/compiler/xla/tests/set_return_value_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -class SetReturnValueTest : public ClientLibraryTestBase {}; - -TEST_F(SetReturnValueTest, NoSetValue) { - ComputationBuilder builder(client_, "no_set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - - std::vector expected = {1.0, 3.0, 4.0, 0.0, -1.0, - 5.0, 6.0, -2.0, -3.0, 7.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValue) { - ComputationBuilder builder(client_, "set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueAndModify) { - ComputationBuilder builder(client_, "set_value_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { - ComputationBuilder builder(client_, "set_value_multiple_times_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(aax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaaax = builder.Add(alpha, aaax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 3dded3f..5cce7a2 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -350,7 +349,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { } XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0); auto one = builder.ConstantR0(10); auto x = builder.ConstantR1({-3, 3, 9, 13}); -- 2.7.4