From: A. Unique TensorFlower Date: Mon, 2 Apr 2018 22:27:24 +0000 (-0700) Subject: [XLA] Set trace for the operand of a trace instruction when creating the instruction... X-Git-Tag: tflite-v0.1.7~39^2^2~120 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=11c0faed23ec32c3f1532f5154dd3c7bb38847d5;p=platform%2Fupstream%2Ftensorflow.git [XLA] Set trace for the operand of a trace instruction when creating the instruction directly or creating from proto. Also implement XlaBuidler::Trace. PiperOrigin-RevId: 191357376 --- diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 04091ec..ec23621 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -543,7 +543,12 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { - UnimplementedOp(); + NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); + }); } XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a2a2c1e..fcf9ebf 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -98,6 +98,13 @@ StatusOr> HloInstruction::CreateFromProto( } } + if (instruction->opcode() == HloOpcode::kTrace) { + TF_RET_CHECK(instruction->operands().size() == 1) + << "Trace instruction should have 1 operand but sees " + << instruction->operands().size(); + instruction->mutable_operand(0)->set_tracing(instruction.get()); + } + TF_RET_CHECK(!proto.name().empty()); instruction->name_ = proto.name(); @@ -170,6 +177,7 @@ StatusOr> HloInstruction::CreateFromProto( WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); instruction->literal_ = Literal::CreateR1U8(tag); + operand->set_tracing(instruction.get()); return instruction; } diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index fcdb2e0..532f7fd 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -3491,7 +3491,6 @@ void ComputationLowerer::Visit( HloInstruction* operand = lookup_instruction(trace_request.operand()); hlo_instruction = add_instruction( HloInstruction::CreateTrace(trace_request.tag(), operand)); - operand->set_tracing(hlo_instruction); break; }