[TF:XLA] Making constant folding deterministic.
authorYunxing Dai <yunxing@google.com>
Mon, 5 Feb 2018 20:15:10 +0000 (12:15 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 20:19:04 +0000 (12:19 -0800)
Making constant folding deterministic by doing DFS deterministically and inserting a serialization point based on nodes' names.

This is the last source of non-determinism remaining in the TF:XLA stack.

RELNOTES: Constant folding pass is now deterministic.
PiperOrigin-RevId: 184566644

tensorflow/core/common_runtime/constant_folding.cc
tensorflow/core/common_runtime/constant_folding.h
tensorflow/core/common_runtime/constant_folding_test.cc
tensorflow/core/common_runtime/function_test.cc

index 0398c2a..b5a51d2 100644 (file)
@@ -328,7 +328,8 @@ void FindConstantFoldableNodes(
                ConsiderConstantFoldableNode(
                    n, opts, nodes, constant_control_deps, shape_replacement_map,
                    &internal_node_inserted);
-             });
+             },
+             NodeComparatorName());
   // If we have inserted just leaf level nodes, then there is nothing to fold.
   if (!internal_node_inserted) {
     nodes->clear();
@@ -339,8 +340,8 @@ void FindConstantFoldableNodes(
 typedef std::pair<Node*, int> NodeAndOutput;
 
 int64 UniqueConstantId() {
-  static std::atomic_int_fast64_t id;
-  return id.fetch_add(1);
+  static std::atomic_int_fast64_t unique_constant_id;
+  return unique_constant_id.fetch_add(1);
 }
 
 // Adds n to constant_graph which is being built up for subsequent evaluation of
@@ -386,14 +387,12 @@ void AddShapeNodeToConstantGraph(
     const std::unordered_map<const Node*, std::vector<Tensor>>&
         shape_replacement_map,
     std::unordered_map<Node*, std::vector<Node*>>* node_map,
-    Graph* constant_graph) {
+    const ConstantFoldNameGenerator& generate_new_name, Graph* constant_graph) {
   std::vector<Node*>& added = (*node_map)[n];
   const string& node_name = n->name();
   for (const Tensor& t : shape_replacement_map.at(n)) {
     auto builder =
-        NodeDefBuilder(strings::StrCat(constant_graph->NewName(node_name),
-                                       "__cf__", UniqueConstantId()),
-                       "Const")
+        NodeDefBuilder(generate_new_name(constant_graph, node_name), "Const")
             .Attr("dtype", t.dtype())
             .Attr("value", t);
     NodeDef def;
@@ -414,7 +413,8 @@ Graph* GetConstantGraph(
     const Graph* orig_graph, const std::vector<Node*>& nodes,
     const std::unordered_map<const Node*, std::vector<Tensor>>&
         shape_replacement_map,
-    std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
+    std::map<NodeAndOutput, Node*>* tensors_to_fetch,
+    const ConstantFoldNameGenerator& generate_new_name) {
   Graph* constant_graph = new Graph(orig_graph->op_registry());
   std::unordered_map<Node*, std::vector<Node*>> node_map;
   node_map[orig_graph->source_node()] = {constant_graph->source_node()};
@@ -424,7 +424,7 @@ Graph* GetConstantGraph(
       AddNodeToConstantGraph(n, &node_map, constant_graph);
     } else {
       AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map,
-                                  constant_graph);
+                                  generate_new_name, constant_graph);
     }
   }
 
@@ -458,10 +458,11 @@ Graph* GetConstantGraph(
 // replacement was successful, false otherwise.
 // 'control_deps' is the set of nodes that should be control predecessors of the
 // new constant node.
-bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
-                               NodeAndOutput tensor, const Tensor& constant,
-                               const gtl::FlatSet<Node*>& control_deps,
-                               int64 max_constant_size_in_bytes) {
+bool ReplaceTensorWithConstant(
+    Graph* graph, Device* partition_device, NodeAndOutput tensor,
+    const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
+    int64 max_constant_size_in_bytes,
+    const ConstantFoldNameGenerator& generate_new_name) {
   // Be conservative when replacing a tensor with a constant, when not
   // running on CPU.
   // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
@@ -509,9 +510,7 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
   }
   const string& node_name = n->name();
   Node* constant_node;
-  auto builder = NodeDefBuilder(strings::StrCat(graph->NewName(node_name),
-                                                "__cf__", UniqueConstantId()),
-                                "Const")
+  auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const")
                      .Attr("dtype", constant.dtype())
                      .Attr("value", constant);
   if (partition_device) {
@@ -555,6 +554,13 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
                     FunctionLibraryRuntime* function_library, Env* env,
                     Device* partition_device, Graph* graph, bool* was_mutated) {
   DumpGraph("Before", graph);
+  ConstantFoldNameGenerator generate_new_name = opts.generate_new_name;
+  if (generate_new_name == nullptr) {
+    generate_new_name = [](Graph* graph, string old_name) {
+      return strings::StrCat(graph->NewName(old_name), "__cf__",
+                             UniqueConstantId());
+    };
+  }
 
   std::vector<Node*> constant_foldable_nodes;
   std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps;
@@ -571,7 +577,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
   std::map<NodeAndOutput, Node*> tensors_to_fetch;
   std::unique_ptr<Graph> constant_graph(
       GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map,
-                       &tensors_to_fetch));
+                       &tensors_to_fetch, generate_new_name));
   DumpGraph("Constant graph", constant_graph.get());
 
   if (tensors_to_fetch.empty()) {
@@ -585,7 +591,16 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
 
   std::vector<string> tensors_to_fetch_names;
   std::vector<NodeAndOutput> tensors_to_replace;
-  for (auto n : tensors_to_fetch) {
+  // Sorting the nodes based on the name gives us a stable ordering between runs
+  // for the same graph.
+  std::vector<std::pair<NodeAndOutput, Node*>> tensors_to_fetch_sorted(
+      tensors_to_fetch.begin(), tensors_to_fetch.end());
+  std::sort(tensors_to_fetch_sorted.begin(), tensors_to_fetch_sorted.end(),
+            [](const std::pair<NodeAndOutput, Node*>& n1,
+               const std::pair<NodeAndOutput, Node*>& n2) {
+              return n1.first.first->name() < n2.first.first->name();
+            });
+  for (auto n : tensors_to_fetch_sorted) {
     tensors_to_fetch_names.push_back(
         strings::StrCat(n.first.first->name(), ":", n.first.second));
     tensors_to_replace.push_back({n.second, n.first.second});
@@ -617,7 +632,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
         constant_control_deps[tensors_to_replace[c].first];
     if (ReplaceTensorWithConstant(
             graph, partition_device, tensors_to_replace[c], outputs[c],
-            control_deps, opts.max_constant_size_in_bytes)) {
+            control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
       ++num_nodes_replaced;
     }
   }
index e4d724c..b1e1fb8 100644 (file)
@@ -24,6 +24,11 @@ limitations under the License.
 
 namespace tensorflow {
 
+// This generator type is used to generate a name for the newly folded node
+// based on the node's old name.
+using ConstantFoldNameGenerator =
+    std::function<string(Graph* graph, string old_name)>;
+
 // Options specific to constant folding optimizations.
 struct ConstantFoldingOptions {
   // If "consider" is not a nullptr, then only constant fold a node "n" if
@@ -37,6 +42,11 @@ struct ConstantFoldingOptions {
   // The maximum size of each constant created during constant folding
   // optimization.
   int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
+
+  // A generator for the name suffix of constant folded nodes. A
+  // default id generator that monotonically increases is used if nullptr is
+  // passed.
+  ConstantFoldNameGenerator generate_new_name = nullptr;
 };
 
 // Perform constant folding optimization on "graph".
index 923a4d9..6ac9319 100644 (file)
@@ -121,6 +121,58 @@ TEST_F(ConstantFoldingTest, Basic) {
                          {2, 2});
 }
 
+// Tests that different node creation ordering creates same graph after constant
+// folding.
+TEST_F(ConstantFoldingTest, DeterministicFolding) {
+  auto build_graph_and_constant_folding = [](Graph& g, bool swap) -> Status {
+    Scope s = Scope::NewRootScope();
+    auto a = ops::Const<float>(s, {1.0}, {});
+    auto b = ops::Const<float>(s, {2.0}, {});
+
+    if (swap) {
+      auto add1 = ops::Add(s.WithOpName("add1"), a, b);
+      auto add2 = ops::Add(s.WithOpName("add2"), a, b);
+      auto s1 =
+          ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver");
+      auto s2 =
+          ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver");
+    } else {
+      // Swap the order of node creation.
+      auto add2 = ops::Add(s.WithOpName("add2"), a, b);
+      auto add1 = ops::Add(s.WithOpName("add1"), a, b);
+      auto s1 =
+          ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver");
+      auto s2 =
+          ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver");
+    }
+
+    TF_CHECK_OK(s.ToGraph(&g));
+    bool was_mutated;
+    int64 unique_id = 0;
+    auto generate_new_name = [&unique_id](Graph* graph, string old_name) {
+      return strings::StrCat(graph->NewName(old_name), "__cf__", unique_id++);
+    };
+    ConstantFoldingOptions opt{};
+    opt.generate_new_name = generate_new_name;
+    TF_CHECK_OK(
+        ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated));
+    return Status::OK();
+  };
+
+  Graph g1(OpRegistry::Global());
+  TF_ASSERT_OK(build_graph_and_constant_folding(g1, false));
+  Graph g2(OpRegistry::Global());
+  TF_ASSERT_OK(build_graph_and_constant_folding(g2, true));
+  EXPECT_EQ(g1.num_nodes(), g2.num_nodes());
+  auto index = NodeNameIndex(g2);
+
+  // All the nodes in g1 are expected to be present in g2.
+  for (int64 i = 0; i < g1.num_nodes(); ++i) {
+    Node* n1 = g1.FindNodeId(i);
+    EXPECT_GT(index.count(n1->name()), 0);
+  }
+}
+
 TEST_F(ConstantFoldingTest, ConsiderFunction) {
   Scope s = Scope::NewRootScope();
   BuildSimpleGraph(&s);
index cad3b38..8b05146 100644 (file)
@@ -787,7 +787,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
     Scope s = Scope::NewRootScope();
     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
     auto x4_x2_scale = ops::Const<float>(
-        s.WithOpName("x4/x2/scale/_15__cf__9")
+        s.WithOpName("x4/x2/scale/_12__cf__6")
             .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
         2.0f);
     auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -993,13 +993,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
     auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
     auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
     auto scale = ops::Const(
-        s.WithOpName("scale/_5__cf__10")
+        s.WithOpName("scale/_6__cf__11")
             .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
         2.0f);
     auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
     auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
     auto const0 = ops::Const(
-        s.WithOpName("Func/_1/sy/_6__cf__11")
+        s.WithOpName("Func/_1/sy/_5__cf__10")
             .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
         0, {0});
     auto func1_rx = ops::internal::BroadcastGradientArgs(