From 711e2f503039bd8a277928ef8a2b3740ae2bfa4b Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Fri, 6 Apr 2018 10:28:18 -0700 Subject: [PATCH] Change GetInstructionCallContext to take an opcode instead of an HloInstruction. This enables use of the function without an actual instruction (eg, if you just have an HloProto). PiperOrigin-RevId: 191905914 --- tensorflow/compiler/xla/service/call_graph.cc | 6 +++--- tensorflow/compiler/xla/service/call_graph.h | 2 +- tensorflow/compiler/xla/service/flatten_call_graph.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 13eb02c..a8053d1 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -51,8 +51,8 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { return out; } -CallContext GetInstructionCallContext(const HloInstruction* instruction) { - switch (instruction->opcode()) { +CallContext GetInstructionCallContext(HloOpcode opcode) { + switch (opcode) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kWhile: @@ -101,7 +101,7 @@ void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) { void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { CHECK_EQ(instruction->parent(), computation()); - const CallContext context = GetInstructionCallContext(instruction); + const CallContext context = GetInstructionCallContext(instruction->opcode()); if (!instruction->called_computations().empty()) { CHECK(context == CallContext::kSequential || context == CallContext::kParallel); diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 688c408..97d3811 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -53,7 +53,7 @@ enum class CallContext { string CallContextToString(CallContext context); std::ostream& operator<<(std::ostream& out, const CallContext& context); -CallContext GetInstructionCallContext(const HloInstruction* instruction); +CallContext GetInstructionCallContext(HloOpcode opcode); // Represents an HLO instruction which calls one or more computations. class CallSite { diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index 2b6caa1..85409b3 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -93,7 +93,7 @@ Status FlattenNode(const CallGraphNode& node) { auto current = worklist.back(); worklist.pop_back(); for (auto* instruction : current->instructions()) { - if (GetInstructionCallContext(instruction) != + if (GetInstructionCallContext(instruction->opcode()) != CallContext::kSequential) { continue; } -- 2.7.4