return Status::OK();
}
+ if (IsMerge(&target_node)) {
+ return Status::OK();
+ }
+
if (target_node.type_string() == "PlaceholderWithDefault") {
return Status::OK();
}
- // TODO(skyewm): more of the filtering applied in input nodes below should be
- // applied to target_node here
+ // TODO(skyewm): should more of the filtering applied in input nodes below be
+ // applied to target_node here?
// Identify the possibly constant subgraph by recursively iterating backwards
// through the inputs to 'target_node' until we either 1) find an already
// Add the target node's inputs to seed the recursion.
std::deque<const Edge*> edges_to_visit;
for (const Edge* e : target_node.in_edges()) {
- // TODO(vrv): What do we do about control edges? Based on our
- // definition of a constant graph, we should be free to ignore
- // control edges since the order in which a constant graph is
- // executed should be the same regardless of when nodes run: we
- // should only need to recurse down data edges.
+ // TODO(skyewm): control edges will be meaningful if/when we handle control
+ // flow (e.g. constants in cond branches are triggered via control edges).
if (e->IsControlEdge()) continue;
edges_to_visit.push_back(e);
}
}
// During construction or import from GraphConstructor, back edges may not
- // be filled in. Don't constant fold through merges at all for now.
+ // be filled in. In addition, control flow constructs may depend on control
+ // edges which aren't handled by this method. Don't constant fold through
+ // merges at all for now.
if (IsMerge(current_node)) {
*is_constant_graph = false;
return Status::OK();
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
+@test_util.with_c_api
+class SmartConstantValueTest(test_util.TensorFlowTestCase):
+
+ # TODO(skyewm): this is essentially a regression test for
+ # TF_TryEvaluateConstant, and is not really a valid smart_constant_value test
+ # (smart_constant_value is only supposed to return bools). Move the
+ # TF_TryEvaluateConstant call to tensor_util.constant_value and make this a
+ # constant_value test instead.
+ def testCond(self):
+ with ops.Graph().as_default():
+ pred = array_ops.placeholder_with_default(True, shape=())
+ x = control_flow_ops.cond(pred,
+ lambda: constant_op.constant(1),
+ lambda: constant_op.constant(2))
+ self.assertIsNone(smart_cond.smart_constant_value(x))
+
+
if __name__ == "__main__":
googletest.main()