// simply users-1 for each instruction. By subtracting 1, we're saying that
// instructions with no users or a single user don't count; instructions with
// lots of fan-out will be visited earlier.
+ int64 cumulative_total_size = 0;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
continue;
}
extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
- total_sizes[hlo] = SumLogicalBufferSizes(
+ int64 logical_buffer_size = SumLogicalBufferSizes(
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
+ total_sizes[hlo] = logical_buffer_size;
+ cumulative_total_size += logical_buffer_size;
tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
hlo->operands().begin(), hlo->operands().end());
for (const HloInstruction* operand : unique_operands) {
extra_users[hlo] += extra_users[operand];
total_sizes[hlo] += total_sizes[operand];
}
+ total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
}
CHECK_EQ(extra_users.size(), computation.instruction_count());
CHECK_EQ(total_sizes.size(), computation.instruction_count());