self._create_definition_if_needed()
return self._extra_inputs
+ @property
+ def stateful_ops(self):
+ """Returns the list of stateful ops in function definition.
+
+ Returns:
+ A list of (op.name, op.type) pairs.
+ """
+ self._create_definition_if_needed()
+ return self._stateful_ops
+
def _create_definition_if_needed(self):
"""Creates the function definition if it's not created yet."""
with context.graph_mode():
else:
self._func_name = compat.as_str(self._op_def.name)
+ self._stateful_ops = [(op.name, op.type)
+ for op in temp_graph.get_operations()
+ if op.op_def.is_stateful]
+
def _set_c_attrs(self, attrs):
"""Sets `attrs` as attributes of self._c_func.
def APlus2B(a, b):
return a + b * 2
+ # APlus2B is stateless.
+ self.assertEqual([], APlus2B.stateful_ops)
with ops.Graph().as_default():
call = APlus2B([1.0], [2.0])
self.assertEqual("APlus2B", call.op.name)
with ops.control_dependencies([check]):
return x * 2
+ # Foo contains a stateful op (Assert).
+ self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
g = ops.Graph()
with g.as_default(), self.test_session():
self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)