[XLA] Set trace for the operand of a trace instruction when creating the instruction...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 2 Apr 2018 22:27:24 +0000 (15:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 2 Apr 2018 22:29:48 +0000 (15:29 -0700)
PiperOrigin-RevId: 191357376

tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/user_computation.cc

index 04091ec..ec23621 100644 (file)
@@ -543,7 +543,12 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
 }
 
 void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
-  UnimplementedOp();
+  NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+    HloInstructionProto instr;
+    *instr.mutable_shape() = ShapeUtil::MakeNil();
+    *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
+    return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
+  });
 }
 
 XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
index a2a2c1e..fcf9ebf 100644 (file)
@@ -98,6 +98,13 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
     }
   }
 
+  if (instruction->opcode() == HloOpcode::kTrace) {
+    TF_RET_CHECK(instruction->operands().size() == 1)
+        << "Trace instruction should have 1 operand but sees "
+        << instruction->operands().size();
+    instruction->mutable_operand(0)->set_tracing(instruction.get());
+  }
+
   TF_RET_CHECK(!proto.name().empty());
   instruction->name_ = proto.name();
 
@@ -170,6 +177,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
       WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
   instruction->operands_.push_back(operand);
   instruction->literal_ = Literal::CreateR1U8(tag);
+  operand->set_tracing(instruction.get());
   return instruction;
 }
 
index fcdb2e0..532f7fd 100644 (file)
@@ -3491,7 +3491,6 @@ void ComputationLowerer::Visit(
       HloInstruction* operand = lookup_instruction(trace_request.operand());
       hlo_instruction = add_instruction(
           HloInstruction::CreateTrace(trace_request.tag(), operand));
-      operand->set_tracing(hlo_instruction);
       break;
     }