}
void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
- UnimplementedOp();
+ NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeNil();
+ *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
+ return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
+ });
}
XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
}
}
+ if (instruction->opcode() == HloOpcode::kTrace) {
+ TF_RET_CHECK(instruction->operands().size() == 1)
+ << "Trace instruction should have 1 operand but sees "
+ << instruction->operands().size();
+ instruction->mutable_operand(0)->set_tracing(instruction.get());
+ }
+
TF_RET_CHECK(!proto.name().empty());
instruction->name_ = proto.name();
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
instruction->operands_.push_back(operand);
instruction->literal_ = Literal::CreateR1U8(tag);
+ operand->set_tracing(instruction.get());
return instruction;
}
HloInstruction* operand = lookup_instruction(trace_request.operand());
hlo_instruction = add_instruction(
HloInstruction::CreateTrace(trace_request.tag(), operand));
- operand->set_tracing(hlo_instruction);
break;
}