From 022d1fe0ecb46a2a7b77f8b99de8c273ef804c82 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 9 Apr 2018 22:04:04 -0700 Subject: [PATCH] [XLA] Redesign: implement XlaBuilder::IsConstant, XlaBuidler::BuildConstantSubGraph, and Client::ComputeConstant(XlaComputation...). - Since the builder no longer holds a client, we moved the ComputeConstant to the client side so that it can communicate with the service side. Now we add XlaBuilder::BuildConstantSubGraph, which is only responsible for building a subgraph that is compile-time constant. - Before this change, every XlaBuilder has a unique id. Now since it also builds constant subgraph, we give every XlaComputation being built a global unique id, and uniquify instruction names when actually building the XlaComputation. PiperOrigin-RevId: 192236997 --- tensorflow/compiler/xla/client/client.cc | 28 ++++ tensorflow/compiler/xla/client/client.h | 21 +++ tensorflow/compiler/xla/client/xla_client/BUILD | 1 + .../compiler/xla/client/xla_client/xla_builder.cc | 183 ++++++++++++++++++--- .../compiler/xla/client/xla_client/xla_builder.h | 69 +++----- tensorflow/compiler/xla/service/service.cc | 44 +++++ tensorflow/compiler/xla/service/service.h | 3 + tensorflow/compiler/xla/service_interface.h | 4 + tensorflow/compiler/xla/tests/BUILD | 2 + .../compiler/xla/tests/compute_constant_test.cc | 100 ++++++----- tensorflow/compiler/xla/xla.proto | 5 + 11 files changed, 345 insertions(+), 115 deletions(-) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 3f45167..f0f9429 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -193,6 +193,34 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } +StatusOr> Client::ComputeConstant( + const XlaComputation& computation, const Layout* output_layout) const { + ComputeConstantGraphRequest request; + *request.mutable_computation() = computation.proto(); + if (output_layout != nullptr) { + *request.mutable_output_layout() = *output_layout; + } + + ComputeConstantResponse response; + + VLOG(2) << "making compute-constant-graph request"; + Status s = stub_->ComputeConstantGraph(&request, &response); + VLOG(2) << "done with request"; + + if (!s.ok()) { + return s; + } + + VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; + + if (!response.has_literal()) { + return InternalError( + "no computed literal in the provided response in ComputeConstantGraph " + "request"); + } + return Literal::CreateFromProto(response.literal()); +} + StatusOr Client::LoadSnapshot(const SessionModule& module) { LoadComputationSnapshotRequest request; *request.mutable_module() = module; diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 05d707d..14c685d 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -194,6 +194,27 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Computes the value of the given computation using a non-optimized + // interpreter on the host. + // + // The computation must not depend on any parameters, or on stateful operators + // such as `RngNormal` or `Infeed`. + // + // This functionality can be useful when translating a computation into XLA + // where something that looked dynamic is required by XLA to be specified as a + // constant. E.g. the source computation (outside of XLA) may include a + // dynamic computation of the shape of something and ComputeConstant lets you + // determine what the value of that computation is in the case where the value + // can be determined at compile time. + // + // If output_layout is non-null, then the output of the computation will be + // stored using that layout. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> ComputeConstant( + const XlaComputation& computation, + const Layout* output_layout = nullptr) const; + // Unregister the memory for the given GlobalData on the device. Status Unregister(const GlobalData& data); diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index b1dba16..31fa124 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -44,6 +44,7 @@ cc_library( hdrs = ["xla_builder.h"], deps = [ ":xla_computation", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 170dd59..a01be28 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -17,12 +17,15 @@ limitations under the License. #include #include +#include #include #include +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" @@ -82,7 +85,7 @@ StatusOr XlaOp::GetShape() const { } XlaBuilder::XlaBuilder(const string& computation_name) - : name_(computation_name), unique_id_(GetUniqueId()) {} + : name_(computation_name) {} XlaBuilder::~XlaBuilder() {} @@ -111,10 +114,11 @@ XlaOp XlaBuilder::NoteErrorOrReturn( return op.ConsumeValueOrDie(); } -StatusOr XlaBuilder::GetProgramShape(int64* root_id) { +StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { TF_RETURN_IF_ERROR(first_error_); TF_RET_CHECK(root_id != nullptr); + ProgramShape program_shape; // Not all instructions can be roots. Walk backwards from the last added @@ -155,9 +159,56 @@ StatusOr XlaBuilder::GetProgramShape(int64* root_id) { return program_shape; } -StatusOr XlaBuilder::GetProgramShape() { - int64 root_id; - return GetProgramShape(&root_id); +StatusOr XlaBuilder::GetProgramShape() const { + int64 root; + return GetProgramShape(&root); +} + +void XlaBuilder::IsConstantVisitor(const int64 op_handle, + std::set* visited, + bool* is_constant) const { + if (visited->count(op_handle) != 0 || !*is_constant) { + return; + } + + CHECK(op_handle < instructions_.size() && op_handle >= 0); + + const HloInstructionProto& instr = instructions_[op_handle]; + const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie(); + switch (opcode) { + default: + for (const int64 operand_id : instr.operand_ids()) { + IsConstantVisitor(operand_id, visited, is_constant); + } + // TODO(b/32495713): We aren't checking the called computations. + break; + + // Non functional ops. + case HloOpcode::kRng: + case HloOpcode::kCrossReplicaSum: + // TODO(b/33009255): Implmement constant folding for cross replica sum. + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kHostCompute: + case HloOpcode::kCall: + // TODO(b/32495713): We aren't checking the to_apply computation itself, + // so we conservatively say that computations containing the Call op + // cannot be constant. We cannot set is_functional=false in other similar + // cases since we're already relying on IsConstant to return true. + case HloOpcode::kCustomCall: + case HloOpcode::kWhile: + // TODO(b/32495713): We aren't checking the condition and body + // computations themselves. + case HloOpcode::kSend: + case HloOpcode::kRecv: + case HloOpcode::kParameter: + *is_constant = false; + break; + } + if (!*is_constant) { + VLOG(1) << "Non-constant: " << instr.name(); + } + visited->insert(op_handle); } XlaComputation XlaBuilder::BuildAndNoteError() { @@ -180,21 +231,24 @@ StatusOr XlaBuilder::Build() { } HloComputationProto entry; + entry.set_id(GetUniqueId()); // Give the computation a global unique id. + entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. { int64 root_id; - ProgramShape program_shape; - TF_ASSIGN_OR_RETURN(program_shape, GetProgramShape(&root_id)); - entry.mutable_program_shape()->Swap(&program_shape); + TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), + GetProgramShape(&root_id)); entry.set_root_id(root_id); } for (auto& instruction : instructions_) { + // Ensures that the instruction names are unique among the whole graph. + const string& new_name = + StrCat(instruction.name(), ".", entry.id(), ".", instruction.id()); + instruction.set_name(new_name); entry.add_instructions()->Swap(&instruction); } - entry.set_id(unique_id_); - entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. XlaComputation computation(entry.id()); HloModuleProto* module = computation.mutable_proto(); module->set_name(entry.name()); @@ -417,11 +471,10 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) { + if (!parameter_numbers_.insert(parameter_number).second) { return InvalidArgument("parameter %lld already registered", parameter_number); } - parameter_numbers_.insert(parameter_number); instr.set_parameter_number(parameter_number); instr.set_name(name); *instr.mutable_shape() = shape; @@ -1262,15 +1315,98 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { }); } -StatusOr XlaBuilder::IsConstant(const XlaOp& operand, - int64 num_parameters) { - return Unimplemented("IsConstant is not implemented."); +StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { + TF_RETURN_IF_ERROR(first_error_); + + // Verify that the handle is valid. + TF_RETURN_IF_ERROR(LookUpInstruction(operand).status()); + + bool is_constant = true; + std::set visited; + IsConstantVisitor(operand.handle(), &visited, &is_constant); + return is_constant; } -StatusOr> XlaBuilder::ComputeConstant( - const XlaOp& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - return Unimplemented("ComputeConstant is not implemented"); +StatusOr XlaBuilder::BuildConstantSubGraph( + const XlaOp& root_op) const { + TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); + if (!is_constant) { + auto op_status = LookUpInstruction(root_op); + string op_string = + op_status.ok() ? op_status.ValueOrDie()->name() : ""; + return InvalidArgument( + "Operand to BuildConstantSubGraph depends on a parameter.\n\n" + " op requested for constant subgraph: %s\n\n" + "This is an internal error that typically happens when the XLA user " + "(e.g. TensorFlow) is attempting to determine a value that must be a " + "compile-time constant (e.g. an array dimension) but it is not capable " + "of being evaluated at XLA compile time.\n\n" + "Please file a usability bug with the framework being used (e.g. " + "TensorFlow).", + op_string.c_str()); + } + + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, + LookUpInstruction(root_op)); + TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); + if (!CanBeRoot(opcode)) { + return InvalidArgument("the operand with opcode %s cannot be root", + root->opcode().c_str()); + } + + HloComputationProto entry; + entry.set_id(GetUniqueId()); // Give the computation a global unique id. + entry.set_name(StrCat(name_, entry.id(), "_compute_constant")); + entry.set_root_id(root->id()); + ProgramShape* program_shape = entry.mutable_program_shape(); + *program_shape->mutable_result() = root->shape(); + + // We use std::set to keep the instruction ids in ascending order (which is + // also a valid denpendency order). The related ops will be added to the + // subgraph in the same order. + std::set related_ops; + tensorflow::gtl::FlatSet related_calls; // Related computations. + std::queue worklist; + worklist.push(root->id()); + related_ops.insert(root->id()); + while (!worklist.empty()) { + int64 node = worklist.front(); + worklist.pop(); + for (int64 id : instructions_[node].operand_ids()) { + if (related_ops.insert(id).second) { + worklist.push(id); + } + } + for (int64 called_id : instructions_[node].called_computation_ids()) { + related_calls.insert(called_id); + } + } + + // Add related ops to the computation. + for (int64 id : related_ops) { + auto* instr = entry.add_instructions(); + *instr = instructions_[id]; + // Ensures that the instruction names are unique among the graph. + const string& new_name = + StrCat(instr->name(), ".", entry.id(), ".", instr->id()); + instr->set_name(new_name); + } + + XlaComputation computation(entry.id()); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_id(entry.id()); + module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); + *module->mutable_program_shape() = *program_shape; + for (auto& e : embedded_) { + if (related_calls.find(e.second.id()) != related_calls.end()) { + *module->add_computations() = e.second; + } + } + *module->add_computations() = std::move(entry); + + return std::move(computation); } std::unique_ptr XlaBuilder::CreateSubBuilder( @@ -1281,10 +1417,6 @@ std::unique_ptr XlaBuilder::CreateSubBuilder( return sub_builder; } -Status XlaBuilder::SetReturnValue(const XlaOp& operand) { - return Unimplemented("SetReturnValue is not implemented."); -} - /* static */ ConvolutionDimensionNumbers XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { ConvolutionDimensionNumbers dimension_numbers; @@ -1364,10 +1496,7 @@ StatusOr XlaBuilder::AddInstruction( instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { - instr.set_name(StrCat(instr.opcode(), ".", unique_id_, ".", handle)); - } else { - // Append the handle to make sure the name is unique. - instr.set_name(StrCat(instr.name(), ".", unique_id_, ".", handle)); + instr.set_name(StrCat(instr.opcode())); } for (const auto& operand : operands) { if (operand.builder_ == nullptr) { diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 0673b86..d747691 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -687,11 +687,12 @@ class XlaBuilder { XlaOp Recv(const Shape& shape, const ChannelHandle& handle); // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters with index greater than or equal to - // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. - // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a - // compile-time constant without evaluating the computation. - StatusOr IsConstant(const XlaOp& operand, int64 num_parameters = 0); + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + StatusOr IsConstant(const XlaOp& operand) const; // Normalizes operand across spatial and batch dimensions for each feature. // @@ -731,47 +732,14 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - // Computes the value of a constant indicated by a XlaOp using a non-optimized - // interpreter on the host. - // - // The operand must represent a constant value, which in this case - // means that it must not statically depend on any parameter of the - // computation that is being built other then the ones specified on the - // parameter list. The parameters in the list will be indexed by their - // parameter id property so the number of parameters specified should be at - // least as many as the largest used parameter index. - // - // `IsConstant` can be used to test whether a computation is a compile-time - // constant without evaluation it. `ComputeConstant` only succeeds for - // computations where `IsConstant` returns true. - // - // This functionality can be useful when translating a computation - // into XLA where something that looked dynamic is required by - // XLA to be specified as a constant. E.g. the source - // computation (outside of XLA) may include a dynamic - // computation of the shape of something and ComputeConstant lets - // you determine what the value of that computation is in the case - // where the value can be determined at compile time. - // - // If output_layout is non-null, then the output of the computation - // will be stored using that layout. - StatusOr> ComputeConstant( - const XlaOp& operand, const Layout* output_layout = nullptr, - tensorflow::gtl::ArraySlice parameters = {}); - // Returns a new XlaBuilder whose resultant Computation is used only by this // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error // behavior as the parent. std::unique_ptr CreateSubBuilder(const string& computation_name); - // Modifies the computation being built so that executions of it will return - // the value associated with operand, rather than the last expression enqueued - // on the XlaBuilder. Any subsequent operations added to the XlaBuilder will - // not have any effect unless SetReturnValue is called again. - Status SetReturnValue(const XlaOp& operand); - // Builds the computation with the requested operations, or returns a non-ok - // status. + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. StatusOr Build(); // Builds the computation with the requested operations, or notes an error in @@ -784,6 +752,12 @@ class XlaBuilder { // instead. XlaComputation BuildAndNoteError(); + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; + // Returns the first error that was encountered while building the // computation. When an error is encountered, by default we return a vacuous // XlaOp and inform the user of the error that occurred while @@ -796,7 +770,7 @@ class XlaBuilder { StatusOr GetShape(const XlaOp& op) const; // Returns the (inferred) result for the current computation's shape. - StatusOr GetProgramShape(); + StatusOr GetProgramShape() const; private: StatusOr AddInstruction( @@ -851,10 +825,17 @@ class XlaBuilder { // Returns the (inferred) result for the program shape for the current // computation and fills the root_id in the pointer. - StatusOr GetProgramShape(int64* root_id); + StatusOr GetProgramShape(int64* root_id) const; + + // A visitor which checks whether an operation is a compile-time constant, + // meaning that it doesn't depend on any parameters, or on any stateful + // operation such as `RngNormal` or `Infeed`. The visitor walks the + // computation starting at a given operation and sets is_constant to false iff + // a parameter or stateful operation is encountered. + void IsConstantVisitor(const int64 op_handle, std::set* visited, + bool* is_constant) const; - string name_; // Name to use for the built computation. - int64 unique_id_; // The unique id for the built computation. + string name_; // Name to use for the built computation. // The first error encountered while building the computation. // This is OK until the first error is encountered. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index ec883a6..70af1c4 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1544,6 +1544,50 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, // Since the shape_with_output_layout option in ExecutionOption is // non-effective to the Evaluator results, explicit relayout here. + // + // TODO(b/77824332): Make HloEvaluator take care of the re-layout. + if (arg->has_output_layout()) { + result_literal = result_literal->Relayout(arg->output_layout()); + } + *result->mutable_literal() = result_literal->ToProto(); + + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::ComputeConstantGraph( + const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { + if (!arg->has_computation()) { + return InvalidArgument("computations may not be empty"); + } + if (!arg->computation().has_program_shape()) { + return InvalidArgument("program shape may not be empty"); + } + if (arg->computation().program_shape().parameters_size() != 0) { + return InvalidArgument( + "constant computation may not depend on any parameters."); + } + + ProgramShape program_shape = arg->computation().program_shape(); + TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); + if (arg->has_output_layout()) { + TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( + arg->output_layout(), program_shape.result())); + } + + HloModuleConfig config(program_shape); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(arg->computation(), config)); + + HloEvaluator evaluator; + TF_ASSIGN_OR_RETURN(auto result_literal, + evaluator.Evaluate>( + *module, /*arg_literals=*/{})); + + // Since the result layout is non-effective to the Evaluator results, explicit + // relayout here. + // + // TODO(b/77824332): Make HloEvaluator take care of the re-layout. if (arg->has_output_layout()) { result_literal = result_literal->Relayout(arg->output_layout()); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 9fa72c1..e399f1a 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -206,6 +206,9 @@ class Service : public ServiceInterface { // Computes the value of a constant expression. tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, ComputeConstantResponse* result) override; + tensorflow::Status ComputeConstantGraph( + const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Returns the shape (with layout) of an array associated with a given data // handle. diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 32aae64..5b44c26 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -112,6 +112,10 @@ class ServiceInterface { virtual tensorflow::Status ComputeConstant( const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0; + virtual tensorflow::Status ComputeConstantGraph( + const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) = 0; + // Methods used by Computation. virtual tensorflow::Status SnapshotComputation( const SnapshotComputationRequest* ag, diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 8ecb421..6c43014 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1551,6 +1551,8 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index e5a03b4..c15d808 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -31,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" @@ -71,28 +74,35 @@ class ComputeConstantTest : public ::testing::Test { } StatusOr> ComputeConstantLiteral( - Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, Layout* output_layout = nullptr, - tensorflow::gtl::ArraySlice parameters = {}) { - TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant( - operand, output_layout, parameters)); + Client* client, const XlaOp& operand, XlaBuilder* builder, + Layout* output_layout = nullptr) { + TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand)); + TF_ASSIGN_OR_RETURN(auto computed, + client->ComputeConstant(subgraph, output_layout)); return std::move(computed); } template + StatusOr ComputeConstantScalar(Client* client, const XlaOp& operand, + XlaBuilder* builder) { + TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, + builder, nullptr)); + return literal->Get({}); + } + + template StatusOr ComputeConstantScalar( Client* client, const ComputationDataHandle& operand, ComputationBuilder* builder, tensorflow::gtl::ArraySlice parameters = {}) { - TF_ASSIGN_OR_RETURN( - auto literal, - ComputeConstantLiteral(client, operand, builder, nullptr, parameters)); + TF_ASSIGN_OR_RETURN(auto literal, + builder->ComputeConstant( + operand, /*output_layout=*/nullptr, parameters)); return literal->Get({}); } - bool IsConstant(const ComputationDataHandle& operand, - ComputationBuilder* builder, int64 num_parameters = 0) { - StatusOr result = builder->IsConstant(operand, num_parameters); + bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { + StatusOr result = builder->IsConstant(operand); EXPECT_TRUE(result.ok()) << result.status(); return result.ok() ? result.ValueOrDie() : false; } @@ -103,7 +113,7 @@ class ComputeConstantTest : public ::testing::Test { TEST_F(ComputeConstantTest, ScalarInt32Literal) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.ConstantR0(42); EXPECT_TRUE(IsConstant(computation, &b)); @@ -116,7 +126,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) { TEST_F(ComputeConstantTest, ScalarFloatAdd) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); EXPECT_TRUE(IsConstant(computation, &b)); @@ -130,7 +140,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) { TEST_F(ComputeConstantTest, ScalarRng) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), ShapeUtil::MakeShape(F32, {})); @@ -151,19 +161,21 @@ TEST_F(ComputeConstantTest, Param) { std::vector arguments; arguments.push_back(std::move(*Literal::CreateR0(42.5f))); - EXPECT_TRUE(IsConstant(computation, &b, arguments.size())); - - auto value = - ComputeConstantScalar(client, computation, &b, arguments); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 44.0f); + TF_ASSERT_OK_AND_ASSIGN(bool is_constant, + b.IsConstant(computation, arguments.size())); + EXPECT_TRUE(is_constant); + + TF_ASSERT_OK_AND_ASSIGN( + auto value, + ComputeConstantScalar(client, computation, &b, arguments)); + EXPECT_EQ(value, 44.0f); } } TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); EXPECT_FALSE(IsConstant(computation, &b)); @@ -177,7 +189,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { TEST_F(ComputeConstantTest, IndirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Add(b.ConstantR0(1.0f), b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); @@ -195,7 +207,7 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { TEST_F(ComputeConstantTest, UnrelatedParam) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); auto constant_4 = @@ -212,64 +224,64 @@ TEST_F(ComputeConstantTest, UnrelatedParam) { EXPECT_TRUE(IsConstant(constant_13, &b)); - auto value = ComputeConstantScalar(client, constant_13, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 13.0f); + TF_ASSERT_OK_AND_ASSIGN( + auto value, ComputeConstantScalar(client, constant_13, &b)); + EXPECT_EQ(value, 13.0f); } } TEST_F(ComputeConstantTest, NonScalarAdd) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(client, computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computed, + ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); } } TEST_F(ComputeConstantTest, IntegerDivide) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(client, computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computed, + ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); } } XLA_TEST_F(ComputeConstantTest, Layout) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); std::vector> layouts = {{0, 1}, {1, 0}}; for (const std::vector& layout : layouts) { auto layout_proto = LayoutUtil::MakeLayout(layout); - auto computed = ComputeConstantLiteral( - client, - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), - &b, &layout_proto); - ASSERT_TRUE(computed.ok()) << computed.status(); + TF_ASSERT_OK_AND_ASSIGN( + auto computed, ComputeConstantLiteral( + client, + b.Add(b.ConstantR2({{1, 2}, {3, 4}}), + b.ConstantR2({{10, 20}, {30, 40}})), + &b, &layout_proto)); std::unique_ptr expected_literal = Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); - LiteralTestUtil::AssertEqualShapesAndLayouts( - expected_literal->shape(), computed.ValueOrDie()->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), + computed->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); } } } diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f9943f7..b4cbdf3 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -417,6 +417,11 @@ message ComputeConstantRequest { repeated LiteralProto parameters = 4; } +message ComputeConstantGraphRequest { + HloModuleProto computation = 1; + Layout output_layout = 2; +} + message ComputeConstantResponse { // A LiteralProto is returned directly for this request, instead of a // ComputationDataHandle. -- 2.7.4