Change GetInstructionCallContext to take an opcode instead of an
authorMark Heffernan <meheff@google.com>
Fri, 6 Apr 2018 17:28:18 +0000 (10:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 17:34:28 +0000 (10:34 -0700)
HloInstruction.
This enables use of the function without an actual instruction (eg, if you just
have an HloProto).

PiperOrigin-RevId: 191905914

tensorflow/compiler/xla/service/call_graph.cc
tensorflow/compiler/xla/service/call_graph.h
tensorflow/compiler/xla/service/flatten_call_graph.cc

index 13eb02c..a8053d1 100644 (file)
@@ -51,8 +51,8 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) {
   return out;
 }
 
-CallContext GetInstructionCallContext(const HloInstruction* instruction) {
-  switch (instruction->opcode()) {
+CallContext GetInstructionCallContext(HloOpcode opcode) {
+  switch (opcode) {
     case HloOpcode::kCall:
     case HloOpcode::kConditional:
     case HloOpcode::kWhile:
@@ -101,7 +101,7 @@ void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
 
 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
   CHECK_EQ(instruction->parent(), computation());
-  const CallContext context = GetInstructionCallContext(instruction);
+  const CallContext context = GetInstructionCallContext(instruction->opcode());
   if (!instruction->called_computations().empty()) {
     CHECK(context == CallContext::kSequential ||
           context == CallContext::kParallel);
index 688c408..97d3811 100644 (file)
@@ -53,7 +53,7 @@ enum class CallContext {
 string CallContextToString(CallContext context);
 std::ostream& operator<<(std::ostream& out, const CallContext& context);
 
-CallContext GetInstructionCallContext(const HloInstruction* instruction);
+CallContext GetInstructionCallContext(HloOpcode opcode);
 
 // Represents an HLO instruction which calls one or more computations.
 class CallSite {
index 2b6caa1..85409b3 100644 (file)
@@ -93,7 +93,7 @@ Status FlattenNode(const CallGraphNode& node) {
       auto current = worklist.back();
       worklist.pop_back();
       for (auto* instruction : current->instructions()) {
-        if (GetInstructionCallContext(instruction) !=
+        if (GetInstructionCallContext(instruction->opcode()) !=
             CallContext::kSequential) {
           continue;
         }