Make the CSE ("node deduping") pass in ArithmeticOptimizer more robust in the presenc...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 26 Mar 2018 18:44:19 +0000 (11:44 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 26 Mar 2018 18:46:47 +0000 (11:46 -0700)
PiperOrigin-RevId: 190499037

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc

index 259168b..1a6751b 100644 (file)
@@ -396,12 +396,17 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
       return false;
     }
   }
+  return !ModifiesInputsInPlace(node);
+}
+
+bool ModifiesInputsInPlace(const NodeDef& node) {
   // Some nodes do in-place updates on regular tensor inputs.
-  if (GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace") ||
-      StringPiece(op_name).starts_with("Inplace")) {
-    return false;
+  string op_name = node.op();
+  std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower);
+  if (StringPiece(op_name).contains("inplace")) {
+    return true;
   }
-  return true;
+  return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
 }
 
 bool ModifiesFrameInfo(const NodeDef& node) {
index 49e01f6..1ec1cd4 100644 (file)
@@ -154,8 +154,12 @@ bool IsCommutative(const NodeDef& node);
 bool IsPersistent(const NodeDef& node);
 
 bool IsFreeOfSideEffect(const NodeDef& node);
+
 bool ModifiesFrameInfo(const NodeDef& node);
 
+// Returns true if the op is known to write to one or more of its inputs.
+bool ModifiesInputsInPlace(const NodeDef& node);
+
 // Returns true if the op is an element-wise involution, i.e. if it is its
 // own inverse such that f(f(x)) == x.
 bool IsInvolution(const NodeDef& node);
index bc004df..23e2185 100644 (file)
@@ -1085,6 +1085,24 @@ bool ArithmeticOptimizer::OptimizedNodeExists(const NodeDef& node,
   return node_map_->NodeExists(OptimizedNodeName(node, suffix));
 }
 
+namespace {
+
+bool FeedsInPlaceOp(const SimpleGraphView& graph_view, const NodeDef& node) {
+  const std::unordered_set<string> op_types_to_traverse = {
+      node.op(), "Identity", "IdentityN", "Reshape"};
+  int node_idx = graph_view.index(node.name());
+  std::set<int> node_fanout;
+  graph_view.DepthFirstSearch(op_types_to_traverse, node_idx, &node_fanout);
+  for (int fanout : node_fanout) {
+    if (ModifiesInputsInPlace(graph_view.graph()->node(fanout))) {
+      return true;
+    }
+  }
+  return false;
+}
+
+}  // namespace
+
 bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
     return false;
@@ -1104,6 +1122,11 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
 
 void ArithmeticOptimizer::DedupComputations() {
   bool stop = true;
+  SimpleGraphView graph_view;
+  if (!graph_view.Initialize(*optimized_graph_).ok()) {
+    LOG(WARNING) << "Failed to build SimpleGraphView.";
+    return;
+  }
   std::set<int> duplicates;
   do {
     stop = true;
@@ -1120,19 +1143,28 @@ void ArithmeticOptimizer::DedupComputations() {
       if (rep == node) {
         continue;
       }
+      // If either node feeds an inplace op, deduping them may cause data races.
+      // For example: If we dedup nodes initializing two independent inplace
+      // accumulations, they will write to the same buffer, clobbering each
+      // other's results.
+      if (FeedsInPlaceOp(graph_view, *rep) ||
+          FeedsInPlaceOp(graph_view, *node)) {
+        continue;
+      }
       const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name());
       for (NodeDef* fanout : fanouts) {
-        for (string& name : *fanout->mutable_input()) {
+        for (int i = 0; i < fanout->input_size(); ++i) {
+          string* name = fanout->mutable_input(i);
           int position;
-          const string nodename = ParseNodeName(name, &position);
+          const string nodename = ParseNodeName(*name, &position);
           if (nodename == node->name()) {
             // Update name in-place.
             if (position > 0) {
-              name = StrCat(rep->name(), ":", position);
+              *name = StrCat(rep->name(), ":", position);
             } else if (position == 0) {
-              name = rep->name();
+              *name = rep->name();
             } else {
-              name = StrCat("^", rep->name());
+              *name = StrCat("^", rep->name());
             }
             node_map_->AddOutput(rep->name(), fanout->name());
           }