From 68430112b2ca5c160db6dd412d43f572ec69e72f Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Fri, 25 May 2018 13:20:13 -0700 Subject: [PATCH] Public API to switch between eager execution and graph building. Now, after tf.enable_eager_execution() has been executed, entering the context manager of a tf.Graph will enable graph mode. So, for example ``` tf.enable_eager_execution() with tf.Graph().as_default(): c = tf.constant(1.0) # this is a graph tensor c2 = tf.constant(1.0) # this is an eager tensor ``` The main use-case of this is allowing documentation writers to make a single notebook which starts with eager execution and seamlessly transitions to building graphs. This also makes many explicit enablings of graph mode in the code redundant (a cleanup cl will follow). PiperOrigin-RevId: 198092991 --- .../contrib/distribute/python/mirrored_strategy.py | 13 ++++++- .../contrib/distribute/python/monitor_test.py | 3 +- tensorflow/contrib/eager/python/saver_test.py | 45 +++++++++------------- .../contrib/opt/python/training/adamax_test.py | 8 ++-- tensorflow/contrib/optimizer_v2/momentum_test.py | 11 +----- tensorflow/python/framework/ops.py | 32 ++++----------- tensorflow/python/framework/ops_test.py | 26 +++++++++++-- tensorflow/python/framework/test_util.py | 20 ++++++---- tensorflow/python/framework/test_util_test.py | 5 ++- .../python/kernel_tests/accumulate_n_eager_test.py | 7 ++-- tensorflow/python/kernel_tests/py_func_test.py | 5 +-- tensorflow/python/ops/variables.py | 2 +- tensorflow/python/training/adam_test.py | 7 ++-- tensorflow/python/training/momentum_test.py | 11 +----- tensorflow/python/training/training_util.py | 20 +++++----- 15 files changed, 105 insertions(+), 110 deletions(-) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 89f2c43..14dbbd6 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib import threading import six @@ -39,6 +40,16 @@ from tensorflow.python.training import distribute as distribute_lib # TODO(josh11b): Replace asserts in this file with if ...: raise ... +@contextlib.contextmanager +def _enter_graph(g): + if context.executing_eagerly(): + with g.as_default(), context.eager_mode(): + yield + else: + with g.as_default(): + yield + + def _cpu_device(device): cpu_device = tf_device.DeviceSpec.from_string(device) cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) @@ -458,7 +469,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with self.coord.stop_on_exception(), \ context.context()._mode(self.context_mode), \ context.context().device_policy(self.context_device_policy), \ - self.graph.as_default(), \ + _enter_graph(self.graph), \ MirroredTowerContext(self.distribution, self.tower_id), \ ops.device(self.device), \ ops.name_scope(self._captured_name_scope), \ diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 8277e1e..4fdb9bf 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import monitor as monitor_lib from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import ops @@ -65,7 +66,7 @@ class MonitorTest(test.TestCase, parameterized.TestCase): step_function, _ = single_loss_example( lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) - with self.test_session() as sess: + with session.Session() as sess, context.eager_mode(): with self.assertRaisesRegexp(ValueError, "Should not provide"): _ = monitor_lib.Monitor(step_function, sess) diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 4032e75..90a3711 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -60,15 +60,9 @@ class SaverTest(test.TestCase): def testSameNameNoClobbering(self): with ops.device(self._dev()): - # Note that this test purposefully uses Graphs rather than - # IsolateTest. Users are more likely to accidentally create the same - # variable name this way. - first_graph = ops.Graph() - with first_graph.as_default(): - v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') - with ops.Graph().as_default(): - v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') - saver = _saver.Saver([v1_first_graph, v1_second_graph]) + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + v2 = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1, v2]) ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') with self.assertRaisesRegexp(ValueError, 'v1'): saver.save(ckpt_prefix) @@ -126,12 +120,11 @@ class SaverTest(test.TestCase): saver = _saver.Saver([v1]) saver.save(ckpt_prefix) - with ops.Graph().as_default(): - saver = _saver.Saver([v1]) - with _saver.restore_variables_on_create(ckpt_prefix): - # Value is from checkpoint, but not from argument. - ret, _ = model(2.0) - self.assertEqual(ret.numpy(), 1.0) + saver = _saver.Saver([v1]) + with _saver.restore_variables_on_create(ckpt_prefix): + # Value is from checkpoint, but not from argument. + ret, _ = model(2.0) + self.assertEqual(ret.numpy(), 1.0) def testRestoreNotFound(self): with ops.device(self._dev()): @@ -184,17 +177,17 @@ class SaverTest(test.TestCase): 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # reset the graph and reload on create, so that 1 + 2 = 3 - with ops.Graph().as_default(): - with _saver.restore_variables_on_create(ckpt_prefix): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model2(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - self.assertEqual( - 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + ops.reset_default_graph() + with _saver.restore_variables_on_create(ckpt_prefix): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model2(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + self.assertEqual( + 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) class GetOptimizerTests(test.TestCase): diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py index bc92a70..21bf3f5 100644 --- a/tensorflow/contrib/opt/python/training/adamax_test.py +++ b/tensorflow/contrib/opt/python/training/adamax_test.py @@ -198,11 +198,11 @@ class AdaMaxOptimizerTest(test.TestCase): self.assertTrue(beta1_power is not None) self.assertIn(beta1_power, opt_variables) - with ops.Graph().as_default(): - # Shouldn't return non-slot variables from other graphs. - self.assertEqual(0, len(opt.variables())) - if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) diff --git a/tensorflow/contrib/optimizer_v2/momentum_test.py b/tensorflow/contrib/optimizer_v2/momentum_test.py index 26724f6..24cdab4 100644 --- a/tensorflow/contrib/optimizer_v2/momentum_test.py +++ b/tensorflow/contrib/optimizer_v2/momentum_test.py @@ -134,7 +134,6 @@ class MomentumOptimizerTest(test.TestCase): with context.eager_mode(): self.doTestBasic(use_resource=True, use_callable_params=True) - @test_util.run_in_graph_and_eager_modes(reset_test=True) def testVariablesAcrossGraphs(self): optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5) with ops.Graph().as_default(): @@ -142,10 +141,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var0") var1 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var1") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var0 + var1) - else: - loss = math_ops.reduce_sum(var0 + var1) + loss = math_ops.reduce_sum(var0 + var1) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var0") @@ -157,10 +153,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var2") var3 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var3") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var2 + var3) - else: - loss = math_ops.reduce_sum(var2 + var3) + loss = math_ops.reduce_sum(var2 + var3) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var2") diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 9fc8136..3af0cc4 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections import copy -import functools import linecache import os import re @@ -3861,6 +3860,9 @@ class Graph(object): assert c.graph is g ``` + If eager execution is enabled ops created under this context manager will be + added to the graph instead of executed eagerly. + Returns: A context manager for using this graph as the default graph. """ @@ -5270,35 +5272,15 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access @tf_contextlib.contextmanager def get_controller(self, default): try: - if context.executing_eagerly(): - # A Graph alone on the context stack would keep init_scope-wrapped - # operations graph building when entered (assuming init_scope is called - # in a graph building context). Instead, we push a context which first - # enables eager execution and then re-enters the Graph. - context.context().context_switches.push( - default.building_function, - functools.partial( - _enter_context_and_graph, - context.eager_mode, - default.as_default)) - else: - # This Graph is being used from a graph building context. A lack of - # context switch implies that the context is graph building. - context.context().context_switches.push(default.building_function, - default.as_default) - with super(_DefaultGraphStack, self).get_controller(default) as g: + context.context().context_switches.push( + default.building_function, default.as_default) + with super(_DefaultGraphStack, self).get_controller( + default) as g, context.graph_mode(): yield g finally: context.context().context_switches.pop() -@tf_contextlib.contextmanager -def _enter_context_and_graph(context_fn, graph_fn): - """Combines two context managers.""" - with context_fn(), graph_fn(): - yield - - _default_graph_stack = _DefaultGraphStack() diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 87317db..e773263 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -2215,12 +2215,25 @@ class InitScopeTest(test_util.TensorFlowTestCase): self.assertEqual(ops.get_name_scope(), "inner") self.assertEqual(ops.get_name_scope(), "") - def testEagerGraphContextsExecuteEagerly(self): + def testEnteringGraphFromEagerIsSticky(self): with context.eager_mode(): + g = ops.Graph() + with g.as_default(): + with ops.init_scope(): + self.assertFalse(context.executing_eagerly()) + self.assertEqual(g, ops.get_default_graph()) + + def testMixGraphEager(self): + with context.eager_mode(): + c = constant_op.constant(1.0) with ops.Graph().as_default(): - with context.graph_mode(): - with ops.init_scope(): - self.assertTrue(context.executing_eagerly()) + with self.assertRaisesRegexp( + RuntimeError, "Attempting to capture an EagerTensor"): + math_ops.add(c, c) + c2 = constant_op.constant(2.0) + with self.assertRaisesRegexp( + TypeError, "contains objects other than 'EagerTensor'"): + math_ops.add(c2, c2) def testPreservesNameScopeInEagerExecution(self): with context.eager_mode(): @@ -2254,6 +2267,11 @@ class GraphTest(test_util.TensorFlowTestCase): with g0.as_default(): ops.reset_default_graph() + def testGraphContextManagerCancelsEager(self): + with context.eager_mode(): + with ops.Graph().as_default(): + self.assertFalse(context.executing_eagerly()) + def testGraphContextManager(self): g0 = ops.Graph() with g0.as_default() as g1: diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 5b01df4..b56483f 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -556,12 +556,16 @@ def assert_no_new_tensors(f): tensors_before = set( id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) - outside_graph_key = ops.get_default_graph()._graph_key - with ops.Graph().as_default(): + if context.executing_eagerly(): + f(self, **kwargs) + ops.reset_default_graph() + else: # Run the test in a new graph so that collections get cleared when it's # done, but inherit the graph key so optimizers behave. - ops.get_default_graph()._graph_key = outside_graph_key - f(self, **kwargs) + outside_graph_key = ops.get_default_graph()._graph_key + with ops.Graph().as_default(): + ops.get_default_graph()._graph_key = outside_graph_key + f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. backprop._zeros_cache.flush() @@ -727,12 +731,12 @@ def run_in_graph_and_eager_modes(__unused__=None, f(self, **kwargs) if assert_no_eager_garbage: + ops.reset_default_graph() run_eagerly = assert_no_new_tensors( assert_no_garbage_created(run_eagerly)) with context.eager_mode(): - with ops.Graph().as_default(): - run_eagerly(self, **kwargs) + run_eagerly(self, **kwargs) return decorated @@ -1027,7 +1031,9 @@ class TensorFlowTestCase(googletest.TestCase): rewriter_config_pb2.RewriterConfig.OFF) return config - if graph is None: + if context.executing_eagerly(): + yield None + elif graph is None: if self._cached_session is None: self._cached_session = session.Session( graph=None, config=prepare_config(config)) diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 0f53762..0178908 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -619,6 +619,7 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): ReferenceCycleTest().test_has_no_cycle() + @test_util.run_in_graph_and_eager_modes() def test_no_leaked_tensor_decorator(self): class LeakedTensorTest(object): @@ -628,11 +629,11 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): @test_util.assert_no_new_tensors def test_has_leak(self): - self.a = constant_op.constant([3.]) + self.a = constant_op.constant([3.], name="leak") @test_util.assert_no_new_tensors def test_has_no_leak(self): - constant_op.constant([3.]) + constant_op.constant([3.], name="no-leak") with self.assertRaisesRegexp(AssertionError, "Tensors not deallocated"): LeakedTensorTest().test_has_leak() diff --git a/tensorflow/python/kernel_tests/accumulate_n_eager_test.py b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py index dc11b7d..5f516f2 100644 --- a/tensorflow/python/kernel_tests/accumulate_n_eager_test.py +++ b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py @@ -43,10 +43,9 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): np.random.seed(12345) x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] tf_x = ops.convert_n_to_tensor(x) - with self.test_session(use_gpu=True): - self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).numpy()) - self.assertAllClose(x[0] * 5, - math_ops.accumulate_n([tf_x[0]] * 5).numpy()) + self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x)) + self.assertAllClose(x[0] * 5, + math_ops.accumulate_n([tf_x[0]] * 5)) def testGrad(self): np.random.seed(42) diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index b9f44d7..dc7399f 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -446,9 +446,8 @@ class PyFuncTest(test.TestCase): a = array_ops.ones((3, 3), dtype=dtypes.int32) x = array_ops.ones((3, 1), dtype=dtypes.int32) output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32) - with self.test_session(): - ret = self.evaluate(output) - self.assertAllEqual(ret, [[3], [3], [3]]) + ret = self.evaluate(output) + self.assertAllEqual(ret, [[3], [3], [3]]) @test_util.run_in_graph_and_eager_modes() def testEagerSingleOutputFloat32(self): diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 959ae08..d88fd83 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -259,7 +259,7 @@ class Variable(checkpointable.CheckpointableBase): constraint=constraint) def __repr__(self): - if context.executing_eagerly(): + if context.executing_eagerly() and not self._in_graph_mode: return "" % ( self.name, self.get_shape(), self.dtype.name, ops.numpy_text(self.read_value(), is_repr=True)) diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index 9be8b6a..bc68f24 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -180,11 +180,10 @@ class AdamOptimizerTest(test.TestCase): self.assertIn(beta1_power, opt_variables) self.assertIn(beta2_power, opt_variables) - with ops.Graph().as_default(): - # Shouldn't return non-slot variables from other graphs. - self.assertEqual(0, len(opt.variables())) - if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index 7bd57ad..f7e7807 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -134,7 +134,6 @@ class MomentumOptimizerTest(test.TestCase): with context.eager_mode(): self.doTestBasic(use_resource=True, use_callable_params=True) - @test_util.run_in_graph_and_eager_modes(reset_test=True) def testVariablesAcrossGraphs(self): optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5) with ops.Graph().as_default(): @@ -142,10 +141,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var0") var1 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var1") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var0 + var1) - else: - loss = math_ops.reduce_sum(var0 + var1) + loss = math_ops.reduce_sum(var0 + var1) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var0") @@ -157,10 +153,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var2") var3 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var3") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var2 + var3) - else: - loss = math_ops.reduce_sum(var2 + var3) + loss = math_ops.reduce_sum(var2 + var3) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var2") diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index d05e1d2..0877b2a 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -119,18 +119,18 @@ def create_global_step(graph=None): graph = graph or ops.get_default_graph() if get_global_step(graph) is not None: raise ValueError('"global_step" already exists.') + if context.executing_eagerly(): + with ops.device('cpu:0'): + return variable_scope.get_variable( + ops.GraphKeys.GLOBAL_STEP, + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES, + ops.GraphKeys.GLOBAL_STEP]) # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): - if context.executing_eagerly(): - with ops.device('cpu:0'): - return variable_scope.get_variable( - ops.GraphKeys.GLOBAL_STEP, - shape=[], - dtype=dtypes.int64, - initializer=init_ops.zeros_initializer(), - trainable=False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES, - ops.GraphKeys.GLOBAL_STEP]) return variable_scope.get_variable( ops.GraphKeys.GLOBAL_STEP, shape=[], -- 2.7.4