From bcec296af809947145a6ebfa1e46b1cafe21ec06 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 May 2018 09:05:59 -0700 Subject: [PATCH] Adds _DefinedFunction.stateful_ops. PiperOrigin-RevId: 195979035 --- tensorflow/python/framework/function.py | 14 ++++++++++++++ tensorflow/python/framework/function_test.py | 4 ++++ 2 files changed, 18 insertions(+) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index f82e94b..b7607ce 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -313,6 +313,16 @@ class _DefinedFunction(object): self._create_definition_if_needed() return self._extra_inputs + @property + def stateful_ops(self): + """Returns the list of stateful ops in function definition. + + Returns: + A list of (op.name, op.type) pairs. + """ + self._create_definition_if_needed() + return self._stateful_ops + def _create_definition_if_needed(self): """Creates the function definition if it's not created yet.""" with context.graph_mode(): @@ -424,6 +434,10 @@ class _DefinedFunction(object): else: self._func_name = compat.as_str(self._op_def.name) + self._stateful_ops = [(op.name, op.type) + for op in temp_graph.get_operations() + if op.op_def.is_stateful] + def _set_c_attrs(self, attrs): """Sets `attrs` as attributes of self._c_func. diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index a5c19f1..caec39f 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -182,6 +182,8 @@ class FunctionTest(test.TestCase): def APlus2B(a, b): return a + b * 2 + # APlus2B is stateless. + self.assertEqual([], APlus2B.stateful_ops) with ops.Graph().as_default(): call = APlus2B([1.0], [2.0]) self.assertEqual("APlus2B", call.op.name) @@ -428,6 +430,8 @@ class FunctionTest(test.TestCase): with ops.control_dependencies([check]): return x * 2 + # Foo contains a stateful op (Assert). + self.assertEqual([("Assert", "Assert")], Foo.stateful_ops) g = ops.Graph() with g.as_default(), self.test_session(): self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0) -- 2.7.4