Adds _DefinedFunction.stateful_ops.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 9 May 2018 16:05:59 +0000 (09:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 17:50:14 +0000 (10:50 -0700)
PiperOrigin-RevId: 195979035

tensorflow/python/framework/function.py
tensorflow/python/framework/function_test.py

index f82e94b..b7607ce 100644 (file)
@@ -313,6 +313,16 @@ class _DefinedFunction(object):
     self._create_definition_if_needed()
     return self._extra_inputs
 
+  @property
+  def stateful_ops(self):
+    """Returns the list of stateful ops in function definition.
+
+    Returns:
+      A list of (op.name, op.type) pairs.
+    """
+    self._create_definition_if_needed()
+    return self._stateful_ops
+
   def _create_definition_if_needed(self):
     """Creates the function definition if it's not created yet."""
     with context.graph_mode():
@@ -424,6 +434,10 @@ class _DefinedFunction(object):
       else:
         self._func_name = compat.as_str(self._op_def.name)
 
+    self._stateful_ops = [(op.name, op.type)
+                          for op in temp_graph.get_operations()
+                          if op.op_def.is_stateful]
+
   def _set_c_attrs(self, attrs):
     """Sets `attrs` as attributes of self._c_func.
 
index a5c19f1..caec39f 100644 (file)
@@ -182,6 +182,8 @@ class FunctionTest(test.TestCase):
     def APlus2B(a, b):
       return a + b * 2
 
+    # APlus2B is stateless.
+    self.assertEqual([], APlus2B.stateful_ops)
     with ops.Graph().as_default():
       call = APlus2B([1.0], [2.0])
       self.assertEqual("APlus2B", call.op.name)
@@ -428,6 +430,8 @@ class FunctionTest(test.TestCase):
       with ops.control_dependencies([check]):
         return x * 2
 
+    # Foo contains a stateful op (Assert).
+    self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
     g = ops.Graph()
     with g.as_default(), self.test_session():
       self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)