}
}
-StatusOr<std::vector<Shape>> GetOperandShapes(
- tensorflow::gtl::ArraySlice<XlaOp> operands) {
- std::vector<Shape> operand_shapes;
- for (const XlaOp& operand : operands) {
- TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape());
- operand_shapes.push_back(shape);
- }
- return operand_shapes;
-}
-
} // namespace
StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
return instr->shape();
}
-StatusOr<Shape> XlaOp::GetShape() const {
- if (builder_ == nullptr) {
- return InvalidArgument(
- "cannot GetShape for an invalid XlaOp with handle %lld", handle());
+StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
+ tensorflow::gtl::ArraySlice<XlaOp> operands) const {
+ std::vector<Shape> operand_shapes;
+ for (const XlaOp& operand : operands) {
+ TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
+ operand_shapes.push_back(shape);
}
- return builder_->GetShape(*this);
+ return operand_shapes;
}
XlaBuilder::XlaBuilder(const string& computation_name)
const XlaOp& operand) {
TF_RETURN_IF_ERROR(first_error_);
- TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
CHECK(ShapeUtil::IsScalar(operand_shape) ||
ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape));
XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferUnaryOpShape(unop, operand_shape));
return AddInstruction(std::move(instr), unop, {operand});
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape());
- TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+ TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferBinaryOpShape(
binop, lhs_shape, rhs_shape, broadcast_dimensions));
updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
}
- TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape());
+ TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs));
if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_lhs,
AddBroadcastSequence(instr.shape(), updated_lhs));
}
- TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape());
+ TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs));
if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_rhs,
AddBroadcastSequence(instr.shape(), updated_rhs));
const XlaOp& ehs) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape());
- TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape());
- TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, ehs.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+ TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
+ TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferTernaryOpShape(
triop, lhs_shape, rhs_shape, ehs_shape));
XlaOp XlaBuilder::Broadcast(
const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
- TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
const Shape& shape,
ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes));
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
- TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& shape,
ShapeInference::InferReshapeShape(
operand_shape, dimensions, new_sizes));
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
- TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
std::vector<int64> dimensions(shape.dimensions_size());
std::iota(dimensions.begin(), dimensions.end(), 0);
return Reshape(operand, dimensions, new_sizes);
const tensorflow::gtl::ArraySlice<int64> fft_length) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferFftShape(operand_shape, fft_type, fft_length));
tensorflow::gtl::ArraySlice<int64> permutation) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferTransposeShape(operand_shape, permutation));
const XlaOp& op) const {
TF_RETURN_IF_ERROR(first_error_);
+ if (op.builder_ == nullptr) {
+ return InvalidArgument(
+ "invalid XlaOp with handle %lld; the builder of this op is freed",
+ op.handle());
+ }
if (op.builder_ != this) {
- return InvalidArgument("invalid XlaOp with handle %lld", op.handle());
+ return InvalidArgument(
+ "XlaOp with handle %lld is built by builder '%s', but is trying to use "
+ "it in builder '%s'",
+ op.handle(), op.builder_->name().c_str(), this->name().c_str());
}
- TF_RET_CHECK(op.builder_ == this);
if (op.handle() >= instructions_.size() || op.handle() < 0) {
return InvalidArgument("no XlaOp value %lld", op.handle());
}