Temporarily disable constant folding past Enter, since a few breakages have been...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 21 Mar 2018 19:53:53 +0000 (12:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 19:58:02 +0000 (12:58 -0700)
PiperOrigin-RevId: 189952252

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index 2639835..bdec73e 100644 (file)
@@ -1707,7 +1707,9 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
     }
 
     // Move constants past Enter.
-    if (IsEnter(*node) && node->input_size() > 0) {
+    // TODO(rmlarsen): Reenable when we fix the root cause of b/76008022
+    if (opt_level_ == RewriterConfig::AGGRESSIVE && IsEnter(*node) &&
+        node->input_size() > 0) {
       const string& node_name = node->name();
       const NodeDef* input = node_map_->GetNode(node->input(0));
       if (input != nullptr && IsReallyConstant(*input) &&
@@ -1729,6 +1731,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
           NodeDef* new_node = optimized_graph->add_node();
           *new_node = *input;
           new_node->set_name(OptimizedNodeName(*input, "_enter"));
+          new_node->set_device(node->device());
           new_node->clear_input();
           new_node->add_input(AsControlDependency(node_name));
           node_map_->AddNode(new_node->name(), new_node);
index 13ecfcd..b6645d3 100644 (file)
@@ -38,7 +38,7 @@ class ConstantFolding : public GraphOptimizer {
   static string AddControlDependency(const string& input_name, GraphDef* graph,
                                      NodeMap* node_map);
 
-  ConstantFolding(DeviceBase* cpu_device);
+  explicit ConstantFolding(DeviceBase* cpu_device);
   ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device);
 
   ~ConstantFolding() override {}
index aeb430b..914a925 100644 (file)
@@ -2103,7 +2103,8 @@ TEST_F(ConstantFoldingTest, Enter) {
   item.fetch.push_back("id2");
   item.fetch.push_back("id3");
 
-  ConstantFolding optimizer(nullptr /* cpu_device */);
+  ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+                            nullptr /* cpu_device */);
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);