return out;
}
-CallContext GetInstructionCallContext(const HloInstruction* instruction) {
- switch (instruction->opcode()) {
+CallContext GetInstructionCallContext(HloOpcode opcode) {
+ switch (opcode) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kWhile:
void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
CHECK_EQ(instruction->parent(), computation());
- const CallContext context = GetInstructionCallContext(instruction);
+ const CallContext context = GetInstructionCallContext(instruction->opcode());
if (!instruction->called_computations().empty()) {
CHECK(context == CallContext::kSequential ||
context == CallContext::kParallel);
string CallContextToString(CallContext context);
std::ostream& operator<<(std::ostream& out, const CallContext& context);
-CallContext GetInstructionCallContext(const HloInstruction* instruction);
+CallContext GetInstructionCallContext(HloOpcode opcode);
// Represents an HLO instruction which calls one or more computations.
class CallSite {
auto current = worklist.back();
worklist.pop_back();
for (auto* instruction : current->instructions()) {
- if (GetInstructionCallContext(instruction) !=
+ if (GetInstructionCallContext(instruction->opcode()) !=
CallContext::kSequential) {
continue;
}