[XLA] Remove XlaOp::GetShape. It only works when the buidler of the XlaOp is not...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 17 May 2018 01:13:00 +0000 (18:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 01:15:48 +0000 (18:15 -0700)
PiperOrigin-RevId: 196921647

tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/client/xla_client/xla_builder.h
tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc

index 2c6b6c6..ae50631 100644 (file)
@@ -57,16 +57,6 @@ bool CanBeRoot(HloOpcode opcode) {
   }
 }
 
-StatusOr<std::vector<Shape>> GetOperandShapes(
-    tensorflow::gtl::ArraySlice<XlaOp> operands) {
-  std::vector<Shape> operand_shapes;
-  for (const XlaOp& operand : operands) {
-    TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape());
-    operand_shapes.push_back(shape);
-  }
-  return operand_shapes;
-}
-
 }  // namespace
 
 StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
@@ -76,12 +66,14 @@ StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
   return instr->shape();
 }
 
-StatusOr<Shape> XlaOp::GetShape() const {
-  if (builder_ == nullptr) {
-    return InvalidArgument(
-        "cannot GetShape for an invalid XlaOp with handle %lld", handle());
+StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
+    tensorflow::gtl::ArraySlice<XlaOp> operands) const {
+  std::vector<Shape> operand_shapes;
+  for (const XlaOp& operand : operands) {
+    TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
+    operand_shapes.push_back(shape);
   }
-  return builder_->GetShape(*this);
+  return operand_shapes;
 }
 
 XlaBuilder::XlaBuilder(const string& computation_name)
@@ -286,7 +278,7 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
                                                  const XlaOp& operand) {
   TF_RETURN_IF_ERROR(first_error_);
 
-  TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+  TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
 
   CHECK(ShapeUtil::IsScalar(operand_shape) ||
         ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape));
@@ -325,7 +317,7 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
-    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
     TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
                         ShapeInference::InferUnaryOpShape(unop, operand_shape));
     return AddInstruction(std::move(instr), unop, {operand});
@@ -337,8 +329,8 @@ XlaOp XlaBuilder::BinaryOp(
     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
-    TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape());
-    TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape());
+    TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+    TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
     TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
                         ShapeInference::InferBinaryOpShape(
                             binop, lhs_shape, rhs_shape, broadcast_dimensions));
@@ -374,12 +366,12 @@ XlaOp XlaBuilder::BinaryOp(
       updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
     }
 
-    TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape());
+    TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs));
     if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) {
       TF_ASSIGN_OR_RETURN(updated_lhs,
                           AddBroadcastSequence(instr.shape(), updated_lhs));
     }
-    TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape());
+    TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs));
     if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) {
       TF_ASSIGN_OR_RETURN(updated_rhs,
                           AddBroadcastSequence(instr.shape(), updated_rhs));
@@ -393,9 +385,9 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
                             const XlaOp& ehs) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
-    TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape());
-    TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape());
-    TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, ehs.GetShape());
+    TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+    TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
+    TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs));
     TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
                         ShapeInference::InferTernaryOpShape(
                             triop, lhs_shape, rhs_shape, ehs_shape));
@@ -485,7 +477,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
 XlaOp XlaBuilder::Broadcast(
     const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
     TF_ASSIGN_OR_RETURN(
         const Shape& shape,
         ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes));
@@ -633,7 +625,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
                           tensorflow::gtl::ArraySlice<int64> dimensions,
                           tensorflow::gtl::ArraySlice<int64> new_sizes) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
     TF_ASSIGN_OR_RETURN(const Shape& shape,
                         ShapeInference::InferReshapeShape(
                             operand_shape, dimensions, new_sizes));
@@ -647,7 +639,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
 XlaOp XlaBuilder::Reshape(const XlaOp& operand,
                           tensorflow::gtl::ArraySlice<int64> new_sizes) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape());
+    TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
     std::vector<int64> dimensions(shape.dimensions_size());
     std::iota(dimensions.begin(), dimensions.end(), 0);
     return Reshape(operand, dimensions, new_sizes);
@@ -1002,7 +994,7 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
                       const tensorflow::gtl::ArraySlice<int64> fft_length) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
-    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
     TF_ASSIGN_OR_RETURN(
         *instr.mutable_shape(),
         ShapeInference::InferFftShape(operand_shape, fft_type, fft_length));
@@ -1233,7 +1225,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand,
                             tensorflow::gtl::ArraySlice<int64> permutation) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
-    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+    TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
     TF_ASSIGN_OR_RETURN(
         *instr.mutable_shape(),
         ShapeInference::InferTransposeShape(operand_shape, permutation));
@@ -1956,11 +1948,18 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
     const XlaOp& op) const {
   TF_RETURN_IF_ERROR(first_error_);
 
+  if (op.builder_ == nullptr) {
+    return InvalidArgument(
+        "invalid XlaOp with handle %lld; the builder of this op is freed",
+        op.handle());
+  }
   if (op.builder_ != this) {
-    return InvalidArgument("invalid XlaOp with handle %lld", op.handle());
+    return InvalidArgument(
+        "XlaOp with handle %lld is built by builder '%s', but is trying to use "
+        "it in builder '%s'",
+        op.handle(), op.builder_->name().c_str(), this->name().c_str());
   }
 
-  TF_RET_CHECK(op.builder_ == this);
   if (op.handle() >= instructions_.size() || op.handle() < 0) {
     return InvalidArgument("no XlaOp value %lld", op.handle());
   }
index e580703..d802e43 100644 (file)
@@ -55,8 +55,6 @@ class XlaOp {
   XlaOp() : handle_(0), builder_(nullptr) {}
   ~XlaOp() {}
 
-  StatusOr<Shape> GetShape() const;
-
   const XlaBuilder* builder() const { return builder_; }
 
   bool operator==(const XlaOp& rhs) const {
@@ -853,6 +851,10 @@ class XlaBuilder {
   // computation and fills the root_id in the pointer.
   StatusOr<ProgramShape> GetProgramShape(int64* root_id) const;
 
+  // Returns shapes for the operands.
+  StatusOr<std::vector<Shape>> GetOperandShapes(
+      tensorflow::gtl::ArraySlice<XlaOp> operands) const;
+
   // A visitor which checks whether an operation is a compile-time constant,
   // meaning that it doesn't depend on any parameters, or on any stateful
   // operation such as `RngNormal` or `Infeed`. The visitor walks the
index ce98456..2df3ea3 100644 (file)
@@ -76,7 +76,7 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) {
   auto y = b.Parameter(1, y_shape, "y");
   auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1});
 
-  TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape());
+  TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add));
   EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape));
 
   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
@@ -188,8 +188,10 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
   builder.Add(p0, p0);
   auto statusor = builder.Build();
   ASSERT_FALSE(statusor.ok());
-  EXPECT_THAT(statusor.status().error_message(),
-              HasSubstr("Do not add XlaOp from builder b1 to builder main"));
+  EXPECT_THAT(
+      statusor.status().error_message(),
+      HasSubstr(
+          "built by builder 'b1', but is trying to use it in builder 'main'"));
 }
 
 TEST_F(XlaBuilderTest, ReshapeDefaultOrder) {