}
}
-void SetOpcode(HloInstructionProto* instr, HloOpcode opcode) {
- instr->set_opcode(HloOpcodeString(opcode));
-}
-
} // namespace
-StatusOr<std::unique_ptr<Shape>> XlaBuilder::GetShape(const XlaOp& op) const {
+StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
- return MakeUnique<Shape>(instr->shape());
+ return instr->shape();
}
StatusOr<Shape> XlaOp::GetShape() const {
TF_RET_CHECK(builder_ != nullptr);
- TF_ASSIGN_OR_RETURN(auto shape, builder_->GetShape(*this));
- return *shape;
+ return builder_->GetShape(*this);
}
XlaBuilder::XlaBuilder(const string& computation_name)
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
auto op = [&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- SetOpcode(&instr, HloOpcode::kAdd);
- TF_ASSIGN_OR_RETURN(const auto* lhs_instr, LookUpInstruction(lhs));
- TF_ASSIGN_OR_RETURN(const auto* rhs_instr, LookUpInstruction(rhs));
- TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
- ShapeInference::InferBinaryOpShape(
- HloOpcode::kAdd, lhs_instr->shape(),
- rhs_instr->shape(), broadcast_dimensions));
- instr.add_operand_ids(lhs_instr->id());
- instr.add_operand_ids(rhs_instr->id());
- return AddInstruction(std::move(instr));
+ TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape());
+ TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape());
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, lhs_shape,
+ rhs_shape, broadcast_dimensions));
+ return AddInstruction(std::move(instr), HloOpcode::kAdd, {lhs, rhs});
};
return NoteErrorOrReturn(op());
}
XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
HloInstructionProto instr;
- SetOpcode(&instr, HloOpcode::kConstant);
*instr.mutable_shape() = literal.shape();
*instr.mutable_literal() = literal.ToProto();
- return AddInstruction(std::move(instr));
+ return AddInstruction(std::move(instr), HloOpcode::kConstant);
}
XlaOp XlaBuilder::Call(const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
auto op = [&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- SetOpcode(&instr, HloOpcode::kCall);
- std::vector<const Shape*> operand_shapes;
+ std::vector<const Shape*> operand_shape_ptrs;
+ std::vector<Shape> operand_shapes;
for (const auto& operand : operands) {
- TF_ASSIGN_OR_RETURN(const auto* input, LookUpInstruction(operand));
- operand_shapes.push_back(&input->shape());
+ TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape());
+ operand_shapes.push_back(shape);
}
+ c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
+ [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferCallShape(
- operand_shapes,
+ operand_shape_ptrs,
/*to_apply=*/computation.GetProgramShape()));
- // Add input operands.
- for (const auto& operand : operands) {
- TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand));
- instr.add_operand_ids(operand_instr->id());
- }
-
// Add called computation.
instr.add_called_computation_ids(
computation.proto().entry_computation_id());
embedded_.insert({e.id(), e});
}
- return AddInstruction(std::move(instr));
+ return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
};
return NoteErrorOrReturn(op());
}
const string& name) {
auto op = [&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- SetOpcode(&instr, HloOpcode::kParameter);
if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) {
return InvalidArgument("parameter %lld already registered",
parameter_number);
instr.set_parameter_number(parameter_number);
instr.set_name(name);
*instr.mutable_shape() = shape;
- return AddInstruction(std::move(instr));
+ return AddInstruction(std::move(instr), HloOpcode::kParameter);
};
return NoteErrorOrReturn(op());
}
-XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) {
+XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<XlaOp> operands) {
const int64 handle = instructions_.size();
instr.set_id(handle);
+ instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
instr.set_name(StrCat(instr.opcode(), ".", handle));
} else {
// Append the handle to make sure the name is unique.
instr.set_name(StrCat(instr.name(), ".", handle));
}
+ for (const auto& operand : operands) {
+ instr.add_operand_ids(operand.handle());
+ }
instructions_.push_back(instr);
XlaOp op(handle, this);