[XLA] Redesign: implement XlaBuilder::IsConstant, XlaBuidler::BuildConstantSubGraph...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 10 Apr 2018 05:04:04 +0000 (22:04 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 10 Apr 2018 05:06:41 +0000 (22:06 -0700)
- Since the builder no longer holds a client, we moved the ComputeConstant to the client side so that it can communicate with the service side. Now we add XlaBuilder::BuildConstantSubGraph, which is only responsible for building a subgraph that is compile-time constant.
- Before this change, every XlaBuilder has a unique id. Now since it also builds constant subgraph, we give every XlaComputation being built a global unique id, and uniquify instruction names when actually building the XlaComputation.

PiperOrigin-RevId: 192236997

tensorflow/compiler/xla/client/client.cc
tensorflow/compiler/xla/client/client.h
tensorflow/compiler/xla/client/xla_client/BUILD
tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/client/xla_client/xla_builder.h
tensorflow/compiler/xla/service/service.cc
tensorflow/compiler/xla/service/service.h
tensorflow/compiler/xla/service_interface.h
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/compute_constant_test.cc
tensorflow/compiler/xla/xla.proto

index 3f45167..f0f9429 100644 (file)
@@ -193,6 +193,34 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
   return Transfer(*data, shape_with_output_layout);
 }
 
+StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
+    const XlaComputation& computation, const Layout* output_layout) const {
+  ComputeConstantGraphRequest request;
+  *request.mutable_computation() = computation.proto();
+  if (output_layout != nullptr) {
+    *request.mutable_output_layout() = *output_layout;
+  }
+
+  ComputeConstantResponse response;
+
+  VLOG(2) << "making compute-constant-graph request";
+  Status s = stub_->ComputeConstantGraph(&request, &response);
+  VLOG(2) << "done with request";
+
+  if (!s.ok()) {
+    return s;
+  }
+
+  VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
+
+  if (!response.has_literal()) {
+    return InternalError(
+        "no computed literal in the provided response in ComputeConstantGraph "
+        "request");
+  }
+  return Literal::CreateFromProto(response.literal());
+}
+
 StatusOr<Computation> Client::LoadSnapshot(const SessionModule& module) {
   LoadComputationSnapshotRequest request;
   *request.mutable_module() = module;
index 05d707d..14c685d 100644 (file)
@@ -194,6 +194,27 @@ class Client {
       const ExecutionOptions* execution_options = nullptr,
       ExecutionProfile* execution_profile = nullptr);
 
+  // Computes the value of the given computation using a non-optimized
+  // interpreter on the host.
+  //
+  // The computation must not depend on any parameters, or on stateful operators
+  // such as `RngNormal` or `Infeed`.
+  //
+  // This functionality can be useful when translating a computation into XLA
+  // where something that looked dynamic is required by XLA to be specified as a
+  // constant. E.g. the source computation (outside of XLA) may include a
+  // dynamic computation of the shape of something and ComputeConstant lets you
+  // determine what the value of that computation is in the case where the value
+  // can be determined at compile time.
+  //
+  // If output_layout is non-null, then the output of the computation will be
+  // stored using that layout.
+  //
+  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+  StatusOr<std::unique_ptr<Literal>> ComputeConstant(
+      const XlaComputation& computation,
+      const Layout* output_layout = nullptr) const;
+
   // Unregister the memory for the given GlobalData on the device.
   Status Unregister(const GlobalData& data);
 
index b1dba16..31fa124 100644 (file)
@@ -44,6 +44,7 @@ cc_library(
     hdrs = ["xla_builder.h"],
     deps = [
         ":xla_computation",
+        "//tensorflow/compiler/xla:execution_options_util",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
index 170dd59..a01be28 100644 (file)
@@ -17,12 +17,15 @@ limitations under the License.
 
 #include <functional>
 #include <numeric>
+#include <queue>
 #include <string>
 #include <utility>
 
+#include "tensorflow/compiler/xla/execution_options_util.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/shape_inference.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/mutex.h"
 
@@ -82,7 +85,7 @@ StatusOr<Shape> XlaOp::GetShape() const {
 }
 
 XlaBuilder::XlaBuilder(const string& computation_name)
-    : name_(computation_name), unique_id_(GetUniqueId()) {}
+    : name_(computation_name) {}
 
 XlaBuilder::~XlaBuilder() {}
 
@@ -111,10 +114,11 @@ XlaOp XlaBuilder::NoteErrorOrReturn(
   return op.ConsumeValueOrDie();
 }
 
-StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) {
+StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
   TF_RETURN_IF_ERROR(first_error_);
 
   TF_RET_CHECK(root_id != nullptr);
+
   ProgramShape program_shape;
 
   // Not all instructions can be roots. Walk backwards from the last added
@@ -155,9 +159,56 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) {
   return program_shape;
 }
 
-StatusOr<ProgramShape> XlaBuilder::GetProgramShape() {
-  int64 root_id;
-  return GetProgramShape(&root_id);
+StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
+  int64 root;
+  return GetProgramShape(&root);
+}
+
+void XlaBuilder::IsConstantVisitor(const int64 op_handle,
+                                   std::set<int64>* visited,
+                                   bool* is_constant) const {
+  if (visited->count(op_handle) != 0 || !*is_constant) {
+    return;
+  }
+
+  CHECK(op_handle < instructions_.size() && op_handle >= 0);
+
+  const HloInstructionProto& instr = instructions_[op_handle];
+  const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
+  switch (opcode) {
+    default:
+      for (const int64 operand_id : instr.operand_ids()) {
+        IsConstantVisitor(operand_id, visited, is_constant);
+      }
+      // TODO(b/32495713): We aren't checking the called computations.
+      break;
+
+    // Non functional ops.
+    case HloOpcode::kRng:
+    case HloOpcode::kCrossReplicaSum:
+      // TODO(b/33009255): Implmement constant folding for cross replica sum.
+    case HloOpcode::kInfeed:
+    case HloOpcode::kOutfeed:
+    case HloOpcode::kHostCompute:
+    case HloOpcode::kCall:
+      // TODO(b/32495713): We aren't checking the to_apply computation itself,
+      // so we conservatively say that computations containing the Call op
+      // cannot be constant.  We cannot set is_functional=false in other similar
+      // cases since we're already relying on IsConstant to return true.
+    case HloOpcode::kCustomCall:
+    case HloOpcode::kWhile:
+      // TODO(b/32495713): We aren't checking the condition and body
+      // computations themselves.
+    case HloOpcode::kSend:
+    case HloOpcode::kRecv:
+    case HloOpcode::kParameter:
+      *is_constant = false;
+      break;
+  }
+  if (!*is_constant) {
+    VLOG(1) << "Non-constant: " << instr.name();
+  }
+  visited->insert(op_handle);
 }
 
 XlaComputation XlaBuilder::BuildAndNoteError() {
@@ -180,21 +231,24 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
   }
 
   HloComputationProto entry;
+  entry.set_id(GetUniqueId());  // Give the computation a global unique id.
+  entry.set_name(StrCat(name_, entry.id()));  // Ensure that the name is unique.
 
   {
     int64 root_id;
-    ProgramShape program_shape;
-    TF_ASSIGN_OR_RETURN(program_shape, GetProgramShape(&root_id));
-    entry.mutable_program_shape()->Swap(&program_shape);
+    TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(),
+                        GetProgramShape(&root_id));
     entry.set_root_id(root_id);
   }
 
   for (auto& instruction : instructions_) {
+    // Ensures that the instruction names are unique among the whole graph.
+    const string& new_name =
+        StrCat(instruction.name(), ".", entry.id(), ".", instruction.id());
+    instruction.set_name(new_name);
     entry.add_instructions()->Swap(&instruction);
   }
 
-  entry.set_id(unique_id_);
-  entry.set_name(StrCat(name_, entry.id()));  // Ensure that the name is unique.
   XlaComputation computation(entry.id());
   HloModuleProto* module = computation.mutable_proto();
   module->set_name(entry.name());
@@ -417,11 +471,10 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
                             const string& name) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
-    if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) {
+    if (!parameter_numbers_.insert(parameter_number).second) {
       return InvalidArgument("parameter %lld already registered",
                              parameter_number);
     }
-    parameter_numbers_.insert(parameter_number);
     instr.set_parameter_number(parameter_number);
     instr.set_name(name);
     *instr.mutable_shape() = shape;
@@ -1262,15 +1315,98 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
   });
 }
 
-StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand,
-                                      int64 num_parameters) {
-  return Unimplemented("IsConstant is not implemented.");
+StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
+  TF_RETURN_IF_ERROR(first_error_);
+
+  // Verify that the handle is valid.
+  TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
+
+  bool is_constant = true;
+  std::set<int64> visited;
+  IsConstantVisitor(operand.handle(), &visited, &is_constant);
+  return is_constant;
 }
 
-StatusOr<std::unique_ptr<Literal>> XlaBuilder::ComputeConstant(
-    const XlaOp& operand, const Layout* output_layout,
-    tensorflow::gtl::ArraySlice<Literal> parameters) {
-  return Unimplemented("ComputeConstant is not implemented");
+StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
+    const XlaOp& root_op) const {
+  TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
+  if (!is_constant) {
+    auto op_status = LookUpInstruction(root_op);
+    string op_string =
+        op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
+    return InvalidArgument(
+        "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
+        "  op requested for constant subgraph: %s\n\n"
+        "This is an internal error that typically happens when the XLA user "
+        "(e.g. TensorFlow) is attempting to determine a value that must be a "
+        "compile-time constant (e.g. an array dimension) but it is not capable "
+        "of being evaluated at XLA compile time.\n\n"
+        "Please file a usability bug with the framework being used (e.g. "
+        "TensorFlow).",
+        op_string.c_str());
+  }
+
+  TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
+                      LookUpInstruction(root_op));
+  TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode()));
+  if (!CanBeRoot(opcode)) {
+    return InvalidArgument("the operand with opcode %s cannot be root",
+                           root->opcode().c_str());
+  }
+
+  HloComputationProto entry;
+  entry.set_id(GetUniqueId());  // Give the computation a global unique id.
+  entry.set_name(StrCat(name_, entry.id(), "_compute_constant"));
+  entry.set_root_id(root->id());
+  ProgramShape* program_shape = entry.mutable_program_shape();
+  *program_shape->mutable_result() = root->shape();
+
+  // We use std::set to keep the instruction ids in ascending order (which is
+  // also a valid denpendency order). The related ops will be added to the
+  // subgraph in the same order.
+  std::set<int64> related_ops;
+  tensorflow::gtl::FlatSet<int64> related_calls;  // Related computations.
+  std::queue<int64> worklist;
+  worklist.push(root->id());
+  related_ops.insert(root->id());
+  while (!worklist.empty()) {
+    int64 node = worklist.front();
+    worklist.pop();
+    for (int64 id : instructions_[node].operand_ids()) {
+      if (related_ops.insert(id).second) {
+        worklist.push(id);
+      }
+    }
+    for (int64 called_id : instructions_[node].called_computation_ids()) {
+      related_calls.insert(called_id);
+    }
+  }
+
+  // Add related ops to the computation.
+  for (int64 id : related_ops) {
+    auto* instr = entry.add_instructions();
+    *instr = instructions_[id];
+    // Ensures that the instruction names are unique among the graph.
+    const string& new_name =
+        StrCat(instr->name(), ".", entry.id(), ".", instr->id());
+    instr->set_name(new_name);
+  }
+
+  XlaComputation computation(entry.id());
+  HloModuleProto* module = computation.mutable_proto();
+  module->set_name(entry.name());
+  module->set_id(entry.id());
+  module->set_entry_computation_name(entry.name());
+  module->set_entry_computation_id(entry.id());
+  *module->mutable_program_shape() = *program_shape;
+  for (auto& e : embedded_) {
+    if (related_calls.find(e.second.id()) != related_calls.end()) {
+      *module->add_computations() = e.second;
+    }
+  }
+  *module->add_computations() = std::move(entry);
+
+  return std::move(computation);
 }
 
 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
@@ -1281,10 +1417,6 @@ std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
   return sub_builder;
 }
 
-Status XlaBuilder::SetReturnValue(const XlaOp& operand) {
-  return Unimplemented("SetReturnValue is not implemented.");
-}
-
 /* static */ ConvolutionDimensionNumbers
 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
   ConvolutionDimensionNumbers dimension_numbers;
@@ -1364,10 +1496,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(
   instr.set_id(handle);
   instr.set_opcode(HloOpcodeString(opcode));
   if (instr.name().empty()) {
-    instr.set_name(StrCat(instr.opcode(), ".", unique_id_, ".", handle));
-  } else {
-    // Append the handle to make sure the name is unique.
-    instr.set_name(StrCat(instr.name(), ".", unique_id_, ".", handle));
+    instr.set_name(StrCat(instr.opcode()));
   }
   for (const auto& operand : operands) {
     if (operand.builder_ == nullptr) {
index 0673b86..d747691 100644 (file)
@@ -687,11 +687,12 @@ class XlaBuilder {
   XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
 
   // Returns true if 'operand' is a compile-time constant. A compile-time
-  // constant does not depend on parameters with index greater than or equal to
-  // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`.
-  // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a
-  // compile-time constant without evaluating the computation.
-  StatusOr<bool> IsConstant(const XlaOp& operand, int64 num_parameters = 0);
+  // constant does not depend on any parameters, or on stateful operators such
+  // as `RngNormal` or `Infeed`.
+  //
+  // This tests whether a computation is a compile-time constant without
+  // evaluating the computation.
+  StatusOr<bool> IsConstant(const XlaOp& operand) const;
 
   // Normalizes operand across spatial and batch dimensions for each feature.
   //
@@ -731,47 +732,14 @@ class XlaBuilder {
                       const XlaOp& grad_output, float epsilon,
                       int64 feature_index);
 
-  // Computes the value of a constant indicated by a XlaOp using a non-optimized
-  // interpreter on the host.
-  //
-  // The operand must represent a constant value, which in this case
-  // means that it must not statically depend on any parameter of the
-  // computation that is being built other then the ones specified on the
-  // parameter list. The parameters in the list will be indexed by their
-  // parameter id property so the number of parameters specified should be at
-  // least as many as the largest used parameter index.
-  //
-  // `IsConstant` can be used to test whether a computation is a compile-time
-  // constant without evaluation it. `ComputeConstant` only succeeds for
-  // computations where `IsConstant` returns true.
-  //
-  // This functionality can be useful when translating a computation
-  // into XLA where something that looked dynamic is required by
-  // XLA to be specified as a constant. E.g. the source
-  // computation (outside of XLA) may include a dynamic
-  // computation of the shape of something and ComputeConstant lets
-  // you determine what the value of that computation is in the case
-  // where the value can be determined at compile time.
-  //
-  // If output_layout is non-null, then the output of the computation
-  // will be stored using that layout.
-  StatusOr<std::unique_ptr<Literal>> ComputeConstant(
-      const XlaOp& operand, const Layout* output_layout = nullptr,
-      tensorflow::gtl::ArraySlice<Literal> parameters = {});
-
   // Returns a new XlaBuilder whose resultant Computation is used only by this
   // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
   // behavior as the parent.
   std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
 
-  // Modifies the computation being built so that executions of it will return
-  // the value associated with operand, rather than the last expression enqueued
-  // on the XlaBuilder. Any subsequent operations added to the XlaBuilder will
-  // not have any effect unless SetReturnValue is called again.
-  Status SetReturnValue(const XlaOp& operand);
-
   // Builds the computation with the requested operations, or returns a non-ok
-  // status.
+  // status. Note that all ops that have been enqueued will be moved to the
+  // computation being returned.
   StatusOr<XlaComputation> Build();
 
   // Builds the computation with the requested operations, or notes an error in
@@ -784,6 +752,12 @@ class XlaBuilder {
   // instead.
   XlaComputation BuildAndNoteError();
 
+  // Returns a subgraph that roots on the given root. If the root is not a
+  // compile-time constant (see `IsConstant`), returns an error.
+  //
+  // This will copy the needed ops/computations to the subgraph.
+  StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
+
   // Returns the first error that was encountered while building the
   // computation. When an error is encountered, by default we return a vacuous
   // XlaOp and inform the user of the error that occurred while
@@ -796,7 +770,7 @@ class XlaBuilder {
   StatusOr<Shape> GetShape(const XlaOp& op) const;
 
   // Returns the (inferred) result for the current computation's shape.
-  StatusOr<ProgramShape> GetProgramShape();
+  StatusOr<ProgramShape> GetProgramShape() const;
 
  private:
   StatusOr<XlaOp> AddInstruction(
@@ -851,10 +825,17 @@ class XlaBuilder {
 
   // Returns the (inferred) result for the program shape for the current
   // computation and fills the root_id in the pointer.
-  StatusOr<ProgramShape> GetProgramShape(int64* root_id);
+  StatusOr<ProgramShape> GetProgramShape(int64* root_id) const;
+
+  // A visitor which checks whether an operation is a compile-time constant,
+  // meaning that it doesn't depend on any parameters, or on any stateful
+  // operation such as `RngNormal` or `Infeed`. The visitor walks the
+  // computation starting at a given operation and sets is_constant to false iff
+  // a parameter or stateful operation is encountered.
+  void IsConstantVisitor(const int64 op_handle, std::set<int64>* visited,
+                         bool* is_constant) const;
 
-  string name_;      // Name to use for the built computation.
-  int64 unique_id_;  // The unique id for the built computation.
+  string name_;  // Name to use for the built computation.
 
   // The first error encountered while building the computation.
   // This is OK until the first error is encountered.
index ec883a6..70af1c4 100644 (file)
@@ -1544,6 +1544,50 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
 
   // Since the shape_with_output_layout option in ExecutionOption is
   // non-effective to the Evaluator results, explicit relayout here.
+  //
+  // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
+  if (arg->has_output_layout()) {
+    result_literal = result_literal->Relayout(arg->output_layout());
+  }
+  *result->mutable_literal() = result_literal->ToProto();
+
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::ComputeConstantGraph(
+    const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) {
+  if (!arg->has_computation()) {
+    return InvalidArgument("computations may not be empty");
+  }
+  if (!arg->computation().has_program_shape()) {
+    return InvalidArgument("program shape may not be empty");
+  }
+  if (arg->computation().program_shape().parameters_size() != 0) {
+    return InvalidArgument(
+        "constant computation may not depend on any parameters.");
+  }
+
+  ProgramShape program_shape = arg->computation().program_shape();
+  TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
+  if (arg->has_output_layout()) {
+    TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
+        arg->output_layout(), program_shape.result()));
+  }
+
+  HloModuleConfig config(program_shape);
+
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+                      HloModule::CreateFromProto(arg->computation(), config));
+
+  HloEvaluator evaluator;
+  TF_ASSIGN_OR_RETURN(auto result_literal,
+                      evaluator.Evaluate<std::unique_ptr<Literal>>(
+                          *module, /*arg_literals=*/{}));
+
+  // Since the result layout is non-effective to the Evaluator results, explicit
+  // relayout here.
+  //
+  // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
   if (arg->has_output_layout()) {
     result_literal = result_literal->Relayout(arg->output_layout());
   }
index 9fa72c1..e399f1a 100644 (file)
@@ -206,6 +206,9 @@ class Service : public ServiceInterface {
   // Computes the value of a constant expression.
   tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg,
                                      ComputeConstantResponse* result) override;
+  tensorflow::Status ComputeConstantGraph(
+      const ComputeConstantGraphRequest* arg,
+      ComputeConstantResponse* result) override;
 
   // Returns the shape (with layout) of an array associated with a given data
   // handle.
index 32aae64..5b44c26 100644 (file)
@@ -112,6 +112,10 @@ class ServiceInterface {
   virtual tensorflow::Status ComputeConstant(
       const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0;
 
+  virtual tensorflow::Status ComputeConstantGraph(
+      const ComputeConstantGraphRequest* arg,
+      ComputeConstantResponse* result) = 0;
+
   // Methods used by Computation.
   virtual tensorflow::Status SnapshotComputation(
       const SnapshotComputationRequest* ag,
index 8ecb421..6c43014 100644 (file)
@@ -1551,6 +1551,8 @@ xla_test(
         "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:computation_builder",
         "//tensorflow/compiler/xla/client:global_data",
+        "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+        "//tensorflow/compiler/xla/client/xla_client:xla_computation",
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
index e5a03b4..c15d808 100644 (file)
@@ -21,6 +21,8 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/computation_builder.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -31,6 +33,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 #include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -71,28 +74,35 @@ class ComputeConstantTest : public ::testing::Test {
   }
 
   StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
-      Client* client, const ComputationDataHandle& operand,
-      ComputationBuilder* builder, Layout* output_layout = nullptr,
-      tensorflow::gtl::ArraySlice<Literal> parameters = {}) {
-    TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant(
-                                           operand, output_layout, parameters));
+      Client* client, const XlaOp& operand, XlaBuilder* builder,
+      Layout* output_layout = nullptr) {
+    TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
+    TF_ASSIGN_OR_RETURN(auto computed,
+                        client->ComputeConstant(subgraph, output_layout));
     return std::move(computed);
   }
 
   template <class Scalar>
+  StatusOr<Scalar> ComputeConstantScalar(Client* client, const XlaOp& operand,
+                                         XlaBuilder* builder) {
+    TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
+                                                             builder, nullptr));
+    return literal->Get<Scalar>({});
+  }
+
+  template <class Scalar>
   StatusOr<Scalar> ComputeConstantScalar(
       Client* client, const ComputationDataHandle& operand,
       ComputationBuilder* builder,
       tensorflow::gtl::ArraySlice<Literal> parameters = {}) {
-    TF_ASSIGN_OR_RETURN(
-        auto literal,
-        ComputeConstantLiteral(client, operand, builder, nullptr, parameters));
+    TF_ASSIGN_OR_RETURN(auto literal,
+                        builder->ComputeConstant(
+                            operand, /*output_layout=*/nullptr, parameters));
     return literal->Get<Scalar>({});
   }
 
-  bool IsConstant(const ComputationDataHandle& operand,
-                  ComputationBuilder* builder, int64 num_parameters = 0) {
-    StatusOr<bool> result = builder->IsConstant(operand, num_parameters);
+  bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
+    StatusOr<bool> result = builder->IsConstant(operand);
     EXPECT_TRUE(result.ok()) << result.status();
     return result.ok() ? result.ValueOrDie() : false;
   }
@@ -103,7 +113,7 @@ class ComputeConstantTest : public ::testing::Test {
 TEST_F(ComputeConstantTest, ScalarInt32Literal) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
-    ComputationBuilder b(client, TestName());
+    XlaBuilder b(TestName());
     auto computation = b.ConstantR0<int32>(42);
     EXPECT_TRUE(IsConstant(computation, &b));
 
@@ -116,7 +126,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) {
 TEST_F(ComputeConstantTest, ScalarFloatAdd) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
-    ComputationBuilder b(client, TestName());
+    XlaBuilder b(TestName());
     auto computation =
         b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f));
     EXPECT_TRUE(IsConstant(computation, &b));
@@ -130,7 +140,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) {
 TEST_F(ComputeConstantTest, ScalarRng) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
-    ComputationBuilder b(client, TestName());
+    XlaBuilder b(TestName());
     auto computation =
         b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f),
                      ShapeUtil::MakeShape(F32, {}));
@@ -151,19 +161,21 @@ TEST_F(ComputeConstantTest, Param) {
 
     std::vector<Literal> arguments;
     arguments.push_back(std::move(*Literal::CreateR0(42.5f)));
-    EXPECT_TRUE(IsConstant(computation, &b, arguments.size()));
-
-    auto value =
-        ComputeConstantScalar<float>(client, computation, &b, arguments);
-    ASSERT_TRUE(value.ok()) << value.status();
-    EXPECT_EQ(value.ValueOrDie(), 44.0f);
+    TF_ASSERT_OK_AND_ASSIGN(bool is_constant,
+                            b.IsConstant(computation, arguments.size()));
+    EXPECT_TRUE(is_constant);
+
+    TF_ASSERT_OK_AND_ASSIGN(
+        auto value,
+        ComputeConstantScalar<float>(client, computation, &b, arguments));
+    EXPECT_EQ(value, 44.0f);
   }
 }
 
 TEST_F(ComputeConstantTest, DirectParamMissing) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
-    ComputationBuilder b(client, TestName());
+    XlaBuilder b(TestName());
     auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param");
     EXPECT_FALSE(IsConstant(computation, &b));
 
@@ -177,7 +189,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) {
 TEST_F(ComputeConstantTest, IndirectParamMissing) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
-    ComputationBuilder b(client, TestName());
+    XlaBuilder b(TestName());
     auto computation =
         b.Add(b.ConstantR0<float>(1.0f),
               b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
@@ -195,7 +207,7 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) {
 TEST_F(ComputeConstantTest, UnrelatedParam) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
-    ComputationBuilder b(client, TestName());
+    XlaBuilder b(TestName());
 
     auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0");
     auto constant_4 =
@@ -212,64 +224,64 @@ TEST_F(ComputeConstantTest, UnrelatedParam) {
 
     EXPECT_TRUE(IsConstant(constant_13, &b));
 
-    auto value = ComputeConstantScalar<float>(client, constant_13, &b);
-    ASSERT_TRUE(value.ok()) << value.status();
-    EXPECT_EQ(value.ValueOrDie(), 13.0f);
+    TF_ASSERT_OK_AND_ASSIGN(
+        auto value, ComputeConstantScalar<float>(client, constant_13, &b));
+    EXPECT_EQ(value, 13.0f);
   }
 }
 
 TEST_F(ComputeConstantTest, NonScalarAdd) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
-    ComputationBuilder b(client, TestName());
+    XlaBuilder b(TestName());
 
     auto computation =
         b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4}));
     EXPECT_TRUE(IsConstant(computation, &b));
 
-    auto computed = ComputeConstantLiteral(client, computation, &b);
-    ASSERT_TRUE(computed.ok()) << computed.status();
+    TF_ASSERT_OK_AND_ASSIGN(auto computed,
+                            ComputeConstantLiteral(client, computation, &b));
     std::unique_ptr<Literal> expected_literal =
         Literal::CreateR1<int32>({4, 6});
-    LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+    LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
   }
 }
 
 TEST_F(ComputeConstantTest, IntegerDivide) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
-    ComputationBuilder b(client, TestName());
+    XlaBuilder b(TestName());
     auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3));
     EXPECT_TRUE(IsConstant(computation, &b));
 
-    auto computed = ComputeConstantLiteral(client, computation, &b);
-    ASSERT_TRUE(computed.ok()) << computed.status();
+    TF_ASSERT_OK_AND_ASSIGN(auto computed,
+                            ComputeConstantLiteral(client, computation, &b));
     std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
-    LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+    LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
   }
 }
 
 XLA_TEST_F(ComputeConstantTest, Layout) {
   for (ClientType client_type : client_types) {
     Client* client = ClientOrDie(platform_, client_type);
-    ComputationBuilder b(client, TestName());
+    XlaBuilder b(TestName());
 
     std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
     for (const std::vector<int64>& layout : layouts) {
       auto layout_proto = LayoutUtil::MakeLayout(layout);
-      auto computed = ComputeConstantLiteral(
-          client,
-          b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
-                b.ConstantR2<int32>({{10, 20}, {30, 40}})),
-          &b, &layout_proto);
-      ASSERT_TRUE(computed.ok()) << computed.status();
+      TF_ASSERT_OK_AND_ASSIGN(
+          auto computed, ComputeConstantLiteral(
+                             client,
+                             b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
+                                   b.ConstantR2<int32>({{10, 20}, {30, 40}})),
+                             &b, &layout_proto));
 
       std::unique_ptr<Literal> expected_literal =
           Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
                                              LayoutUtil::MakeLayout(layout));
-      LiteralTestUtil::AssertEqualShapesAndLayouts(
-          expected_literal->shape(), computed.ValueOrDie()->shape());
-      LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+      LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
+                                                   computed->shape());
+      LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
     }
   }
 }
index f9943f7..b4cbdf3 100644 (file)
@@ -417,6 +417,11 @@ message ComputeConstantRequest {
   repeated LiteralProto parameters = 4;
 }
 
+message ComputeConstantGraphRequest {
+  HloModuleProto computation = 1;
+  Layout output_layout = 2;
+}
+
 message ComputeConstantResponse {
   // A LiteralProto is returned directly for this request, instead of a
   // ComputationDataHandle.