From 8fe70dc41f269b266b2adce7bed96b3833dda60e Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 5 Feb 2018 15:49:57 -0800 Subject: [PATCH] "frame_name" attr must be altered when importing/exporting MetaGraphDefs. The frame_name attr of Enter operations must be the name of the associated WhileLoopContext, otherwise taking the gradient of the loop will result in frame errors. I'm not happy that the export logic is in meta_graph.py (all other control flow de/serialization is in control_flow_ops.py). However, I can't think of how else to do it, since only export_scoped_meta_graph has access to the NodeDefs being exported. PiperOrigin-RevId: 184599323 --- tensorflow/python/framework/meta_graph.py | 6 ++- tensorflow/python/framework/meta_graph_test.py | 52 ++++++++++++++++++++++++++ tensorflow/python/ops/control_flow_ops.py | 13 +++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index fc1a823..8c03a5f 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -87,6 +87,10 @@ def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False): compat.as_str(s).split("@")[1].startswith(export_scope)] node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue(s=new_s))) + elif node_def.op in ("Enter", "RefEnter") and k == "frame_name": + if not export_scope or compat.as_str(v.s).startswith(export_scope): + new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope)) + node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s)) else: node_def.attr[k].CopyFrom(v) @@ -959,5 +963,3 @@ def copy_scoped_meta_graph(from_scope, to_scope, graph=to_graph, import_scope=to_scope) return var_list - - diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index b5ed135..f2f1e83 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -25,6 +25,7 @@ import shutil from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -34,6 +35,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import nn_ops @@ -447,6 +449,56 @@ class ScopedMetaGraphTest(test.TestCase): del b.collection_def["unbound_inputs"] test_util.assert_meta_graph_protos_equal(self, a, b) + def testWhileLoopGradients(self): + # Create a simple while loop. + with ops.Graph().as_default(): + with ops.name_scope("export"): + var = variables.Variable(0) + var_name = var.name + _, output = control_flow_ops.while_loop(lambda i, x: i < 5, + lambda i, x: (i + 1, x + i), + [0, var]) + output_name = output.name + + # Generate a MetaGraphDef containing the while loop with an export scope. + meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + export_scope="export") + + # Build and run the gradients of the while loop. We use this below to + # verify that the gradients are correct with the imported MetaGraphDef. + init_op = variables.global_variables_initializer() + grad = gradients_impl.gradients([output], [var]) + with session.Session() as sess: + sess.run(init_op) + expected_grad_value = sess.run(grad) + + # Restore the MetaGraphDef into a new Graph with an import scope. + with ops.Graph().as_default(): + meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import") + + # Re-export and make sure we get the same MetaGraphDef. + new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph( + export_scope="import") + test_util.assert_meta_graph_protos_equal( + self, meta_graph_def, new_meta_graph_def) + + # Make sure we can still build gradients and get the same result. + + def new_name(tensor_name): + base_tensor_name = tensor_name.replace("export/", "") + return "import/" + base_tensor_name + + var = ops.get_default_graph().get_tensor_by_name(new_name(var_name)) + output = ops.get_default_graph().get_tensor_by_name(new_name(output_name)) + grad = gradients_impl.gradients([output], [var]) + + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + actual_grad_value = sess.run(grad) + self.assertEqual(expected_grad_value, actual_grad_value) + def testScopedImportUnderNameScope(self): graph = ops.Graph() with graph.as_default(): diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index bcd187d..e75eb08 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -57,6 +57,7 @@ import functools import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import control_flow_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -79,6 +80,7 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.gen_control_flow_ops import * # pylint: enable=wildcard-import from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use @@ -2242,6 +2244,17 @@ class WhileContext(ControlFlowContext): super(WhileContext, self).__init__( values_def=context_def.values_def, import_scope=import_scope) + # import_scope causes self.name to be different from the original serialized + # context's name. Rewrite "frame_name" attrs with the new name. + if import_scope: + for tensor_name in self._values: + op = g.as_graph_element(tensor_name).op + if util.IsLoopEnter(op): + # pylint: disable=protected-access + op._set_attr("frame_name", + attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) + # pylint: enable=protected-access + @property def maximum_iterations(self): """The maximum number of iterations that will be executed.""" -- 2.7.4