[XLA] Simpify XlaBuilder: extract common add instruction logic.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 21 Mar 2018 01:36:33 +0000 (18:36 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 01:39:18 +0000 (18:39 -0700)
PiperOrigin-RevId: 189848174

tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/client/xla_client/xla_builder.h

index 8829fc6..82b61d4 100644 (file)
@@ -51,21 +51,16 @@ bool CanBeRoot(HloOpcode opcode) {
   }
 }
 
-void SetOpcode(HloInstructionProto* instr, HloOpcode opcode) {
-  instr->set_opcode(HloOpcodeString(opcode));
-}
-
 }  // namespace
 
-StatusOr<std::unique_ptr<Shape>> XlaBuilder::GetShape(const XlaOp& op) const {
+StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
   TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
-  return MakeUnique<Shape>(instr->shape());
+  return instr->shape();
 }
 
 StatusOr<Shape> XlaOp::GetShape() const {
   TF_RET_CHECK(builder_ != nullptr);
-  TF_ASSIGN_OR_RETURN(auto shape, builder_->GetShape(*this));
-  return *shape;
+  return builder_->GetShape(*this);
 }
 
 XlaBuilder::XlaBuilder(const string& computation_name)
@@ -158,49 +153,41 @@ XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
                       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
   auto op = [&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
-    SetOpcode(&instr, HloOpcode::kAdd);
-    TF_ASSIGN_OR_RETURN(const auto* lhs_instr, LookUpInstruction(lhs));
-    TF_ASSIGN_OR_RETURN(const auto* rhs_instr, LookUpInstruction(rhs));
-    TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
-                        ShapeInference::InferBinaryOpShape(
-                            HloOpcode::kAdd, lhs_instr->shape(),
-                            rhs_instr->shape(), broadcast_dimensions));
-    instr.add_operand_ids(lhs_instr->id());
-    instr.add_operand_ids(rhs_instr->id());
-    return AddInstruction(std::move(instr));
+    TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape());
+    TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape());
+    TF_ASSIGN_OR_RETURN(
+        *instr.mutable_shape(),
+        ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, lhs_shape,
+                                           rhs_shape, broadcast_dimensions));
+    return AddInstruction(std::move(instr), HloOpcode::kAdd, {lhs, rhs});
   };
   return NoteErrorOrReturn(op());
 }
 
 XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
   HloInstructionProto instr;
-  SetOpcode(&instr, HloOpcode::kConstant);
   *instr.mutable_shape() = literal.shape();
   *instr.mutable_literal() = literal.ToProto();
-  return AddInstruction(std::move(instr));
+  return AddInstruction(std::move(instr), HloOpcode::kConstant);
 }
 
 XlaOp XlaBuilder::Call(const XlaComputation& computation,
                        tensorflow::gtl::ArraySlice<XlaOp> operands) {
   auto op = [&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
-    SetOpcode(&instr, HloOpcode::kCall);
-    std::vector<const Shape*> operand_shapes;
+    std::vector<const Shape*> operand_shape_ptrs;
+    std::vector<Shape> operand_shapes;
     for (const auto& operand : operands) {
-      TF_ASSIGN_OR_RETURN(const auto* input, LookUpInstruction(operand));
-      operand_shapes.push_back(&input->shape());
+      TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape());
+      operand_shapes.push_back(shape);
     }
+    c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+                [](const Shape& shape) { return &shape; });
     TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
                         ShapeInference::InferCallShape(
-                            operand_shapes,
+                            operand_shape_ptrs,
                             /*to_apply=*/computation.GetProgramShape()));
 
-    // Add input operands.
-    for (const auto& operand : operands) {
-      TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand));
-      instr.add_operand_ids(operand_instr->id());
-    }
-
     // Add called computation.
     instr.add_called_computation_ids(
         computation.proto().entry_computation_id());
@@ -208,7 +195,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
       embedded_.insert({e.id(), e});
     }
 
-    return AddInstruction(std::move(instr));
+    return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
   };
   return NoteErrorOrReturn(op());
 }
@@ -217,7 +204,6 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
                             const string& name) {
   auto op = [&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
-    SetOpcode(&instr, HloOpcode::kParameter);
     if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) {
       return InvalidArgument("parameter %lld already registered",
                              parameter_number);
@@ -226,20 +212,25 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
     instr.set_parameter_number(parameter_number);
     instr.set_name(name);
     *instr.mutable_shape() = shape;
-    return AddInstruction(std::move(instr));
+    return AddInstruction(std::move(instr), HloOpcode::kParameter);
   };
   return NoteErrorOrReturn(op());
 }
 
-XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) {
+XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
+                                 tensorflow::gtl::ArraySlice<XlaOp> operands) {
   const int64 handle = instructions_.size();
   instr.set_id(handle);
+  instr.set_opcode(HloOpcodeString(opcode));
   if (instr.name().empty()) {
     instr.set_name(StrCat(instr.opcode(), ".", handle));
   } else {
     // Append the handle to make sure the name is unique.
     instr.set_name(StrCat(instr.name(), ".", handle));
   }
+  for (const auto& operand : operands) {
+    instr.add_operand_ids(operand.handle());
+  }
   instructions_.push_back(instr);
 
   XlaOp op(handle, this);
index 7632bd2..f1d10ec 100644 (file)
@@ -26,6 +26,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -157,14 +158,15 @@ class XlaBuilder {
   XlaOp ConstantR0(NativeT value);
 
   // Returns the shape of the given op.
-  StatusOr<std::unique_ptr<Shape>> GetShape(const XlaOp& op) const;
+  StatusOr<Shape> GetShape(const XlaOp& op) const;
 
   // Builds the computation with the requested operations, or returns a non-ok
   // status.
   StatusOr<XlaComputation> Build();
 
  private:
-  XlaOp AddInstruction(HloInstructionProto&& instr);
+  XlaOp AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
+                       tensorflow::gtl::ArraySlice<XlaOp> operands = {});
 
   // Notes that the error occurred by:
   // * storing it internally and capturing a backtrace if it's the first error