Fix buffer assignment for conditional instruction.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Feb 2018 01:27:20 +0000 (17:27 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187107432

tensorflow/compiler/xla/service/buffer_assignment.cc
tensorflow/compiler/xla/service/copy_insertion.cc

index b1e693d..d44d3d7 100644 (file)
@@ -48,6 +48,183 @@ using ::tensorflow::strings::HumanReadableNumBytes;
 using ::tensorflow::strings::Printf;
 using ::tensorflow::strings::StrAppend;
 
+namespace {
+
+template <typename T>
+string ColocatedBufferSetsToString(const T& container, const char* title) {
+  string result;
+  StrAppend(&result, title, "\n");
+  for (const auto& it : container) {
+    StrAppend(&result, "\t", it->ToString(), "\n");
+  }
+  return result;
+}
+
+// Walk the call graph of the HLO module and place each computation into either
+// thread_local_computations or global_computations depending upon whether the
+// computation requires thread-local allocations or global allocations. The
+// elements in thread_local_computations and global_computations are in post
+// order (if computation A has an instruction which calls computation B, then A
+// will appear after B in the vector).
+Status GatherComputationsByAllocationType(
+    const HloModule* module,
+    std::vector<const HloComputation*>* thread_local_computations,
+    std::vector<const HloComputation*>* global_computations) {
+  // Create a worklist of computations paired with whether the allocation must
+  // be thread-local.
+  std::deque<std::pair<const HloComputation*, bool>> worklist;
+  worklist.push_back(std::make_pair(module->entry_computation(),
+                                    /*is_thread_local*/ false));
+
+  // Sets for quickly checking membership. Computations are returned in vectors
+  // for stable iteration.
+  FlatSet<const HloComputation*> thread_local_set;
+  FlatSet<const HloComputation*> global_set;
+
+  while (!worklist.empty()) {
+    auto worklist_front = worklist.front();
+    worklist.pop_front();
+    const HloComputation* computation = worklist_front.first;
+    bool is_thread_local = worklist_front.second;
+    bool in_thread_local_set = thread_local_set.count(computation) > 0;
+    bool in_global_set = global_set.count(computation) > 0;
+
+    // If the computation has already been added to the respective set, then
+    // nothing to do.
+    if ((is_thread_local && in_thread_local_set) ||
+        (!is_thread_local && in_global_set)) {
+      continue;
+    }
+
+    // If the computation has already been added to the other set this is an
+    // error condition because the global call to the computation (eg,
+    // while/call) may return a reference to one of the thread-local buffers to
+    // the calling computation which will become a dangling reference when the
+    // thread-local is deallocated with the call return.
+    if ((is_thread_local && in_global_set) ||
+        (!is_thread_local && in_thread_local_set)) {
+      return InvalidArgument(
+          "computation %s has conflicting allocation requirements (global "
+          "and thread-local)",
+          computation->name().c_str());
+    }
+
+    if (is_thread_local) {
+      thread_local_set.insert(computation);
+    } else {
+      global_set.insert(computation);
+    }
+
+    for (auto* instruction : computation->instructions()) {
+      for (HloComputation* subcomputation :
+           instruction->called_computations()) {
+        switch (instruction->opcode()) {
+          case HloOpcode::kCall:
+          case HloOpcode::kConditional:
+          case HloOpcode::kWhile:
+            // Call and while must be called from a computation with global
+            // allocations as they may return references to buffers inside the
+            // called computation which cannot be thread-local.
+            if (is_thread_local) {
+              return InvalidArgument(
+                  "computation %s cannot contain call/while op because it "
+                  "requires thread-local buffer allocations",
+                  computation->name().c_str());
+            }
+            worklist.push_back(std::make_pair(subcomputation,
+                                              false));  // Not thread local.
+            break;
+          case HloOpcode::kMap:
+          case HloOpcode::kReduce:
+          case HloOpcode::kReduceWindow:
+          case HloOpcode::kSelectAndScatter:
+          case HloOpcode::kFusion:
+            // Map/reduce etc computations are always thread-local.
+            worklist.push_back(std::make_pair(subcomputation,
+                                              true));  // Thread local.
+            break;
+          default:
+            return InternalError(
+                "Unexpected calling opcode: %s",
+                HloOpcodeString(instruction->opcode()).c_str());
+        }
+      }
+    }
+  }
+
+  // Add the computations to the vectors in post order.
+  for (auto* computation : module->MakeComputationPostOrder()) {
+    if (thread_local_set.count(computation) > 0) {
+      thread_local_computations->push_back(computation);
+    } else if (global_set.count(computation) > 0) {
+      global_computations->push_back(computation);
+    }
+    // If the computation is not reachable from the entry computation, then it
+    // will not appear in either thread_local_set or global_set. We don't bother
+    // assigning buffers for these.
+  }
+  return Status::OK();
+}
+
+// Checks that points-to set of 'instruction' is unambiguous and distinct
+// (ensured by CopyInsertion), then adds the buffer from the points-to set at
+// 'index' to 'colocated_set'.
+const LogicalBuffer* AddBufferToColocatedSet(
+    const HloInstruction* instruction, const ShapeIndex& index,
+    const TuplePointsToAnalysis& points_to_analysis,
+    std::vector<const LogicalBuffer*>* colocated_set) {
+  // CopyInsertion ensures root points-to set is unambiguous and distinct.
+  const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
+  DCHECK(!points_to.IsAmbiguous());
+  colocated_set->push_back(points_to.element(index)[0]);
+  return colocated_set->back();
+}
+
+// Given the interference map of a graph (the list of interfering node indices
+// for each node), perform graph coloring such that interfering nodes are
+// assigned to different colors. Returns the assigned color of the nodes, where
+// the colors are represented as integer values [0, color_count).
+std::vector<int64> ColorInterferenceGraph(
+    const std::vector<std::vector<int64>>& interference_map) {
+  const int64 node_count = interference_map.size();
+
+  // Sort the nodes such that we assign nodes with more interference first. This
+  // relies on the common heuristic of assigning the most constrained node
+  // first, but it would be good to investigate other ordering heuristics too.
+  std::vector<int64> nodes(node_count);
+  std::iota(nodes.begin(), nodes.end(), 0);
+  std::sort(nodes.begin(), nodes.end(),
+            [&interference_map](const int64 i, const int64 j) {
+              return interference_map[i].size() > interference_map[j].size();
+            });
+
+  const int64 kColorUnassigned = -1;
+  std::vector<int64> assigned_colors(node_count, kColorUnassigned);
+  for (int64 node : nodes) {
+    // Mark the colors that are already assigned to the neighbors.
+    std::vector<bool> available_colors(node_count, true);
+    for (int64 neighbor : interference_map[node]) {
+      int64 color = assigned_colors[neighbor];
+      if (color != kColorUnassigned) {
+        available_colors[color] = false;
+      }
+    }
+
+    // Find the color that is not yet assigned to the neighbors.
+    int64 color = kColorUnassigned;
+    for (color = 0; color < available_colors.size(); ++color) {
+      if (available_colors[color]) {
+        break;
+      }
+    }
+    CHECK_NE(color, kColorUnassigned);
+    assigned_colors[node] = color;
+  }
+  return assigned_colors;
+}
+
+}  // namespace
+
 size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const {
   uint64 h = std::hash<int64>()(s.index());
   h = tensorflow::Hash64Combine(h, std::hash<int64>()(s.offset()));
@@ -523,116 +700,6 @@ BufferAssignmentProto BufferAssignment::ToProto() const {
   return proto;
 }
 
-namespace {
-
-// Walk the call graph of the HLO module and place each computation into either
-// thread_local_computations or global_computations depending upon whether the
-// computation requires thread-local allocations or global allocations. The
-// elements in thread_local_computations and global_computations are in post
-// order (if computation A has an instruction which calls computation B, then A
-// will appear after B in the vector).
-Status GatherComputationsByAllocationType(
-    const HloModule* module,
-    std::vector<const HloComputation*>* thread_local_computations,
-    std::vector<const HloComputation*>* global_computations) {
-  // Create a worklist of computations paired with whether the allocation must
-  // be thread-local.
-  std::deque<std::pair<const HloComputation*, bool>> worklist;
-  worklist.push_back(std::make_pair(module->entry_computation(),
-                                    /*is_thread_local*/ false));
-
-  // Sets for quickly checking membership. Computations are returned in vectors
-  // for stable iteration.
-  FlatSet<const HloComputation*> thread_local_set;
-  FlatSet<const HloComputation*> global_set;
-
-  while (!worklist.empty()) {
-    auto worklist_front = worklist.front();
-    worklist.pop_front();
-    const HloComputation* computation = worklist_front.first;
-    bool is_thread_local = worklist_front.second;
-    bool in_thread_local_set = thread_local_set.count(computation) > 0;
-    bool in_global_set = global_set.count(computation) > 0;
-
-    // If the computation has already been added to the respective set, then
-    // nothing to do.
-    if ((is_thread_local && in_thread_local_set) ||
-        (!is_thread_local && in_global_set)) {
-      continue;
-    }
-
-    // If the computation has already been added to the other set this is an
-    // error condition because the global call to the computation (eg,
-    // while/call) may return a reference to one of the thread-local buffers to
-    // the calling computation which will become a dangling reference when the
-    // thread-local is deallocated with the call return.
-    if ((is_thread_local && in_global_set) ||
-        (!is_thread_local && in_thread_local_set)) {
-      return InvalidArgument(
-          "computation %s has conflicting allocation requirements (global "
-          "and thread-local)",
-          computation->name().c_str());
-    }
-
-    if (is_thread_local) {
-      thread_local_set.insert(computation);
-    } else {
-      global_set.insert(computation);
-    }
-
-    for (auto* instruction : computation->instructions()) {
-      for (HloComputation* subcomputation :
-           instruction->called_computations()) {
-        switch (instruction->opcode()) {
-          case HloOpcode::kCall:
-          case HloOpcode::kConditional:
-          case HloOpcode::kWhile:
-            // Call and while must be called from a computation with global
-            // allocations as they may return references to buffers inside the
-            // called computation which cannot be thread-local.
-            if (is_thread_local) {
-              return InvalidArgument(
-                  "computation %s cannot contain call/while op because it "
-                  "requires thread-local buffer allocations",
-                  computation->name().c_str());
-            }
-            worklist.push_back(std::make_pair(subcomputation,
-                                              false));  // Not thread local.
-            break;
-          case HloOpcode::kMap:
-          case HloOpcode::kReduce:
-          case HloOpcode::kReduceWindow:
-          case HloOpcode::kSelectAndScatter:
-          case HloOpcode::kFusion:
-            // Map/reduce etc computations are always thread-local.
-            worklist.push_back(std::make_pair(subcomputation,
-                                              true));  // Thread local.
-            break;
-          default:
-            return InternalError(
-                "Unexpected calling opcode: %s",
-                HloOpcodeString(instruction->opcode()).c_str());
-        }
-      }
-    }
-  }
-
-  // Add the computations to the vectors in post order.
-  for (auto* computation : module->MakeComputationPostOrder()) {
-    if (thread_local_set.count(computation) > 0) {
-      thread_local_computations->push_back(computation);
-    } else if (global_set.count(computation) > 0) {
-      global_computations->push_back(computation);
-    }
-    // If the computation is not reachable from the entry computation, then it
-    // will not appear in either thread_local_set or global_set. We don't bother
-    // assigning buffers for these.
-  }
-  return Status::OK();
-}
-
-}  // namespace
-
 /* static */
 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
@@ -1085,7 +1152,8 @@ void BufferAssigner::AddSetToColocatedBufferSets(
   if (colocated_set.empty()) {
     return;
   }
-
+  VLOG(5) << ColocatedBufferSetsToString(colocated_set,
+                                         "Adding colocated buffer set");
   // Find existing sets that overlap with at least one buffer from the
   // colocated_set. The resulting 'overlap_set_indices' will have at most
   // colocated_buffer_sets->size() entries, and will be in increasing order.
@@ -1093,6 +1161,10 @@ void BufferAssigner::AddSetToColocatedBufferSets(
   for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) {
     for (const LogicalBuffer* buffer : colocated_set) {
       if ((*colocated_buffer_sets)[index].count(buffer) > 0) {
+        VLOG(5) << "Found overlap with existing set on buffer "
+                << buffer->ToString() << "\n"
+                << ColocatedBufferSetsToString((*colocated_buffer_sets)[index],
+                                               "Overlapping set");
         overlap_set_indices.push_back(index);
         break;
       }
@@ -1104,6 +1176,7 @@ void BufferAssigner::AddSetToColocatedBufferSets(
     colocated_buffer_sets->emplace_back();
     colocated_buffer_sets->back().insert(colocated_set.begin(),
                                          colocated_set.end());
+    VLOG(5) << "No overlap found, new group created";
     return;
   }
 
@@ -1115,6 +1188,8 @@ void BufferAssigner::AddSetToColocatedBufferSets(
     first->insert(overlap_set.begin(), overlap_set.end());
   }
   first->insert(colocated_set.begin(), colocated_set.end());
+  VLOG(5) << ColocatedBufferSetsToString(
+      *first, "Result of the colocated buffer set merging");
 
   // Remove overlap sets that we just merged. The offset accounts for the fact
   // that as elements are erased, the indices need to be adjusted. Keep in mind
@@ -1125,67 +1200,6 @@ void BufferAssigner::AddSetToColocatedBufferSets(
   }
 }
 
-namespace {
-
-// Checks that points-to set of 'instruction' is unambiguous and distinct
-// (ensured by CopyInsertion), then adds the buffer from the points-to set at
-// 'index' to 'colocated_set'.
-const LogicalBuffer* AddBufferToColocatedSet(
-    const HloInstruction* instruction, const ShapeIndex& index,
-    const TuplePointsToAnalysis& points_to_analysis,
-    std::vector<const LogicalBuffer*>* colocated_set) {
-  // CopyInsertion ensures root points-to set is unambiguous and distinct.
-  const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
-  DCHECK(!points_to.IsAmbiguous());
-  colocated_set->push_back(points_to.element(index)[0]);
-  return colocated_set->back();
-}
-
-// Given the interference map of a graph (the list of interfering node indices
-// for each node), perform graph coloring such that interfering nodes are
-// assigned to different colors. Returns the assigned color of the nodes, where
-// the colors are represented as integer values [0, color_count).
-std::vector<int64> ColorInterferenceGraph(
-    const std::vector<std::vector<int64>>& interference_map) {
-  const int64 node_count = interference_map.size();
-
-  // Sort the nodes such that we assign nodes with more interference first. This
-  // relies on the common heuristic of assigning the most constrained node
-  // first, but it would be good to investigate other ordering heuristics too.
-  std::vector<int64> nodes(node_count);
-  std::iota(nodes.begin(), nodes.end(), 0);
-  std::sort(nodes.begin(), nodes.end(),
-            [&interference_map](const int64 i, const int64 j) {
-              return interference_map[i].size() > interference_map[j].size();
-            });
-
-  const int64 kColorUnassigned = -1;
-  std::vector<int64> assigned_colors(node_count, kColorUnassigned);
-  for (int64 node : nodes) {
-    // Mark the colors that are already assigned to the neighbors.
-    std::vector<bool> available_colors(node_count, true);
-    for (int64 neighbor : interference_map[node]) {
-      int64 color = assigned_colors[neighbor];
-      if (color != kColorUnassigned) {
-        available_colors[color] = false;
-      }
-    }
-
-    // Find the color that is not yet assigned to the neighbors.
-    int64 color = kColorUnassigned;
-    for (color = 0; color < available_colors.size(); ++color) {
-      if (available_colors[color]) {
-        break;
-      }
-    }
-    CHECK_NE(color, kColorUnassigned);
-    assigned_colors[node] = color;
-  }
-  return assigned_colors;
-}
-
-}  // namespace
-
 std::vector<BufferAssigner::ColocatedBufferSet>
 BufferAssigner::MergeColocatedBufferSets(
     const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
index cc19587..df73c28 100644 (file)
@@ -58,6 +58,45 @@ bool ValueIsReadOnly(const HloValue& value) {
   return IsConstantValue(value) || IsEntryParameterValue(value);
 }
 
+// Data structure describing the action which should be taken on parts of a
+// computation buffers, with respect to the adding of special case copies.
+struct SpecialCaseCopyPolicy {
+  // Insert a copy if the same buffer is found at multiple indices within the
+  // output tuple.
+  bool copy_root_replicated_buffers = false;
+  // If true, insert a copy if a buffer coming from a constant or a parameter
+  // is found wihtin the output tuple.
+  bool copy_parameters_and_constants = false;
+};
+
+SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
+                                               HloModule* module,
+                                               HloComputation* computation) {
+  SpecialCaseCopyPolicy policy;
+  if (computation == module->entry_computation()) {
+    policy.copy_parameters_and_constants = true;
+    policy.copy_root_replicated_buffers = true;
+  }
+  for (const CallSite& site : node.caller_callsites()) {
+    // The kWhile instruction does not have an handling here, as the
+    // AddCopiesForWhile() API takes care of adding its own copies.
+    if (site.instruction()->opcode() == HloOpcode::kConditional) {
+      policy.copy_parameters_and_constants = true;
+      policy.copy_root_replicated_buffers = true;
+    }
+  }
+  return policy;
+}
+
+bool ShouldCopyRootValue(const HloValue& value,
+                         const SpecialCaseCopyPolicy& policy) {
+  if (policy.copy_parameters_and_constants) {
+    return IsConstantValue(value) ||
+           value.defining_instruction()->opcode() == HloOpcode::kParameter;
+  }
+  return false;
+}
+
 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
@@ -957,7 +996,8 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
     }
     TF_RET_CHECK(node.context() == CallContext::kSequential);
 
-    const bool is_entry = computation == module->entry_computation();
+    SpecialCaseCopyPolicy policy =
+        GetSpecialCaseCopyPolicy(node, module, computation);
     HloInstruction* root = computation->root_instruction();
 
     // Mark nondistinct/ambiguous indices.
@@ -970,27 +1010,26 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
           for (const HloBuffer* buffer : buffers_at_index) {
             buffer_seen_before |= !seen.insert(buffer).second;
           }
-          if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) {
-            VLOG(2) << "Index " << index << " of root of computation "
+          if (buffers_at_index.size() > 1 ||
+              (buffer_seen_before && policy.copy_root_replicated_buffers)) {
+            VLOG(2) << "Index " << index << " of computation "
                     << computation->name() << " (" << root->name()
                     << ") has ambiguous or non-distinct buffer. Copying.";
             add_index_to_copy(root, index);
           }
         });
 
-    // For entry instructions, mark any parameter or constant values.
-    if (is_entry) {
-      for (const auto& pair :
-           alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
-        const ShapeIndex& index = pair.first;
-        const HloValueSet& value_set = pair.second;
-        for (const HloValue* value : value_set.values()) {
-          if (ValueIsReadOnly(*value)) {
-            VLOG(2) << "Root of entry computation (" << root->name()
-                    << ") has constant or entry parameter value at index "
-                    << index << ". Copying.";
-            add_index_to_copy(root, index);
-          }
+    for (const auto& pair :
+         alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
+      const ShapeIndex& index = pair.first;
+      const HloValueSet& value_set = pair.second;
+      for (const HloValue* value : value_set.values()) {
+        if (ShouldCopyRootValue(*value, policy)) {
+          VLOG(2) << "Root of (" << root->name() << ") of computation("
+                  << computation->name()
+                  << ") has constant or parameter value at index " << index
+                  << ". Copying.";
+          add_index_to_copy(root, index);
         }
       }
     }
@@ -1012,7 +1051,6 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
       instruction->parent()->set_root_instruction(deep_copy);
     }
   }
-
   return Status::OK();
 }