Don't evaluate control flow in EvaluateConstantTensor.
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 13 Mar 2018 00:12:32 +0000 (17:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Mar 2018 00:16:39 +0000 (17:16 -0700)
ExtractConstantSubgraph doesn't copy control edges, which are sometimes necessary
to correctly evaluate conds (at the very least). Avoid evaluating conds at all to
address this.

PiperOrigin-RevId: 188803649

tensorflow/core/common_runtime/eval_const_tensor.cc
tensorflow/python/framework/smart_cond_test.py

index 6370bb5..c1542f1 100644 (file)
@@ -128,12 +128,16 @@ Status ExtractConstantSubgraph(
     return Status::OK();
   }
 
+  if (IsMerge(&target_node)) {
+    return Status::OK();
+  }
+
   if (target_node.type_string() == "PlaceholderWithDefault") {
     return Status::OK();
   }
 
-  // TODO(skyewm): more of the filtering applied in input nodes below should be
-  // applied to target_node here
+  // TODO(skyewm): should more of the filtering applied in input nodes below be
+  // applied to target_node here?
 
   // Identify the possibly constant subgraph by recursively iterating backwards
   // through the inputs to 'target_node' until we either 1) find an already
@@ -153,11 +157,8 @@ Status ExtractConstantSubgraph(
   // Add the target node's inputs to seed the recursion.
   std::deque<const Edge*> edges_to_visit;
   for (const Edge* e : target_node.in_edges()) {
-    // TODO(vrv): What do we do about control edges?  Based on our
-    // definition of a constant graph, we should be free to ignore
-    // control edges since the order in which a constant graph is
-    // executed should be the same regardless of when nodes run: we
-    // should only need to recurse down data edges.
+    // TODO(skyewm): control edges will be meaningful if/when we handle control
+    // flow (e.g. constants in cond branches are triggered via control edges).
     if (e->IsControlEdge()) continue;
     edges_to_visit.push_back(e);
   }
@@ -177,7 +178,9 @@ Status ExtractConstantSubgraph(
     }
 
     // During construction or import from GraphConstructor, back edges may not
-    // be filled in.  Don't constant fold through merges at all for now.
+    // be filled in. In addition, control flow constructs may depend on control
+    // edges which aren't handled by this method. Don't constant fold through
+    // merges at all for now.
     if (IsMerge(current_node)) {
       *is_constant_graph = false;
       return Status::OK();
index 582ce81..1170a41 100644 (file)
@@ -24,6 +24,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import smart_cond
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import googletest
 
@@ -144,5 +145,22 @@ class SmartCaseTest(test_util.TensorFlowTestCase):
       self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
 
 
+@test_util.with_c_api
+class SmartConstantValueTest(test_util.TensorFlowTestCase):
+
+  # TODO(skyewm): this is essentially a regression test for
+  # TF_TryEvaluateConstant, and is not really a valid smart_constant_value test
+  # (smart_constant_value is only supposed to return bools). Move the
+  # TF_TryEvaluateConstant call to tensor_util.constant_value and make this a
+  # constant_value test instead.
+  def testCond(self):
+    with ops.Graph().as_default():
+      pred = array_ops.placeholder_with_default(True, shape=())
+      x = control_flow_ops.cond(pred,
+                                lambda: constant_op.constant(1),
+                                lambda: constant_op.constant(2))
+      self.assertIsNone(smart_cond.smart_constant_value(x))
+
+
 if __name__ == "__main__":
   googletest.main()