From e2a9276d485ab3c3b5b0ebfbc92fc105cfa7419f Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 22 Feb 2018 12:16:09 -0800 Subject: [PATCH] Add eager support for unit tests for most Keras layers. A few minor layers were left out: - noise layers (apparent issue with tf.random_normal) - bidirectional wrapper - conv recurrent layers (impending refactor) PiperOrigin-RevId: 186654795 --- tensorflow/python/keras/_impl/keras/backend.py | 3 +- .../keras/_impl/keras/engine/training_eager.py | 17 +- .../_impl/keras/engine/training_eager_test.py | 19 +- .../keras/_impl/keras/layers/convolutional_test.py | 68 +++++- .../python/keras/_impl/keras/layers/core_test.py | 244 ++++++++++----------- .../keras/_impl/keras/layers/embeddings_test.py | 72 +++--- .../python/keras/_impl/keras/layers/gru_test.py | 62 +++--- .../python/keras/_impl/keras/layers/local_test.py | 48 ++-- .../python/keras/_impl/keras/layers/lstm_test.py | 61 +++--- .../python/keras/_impl/keras/layers/merge_test.py | 212 +++++++++--------- .../python/keras/_impl/keras/layers/noise_test.py | 11 +- .../keras/_impl/keras/layers/pooling_test.py | 230 +++++++++---------- .../python/keras/_impl/keras/layers/recurrent.py | 136 ++++-------- .../keras/_impl/keras/layers/simplernn_test.py | 62 +++--- .../keras/_impl/keras/layers/wrappers_test.py | 56 +++-- tensorflow/python/keras/_impl/keras/optimizers.py | 4 +- .../python/keras/_impl/keras/testing_utils.py | 7 +- 17 files changed, 660 insertions(+), 652 deletions(-) diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index a238a3f..a2db05f 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -3087,7 +3087,8 @@ def rnn(step_function, outputs_shape[1] = inputs_shape[1] outputs.set_shape(outputs_shape) - last_output._uses_learning_phase = uses_learning_phase + if not context.in_eager_mode(): + last_output._uses_learning_phase = uses_learning_phase return last_output, outputs, new_states diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py index 3507f36..282dd0d 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py @@ -29,6 +29,7 @@ from tensorflow.python.keras._impl.keras import metrics as metrics_module from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays +from tensorflow.python.platform import tf_logging as logging def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None): @@ -196,8 +197,7 @@ def _process_single_batch(eager_model_inputs, eager_model_outputs, model, output of the model, total loss and the loss associated with each output. Raises: - ValueError: If the model loss is 0 or if the trainable weights list is - empty when the trainable parameter is set to True. + ValueError: If the model has no loss to optimize. """ K.set_learning_phase(training) with GradientTape() as tape: @@ -209,12 +209,13 @@ def _process_single_batch(eager_model_inputs, eager_model_outputs, model, 'because it has no loss to optimize.') if training: if not model._collected_trainable_weights: - raise ValueError('The list of trainable weights is empty. Make sure that ' - 'you are not setting model.trainable to False before ' - 'compiling the model.') - grads = tape.gradient(loss, model._collected_trainable_weights) - model.optimizer.apply_gradients(zip(grads, - model._collected_trainable_weights)) + logging.warning('The list of trainable weights is empty. Make sure that ' + 'you are not setting model.trainable to False before ' + 'compiling the model.') + else: + grads = tape.gradient(loss, model._collected_trainable_weights) + model.optimizer.apply_gradients(zip(grads, + model._collected_trainable_weights)) return outs, loss, loss_metrics diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py index 45601f9..3d94b75 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py @@ -26,6 +26,7 @@ from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -397,17 +398,13 @@ class LossWeightingTest(test.TestCase): optimizer=RMSPropOptimizer(learning_rate=0.001)) np.random.seed(43) - (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + (x_train, y_train), _ = testing_utils.get_test_data( train_samples=train_samples, test_samples=test_samples, input_shape=(input_dim,), num_classes=num_classes) - int_y_test = y_test.copy() int_y_train = y_train.copy() - # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, num_classes) - y_test = keras.utils.to_categorical(y_test, num_classes) - test_ids = np.where(int_y_test == np.array(weighted_class))[0] class_weight = dict([(i, 1.) for i in range(num_classes)]) class_weight[weighted_class] = 2. @@ -549,8 +546,10 @@ class TestDynamicTrainability(test.TestCase): model.trainable = False model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') model.trainable = True - with self.assertRaises(ValueError): + with test.mock.patch.object(logging, 'warning') as mock_log: model.train_on_batch(x, y) + self.assertRegexpMatches(str(mock_log.call_args), + 'trainable weights is empty') def test_trainable_argument(self): x = np.random.random((5, 3)) @@ -560,8 +559,10 @@ class TestDynamicTrainability(test.TestCase): model.add(keras.layers.Dense(2, input_dim=3, trainable=False)) model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') out = model.predict(x) - with self.assertRaises(ValueError): + with test.mock.patch.object(logging, 'warning') as mock_log: model.train_on_batch(x, y) + self.assertRegexpMatches(str(mock_log.call_args), + 'trainable weights is empty') out_2 = model.predict(x) self.assertAllClose(out, out_2) @@ -571,8 +572,10 @@ class TestDynamicTrainability(test.TestCase): model = keras.models.Model(inputs, output) model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') out = model.predict(x) - with self.assertRaises(ValueError): + with test.mock.patch.object(logging, 'warning') as mock_log: model.train_on_batch(x, y) + self.assertRegexpMatches(str(mock_log.call_args), + 'trainable weights is empty') out_2 = model.predict(x) self.assertAllClose(out, out_2) diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py index 4a62281..c612e97 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py @@ -22,6 +22,8 @@ import copy import numpy as np +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -43,6 +45,7 @@ class Convolution1DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, length, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() def test_conv1d(self): kwargs = { 'filters': 2, @@ -114,6 +117,7 @@ class Conv2DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() def test_conv2d(self): kwargs = { 'filters': 2, @@ -188,6 +192,7 @@ class Conv2DTransposeTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() def test_conv2dtranspose(self): kwargs = { 'filters': 2, @@ -253,6 +258,7 @@ class Conv3DTransposeTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, depth, num_row, num_col, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() def test_conv3dtranspose(self): kwargs = { 'filters': 2, @@ -316,6 +322,7 @@ class SeparableConv1DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, length, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() def test_separable_conv1d(self): kwargs = { 'filters': 2, @@ -391,6 +398,7 @@ class SeparableConv2DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() def test_separable_conv2d(self): kwargs = { 'filters': 2, @@ -469,6 +477,7 @@ class Conv3DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, depth, num_row, num_col, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() def test_conv3d(self): kwargs = { 'filters': 2, @@ -520,6 +529,7 @@ class Conv3DTest(test.TestCase): class ZeroPaddingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_zero_padding_1d(self): num_samples = 2 input_dim = 2 @@ -543,7 +553,10 @@ class ZeroPaddingTest(test.TestCase): layer = keras.layers.ZeroPadding1D(padding=2) layer.build(shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) for offset in [0, 1, -1, -2]: np.testing.assert_allclose(np_output[:, offset, :], 0.) np.testing.assert_allclose(np_output[:, 2:-2, :], 1.) @@ -551,7 +564,10 @@ class ZeroPaddingTest(test.TestCase): layer = keras.layers.ZeroPadding1D(padding=(1, 2)) layer.build(shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) for left_offset in [0]: np.testing.assert_allclose(np_output[:, left_offset, :], 0.) for right_offset in [-1, -2]: @@ -565,6 +581,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding1D(padding=None) + @tf_test_util.run_in_graph_and_eager_modes() def test_zero_padding_2d(self): num_samples = 2 stack_size = 2 @@ -593,7 +610,10 @@ class ZeroPaddingTest(test.TestCase): padding=(2, 2), data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) if data_format == 'channels_last': for offset in [0, 1, -1, -2]: np.testing.assert_allclose(np_output[:, offset, :, :], 0.) @@ -609,7 +629,10 @@ class ZeroPaddingTest(test.TestCase): padding=((1, 2), (3, 4)), data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) if data_format == 'channels_last': for top_offset in [0]: np.testing.assert_allclose(np_output[:, top_offset, :, :], 0.) @@ -637,6 +660,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding2D(padding=None) + @tf_test_util.run_in_graph_and_eager_modes() def test_zero_padding_3d(self): num_samples = 2 stack_size = 2 @@ -659,7 +683,10 @@ class ZeroPaddingTest(test.TestCase): layer = keras.layers.ZeroPadding3D(padding=(2, 2, 2)) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) for offset in [0, 1, -1, -2]: np.testing.assert_allclose(np_output[:, offset, :, :, :], 0.) np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.) @@ -675,11 +702,13 @@ class ZeroPaddingTest(test.TestCase): class UpSamplingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_upsampling_1d(self): with self.test_session(use_gpu=True): testing_utils.layer_test( keras.layers.UpSampling1D, kwargs={'size': 2}, input_shape=(3, 5, 4)) + @tf_test_util.run_in_graph_and_eager_modes() def test_upsampling_2d(self): num_samples = 2 stack_size = 2 @@ -708,7 +737,10 @@ class UpSamplingTest(test.TestCase): size=(length_row, length_col), data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) if data_format == 'channels_first': assert np_output.shape[2] == length_row * input_num_row assert np_output.shape[3] == length_col * input_num_col @@ -726,6 +758,7 @@ class UpSamplingTest(test.TestCase): np.testing.assert_allclose(np_output, expected_out) + @tf_test_util.run_in_graph_and_eager_modes() def test_upsampling_3d(self): num_samples = 2 stack_size = 2 @@ -757,7 +790,10 @@ class UpSamplingTest(test.TestCase): data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) if data_format == 'channels_first': assert np_output.shape[2] == length_dim1 * input_len_dim1 assert np_output.shape[3] == length_dim2 * input_len_dim2 @@ -782,6 +818,7 @@ class UpSamplingTest(test.TestCase): class CroppingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_cropping_1d(self): num_samples = 2 time_length = 4 @@ -800,6 +837,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping1D(cropping=None) + @tf_test_util.run_in_graph_and_eager_modes() def test_cropping_2d(self): num_samples = 2 stack_size = 2 @@ -827,7 +865,10 @@ class CroppingTest(test.TestCase): cropping=cropping, data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) # compare with numpy if data_format == 'channels_first': expected_out = inputs[:, :, cropping[0][0]:-cropping[0][1], cropping[ @@ -851,7 +892,10 @@ class CroppingTest(test.TestCase): cropping=cropping, data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) # compare with input np.testing.assert_allclose(np_output, inputs) @@ -861,6 +905,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping2D(cropping=None) + @tf_test_util.run_in_graph_and_eager_modes() def test_cropping_3d(self): num_samples = 2 stack_size = 2 @@ -892,7 +937,10 @@ class CroppingTest(test.TestCase): cropping=cropping, data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) # compare with numpy if data_format == 'channels_first': expected_out = inputs[:, :, diff --git a/tensorflow/python/keras/_impl/keras/layers/core_test.py b/tensorflow/python/keras/_impl/keras/layers/core_test.py index bdb99c9..2ca816a 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/core_test.py @@ -20,11 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils -from tensorflow.python.ops import init_ops from tensorflow.python.platform import test @@ -52,146 +50,134 @@ class CoreLayersTest(test.TestCase): dropout = keras.layers.Dropout(0.5) self.assertEqual(True, dropout.supports_masking) - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout1D, - kwargs={'rate': 0.5}, - input_shape=(2, 3, 4)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout2D, - kwargs={'rate': 0.5}, - input_shape=(2, 3, 4, 5)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout2D, - kwargs={'rate': 0.5, 'data_format': 'channels_first'}, - input_shape=(2, 3, 4, 5)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout3D, - kwargs={'rate': 0.5}, - input_shape=(2, 3, 4, 4, 5)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout3D, - kwargs={'rate': 0.5, 'data_format': 'channels_first'}, - input_shape=(2, 3, 4, 4, 5)) - + @tf_test_util.run_in_graph_and_eager_modes() + def test_spatial_dropout(self): + testing_utils.layer_test( + keras.layers.SpatialDropout1D, + kwargs={'rate': 0.5}, + input_shape=(2, 3, 4)) + + testing_utils.layer_test( + keras.layers.SpatialDropout2D, + kwargs={'rate': 0.5}, + input_shape=(2, 3, 4, 5)) + + testing_utils.layer_test( + keras.layers.SpatialDropout2D, + kwargs={'rate': 0.5, 'data_format': 'channels_first'}, + input_shape=(2, 3, 4, 5)) + + testing_utils.layer_test( + keras.layers.SpatialDropout3D, + kwargs={'rate': 0.5}, + input_shape=(2, 3, 4, 4, 5)) + + testing_utils.layer_test( + keras.layers.SpatialDropout3D, + kwargs={'rate': 0.5, 'data_format': 'channels_first'}, + input_shape=(2, 3, 4, 4, 5)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_activation(self): # with string argument - with self.test_session(): - testing_utils.layer_test( - keras.layers.Activation, - kwargs={'activation': 'relu'}, - input_shape=(3, 2)) + testing_utils.layer_test( + keras.layers.Activation, + kwargs={'activation': 'relu'}, + input_shape=(3, 2)) # with function argument - with self.test_session(): - testing_utils.layer_test( - keras.layers.Activation, - kwargs={'activation': keras.backend.relu}, - input_shape=(3, 2)) + testing_utils.layer_test( + keras.layers.Activation, + kwargs={'activation': keras.backend.relu}, + input_shape=(3, 2)) + @tf_test_util.run_in_graph_and_eager_modes() def test_reshape(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Reshape, - kwargs={'target_shape': (8, 1)}, - input_shape=(3, 2, 4)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.Reshape, - kwargs={'target_shape': (-1, 1)}, - input_shape=(3, 2, 4)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.Reshape, - kwargs={'target_shape': (1, -1)}, - input_shape=(3, 2, 4)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.Reshape, - kwargs={'target_shape': (-1, 1)}, - input_shape=(None, None, 2)) - + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (8, 1)}, + input_shape=(3, 2, 4)) + + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (-1, 1)}, + input_shape=(3, 2, 4)) + + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (1, -1)}, + input_shape=(3, 2, 4)) + + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (-1, 1)}, + input_shape=(None, None, 2)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_permute(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4)) + testing_utils.layer_test( + keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4)) + @tf_test_util.run_in_graph_and_eager_modes() def test_flatten(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4)) + testing_utils.layer_test( + keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4)) + @tf_test_util.run_in_graph_and_eager_modes() def test_repeat_vector(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2)) + testing_utils.layer_test( + keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2)) + @tf_test_util.run_in_graph_and_eager_modes() def test_lambda(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Lambda, - kwargs={'function': lambda x: x + 1}, - input_shape=(3, 2)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.Lambda, - kwargs={ - 'function': lambda x, a, b: x * a + b, - 'arguments': { - 'a': 0.6, - 'b': 0.4 - } - }, - input_shape=(3, 2)) - - with self.test_session(): - # test serialization with function - def f(x): - return x + 1 - - ld = keras.layers.Lambda(f) - config = ld.get_config() - ld = keras.layers.deserialize({ - 'class_name': 'Lambda', - 'config': config - }) - - # test with lambda - ld = keras.layers.Lambda( - lambda x: keras.backend.concatenate([keras.backend.square(x), x])) - config = ld.get_config() - ld = keras.layers.Lambda.from_config(config) - + testing_utils.layer_test( + keras.layers.Lambda, + kwargs={'function': lambda x: x + 1}, + input_shape=(3, 2)) + + testing_utils.layer_test( + keras.layers.Lambda, + kwargs={ + 'function': lambda x, a, b: x * a + b, + 'arguments': { + 'a': 0.6, + 'b': 0.4 + } + }, + input_shape=(3, 2)) + + # test serialization with function + def f(x): + return x + 1 + + ld = keras.layers.Lambda(f) + config = ld.get_config() + ld = keras.layers.deserialize({ + 'class_name': 'Lambda', + 'config': config + }) + + # test with lambda + ld = keras.layers.Lambda( + lambda x: keras.backend.concatenate([keras.backend.square(x), x])) + config = ld.get_config() + ld = keras.layers.Lambda.from_config(config) + + @tf_test_util.run_in_graph_and_eager_modes() def test_dense(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2)) + testing_utils.layer_test( + keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2)) - with self.test_session(): - testing_utils.layer_test( - keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 2)) + testing_utils.layer_test( + keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 2)) - with self.test_session(): - testing_utils.layer_test( - keras.layers.Dense, kwargs={'units': 3}, input_shape=(None, None, 2)) + testing_utils.layer_test( + keras.layers.Dense, kwargs={'units': 3}, input_shape=(None, None, 2)) - with self.test_session(): - testing_utils.layer_test( - keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2)) + testing_utils.layer_test( + keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2)) - # Test regularization + def test_dense_regularization(self): with self.test_session(): layer = keras.layers.Dense( 3, @@ -202,7 +188,7 @@ class CoreLayersTest(test.TestCase): layer(keras.backend.variable(np.ones((2, 4)))) self.assertEqual(3, len(layer.losses)) - # Test constraints + def test_dense_constraints(self): with self.test_session(): k_constraint = keras.constraints.max_norm(0.01) b_constraint = keras.constraints.max_norm(0.01) @@ -212,12 +198,6 @@ class CoreLayersTest(test.TestCase): self.assertEqual(layer.kernel.constraint, k_constraint) self.assertEqual(layer.bias.constraint, b_constraint) - def test_eager_dense(self): - with context.eager_mode(): - l = keras.layers.Dense(units=3, - kernel_initializer=init_ops.zeros_initializer()) - self.assertAllEqual(l(constant_op.constant([[1.0]])), [[0., 0., 0.]]) - def test_activity_regularization(self): with self.test_session(): layer = keras.layers.ActivityRegularization(l1=0.1) diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py b/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py index 1712111..26fd1f1 100644 --- a/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -25,47 +26,44 @@ from tensorflow.python.platform import test class EmbeddingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_embedding(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Embedding, - kwargs={'output_dim': 4, - 'input_dim': 10, - 'input_length': 2}, - input_shape=(3, 2), - input_dtype='int32', - expected_output_dtype='float32') + testing_utils.layer_test( + keras.layers.Embedding, + kwargs={'output_dim': 4, + 'input_dim': 10, + 'input_length': 2}, + input_shape=(3, 2), + input_dtype='int32', + expected_output_dtype='float32') - with self.test_session(): - testing_utils.layer_test( - keras.layers.Embedding, - kwargs={'output_dim': 4, - 'input_dim': 10, - 'mask_zero': True}, - input_shape=(3, 2), - input_dtype='int32', - expected_output_dtype='float32') + testing_utils.layer_test( + keras.layers.Embedding, + kwargs={'output_dim': 4, + 'input_dim': 10, + 'mask_zero': True}, + input_shape=(3, 2), + input_dtype='int32', + expected_output_dtype='float32') - with self.test_session(): - testing_utils.layer_test( - keras.layers.Embedding, - kwargs={'output_dim': 4, - 'input_dim': 10, - 'mask_zero': True}, - input_shape=(3, 4, 2), - input_dtype='int32', - expected_output_dtype='float32') + testing_utils.layer_test( + keras.layers.Embedding, + kwargs={'output_dim': 4, + 'input_dim': 10, + 'mask_zero': True}, + input_shape=(3, 4, 2), + input_dtype='int32', + expected_output_dtype='float32') - with self.test_session(): - testing_utils.layer_test( - keras.layers.Embedding, - kwargs={'output_dim': 4, - 'input_dim': 10, - 'mask_zero': True, - 'input_length': (None, 2)}, - input_shape=(3, 4, 2), - input_dtype='int32', - expected_output_dtype='float32') + testing_utils.layer_test( + keras.layers.Embedding, + kwargs={'output_dim': 4, + 'input_dim': 10, + 'mask_zero': True, + 'input_length': (None, 2)}, + input_shape=(3, 4, 2), + input_dtype='int32', + expected_output_dtype='float32') if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/layers/gru_test.py b/tensorflow/python/keras/_impl/keras/layers/gru_test.py index c57fbac..48e7e14 100644 --- a/tensorflow/python/keras/_impl/keras/layers/gru_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/gru_test.py @@ -20,64 +20,66 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer class GRULayerTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_return_sequences_GRU(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.GRU, - kwargs={'units': units, - 'return_sequences': True}, - input_shape=(num_samples, timesteps, embedding_dim)) + testing_utils.layer_test( + keras.layers.GRU, + kwargs={'units': units, + 'return_sequences': True}, + input_shape=(num_samples, timesteps, embedding_dim)) + @tf_test_util.run_in_graph_and_eager_modes() def test_dynamic_behavior_GRU(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - layer = keras.layers.GRU(units, input_shape=(None, embedding_dim)) - model = keras.models.Sequential() - model.add(layer) - model.compile('sgd', 'mse') - x = np.random.random((num_samples, timesteps, embedding_dim)) - y = np.random.random((num_samples, units)) - model.train_on_batch(x, y) - + layer = keras.layers.GRU(units, input_shape=(None, embedding_dim)) + model = keras.models.Sequential() + model.add(layer) + model.compile(RMSPropOptimizer(0.01), 'mse') + x = np.random.random((num_samples, timesteps, embedding_dim)) + y = np.random.random((num_samples, units)) + model.train_on_batch(x, y) + + @tf_test_util.run_in_graph_and_eager_modes() def test_dropout_GRU(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.GRU, - kwargs={'units': units, - 'dropout': 0.1, - 'recurrent_dropout': 0.1}, - input_shape=(num_samples, timesteps, embedding_dim)) - + testing_utils.layer_test( + keras.layers.GRU, + kwargs={'units': units, + 'dropout': 0.1, + 'recurrent_dropout': 0.1}, + input_shape=(num_samples, timesteps, embedding_dim)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_implementation_mode_GRU(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - for mode in [0, 1, 2]: - testing_utils.layer_test( - keras.layers.GRU, - kwargs={'units': units, - 'implementation': mode}, - input_shape=(num_samples, timesteps, embedding_dim)) + for mode in [0, 1, 2]: + testing_utils.layer_test( + keras.layers.GRU, + kwargs={'units': units, + 'implementation': mode}, + input_shape=(num_samples, timesteps, embedding_dim)) def test_statefulness_GRU(self): num_samples = 2 diff --git a/tensorflow/python/keras/_impl/keras/layers/local_test.py b/tensorflow/python/keras/_impl/keras/layers/local_test.py index a815a0f..93741d2 100644 --- a/tensorflow/python/keras/_impl/keras/layers/local_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/local_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -27,6 +28,7 @@ from tensorflow.python.platform import test class LocallyConnectedLayersTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_locallyconnected_1d(self): num_samples = 2 num_steps = 8 @@ -39,16 +41,15 @@ class LocallyConnectedLayersTest(test.TestCase): if padding == 'same' and strides != 1: continue - with self.test_session(): - testing_utils.layer_test( - keras.layers.LocallyConnected1D, - kwargs={ - 'filters': filters, - 'kernel_size': filter_length, - 'padding': padding, - 'strides': strides - }, - input_shape=(num_samples, num_steps, input_dim)) + testing_utils.layer_test( + keras.layers.LocallyConnected1D, + kwargs={ + 'filters': filters, + 'kernel_size': filter_length, + 'padding': padding, + 'strides': strides + }, + input_shape=(num_samples, num_steps, input_dim)) def test_locallyconnected_1d_regularization(self): num_samples = 2 @@ -86,6 +87,7 @@ class LocallyConnectedLayersTest(test.TestCase): self.assertEqual(layer.kernel.constraint, k_constraint) self.assertEqual(layer.bias.constraint, b_constraint) + @tf_test_util.run_in_graph_and_eager_modes() def test_locallyconnected_2d(self): num_samples = 8 filters = 3 @@ -98,20 +100,18 @@ class LocallyConnectedLayersTest(test.TestCase): if padding == 'same' and strides != (1, 1): continue - with self.test_session(): - testing_utils.layer_test( - keras.layers.LocallyConnected2D, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'padding': padding, - 'kernel_regularizer': 'l2', - 'bias_regularizer': 'l2', - 'activity_regularizer': 'l2', - 'strides': strides, - 'data_format': 'channels_last' - }, - input_shape=(num_samples, num_row, num_col, stack_size)) + testing_utils.layer_test( + keras.layers.LocallyConnected2D, + kwargs={ + 'filters': filters, + 'kernel_size': 3, + 'padding': padding, + 'kernel_regularizer': 'l2', + 'bias_regularizer': 'l2', + 'strides': strides, + 'data_format': 'channels_last' + }, + input_shape=(num_samples, num_row, num_col, stack_size)) def test_locallyconnected_2d_channels_first(self): num_samples = 8 diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py index 1de5485..74548d0 100644 --- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py @@ -20,28 +20,29 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer class LSTMLayerTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_return_sequences_LSTM(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.LSTM, - kwargs={'units': units, - 'return_sequences': True}, - input_shape=(num_samples, timesteps, embedding_dim)) + testing_utils.layer_test( + keras.layers.LSTM, + kwargs={'units': units, + 'return_sequences': True}, + input_shape=(num_samples, timesteps, embedding_dim)) def test_static_shape_inference_LSTM(self): # Github issue: 15165 - num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 @@ -55,45 +56,45 @@ class LSTMLayerTest(test.TestCase): outputs = model.layers[-1].output self.assertEquals(outputs.get_shape().as_list(), [None, timesteps, units]) + @tf_test_util.run_in_graph_and_eager_modes() def test_dynamic_behavior_LSTM(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - layer = keras.layers.LSTM(units, input_shape=(None, embedding_dim)) - model = keras.models.Sequential() - model.add(layer) - model.compile('sgd', 'mse') - x = np.random.random((num_samples, timesteps, embedding_dim)) - y = np.random.random((num_samples, units)) - model.train_on_batch(x, y) + layer = keras.layers.LSTM(units, input_shape=(None, embedding_dim)) + model = keras.models.Sequential() + model.add(layer) + model.compile(RMSPropOptimizer(0.001), 'mse') + x = np.random.random((num_samples, timesteps, embedding_dim)) + y = np.random.random((num_samples, units)) + model.train_on_batch(x, y) + @tf_test_util.run_in_graph_and_eager_modes() def test_dropout_LSTM(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.LSTM, - kwargs={'units': units, - 'dropout': 0.1, - 'recurrent_dropout': 0.1}, - input_shape=(num_samples, timesteps, embedding_dim)) - + testing_utils.layer_test( + keras.layers.LSTM, + kwargs={'units': units, + 'dropout': 0.1, + 'recurrent_dropout': 0.1}, + input_shape=(num_samples, timesteps, embedding_dim)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_implementation_mode_LSTM(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - for mode in [0, 1, 2]: - testing_utils.layer_test( - keras.layers.LSTM, - kwargs={'units': units, - 'implementation': mode}, - input_shape=(num_samples, timesteps, embedding_dim)) + for mode in [0, 1, 2]: + testing_utils.layer_test( + keras.layers.LSTM, + kwargs={'units': units, + 'implementation': mode}, + input_shape=(num_samples, timesteps, embedding_dim)) def test_statefulness_LSTM(self): num_samples = 2 diff --git a/tensorflow/python/keras/_impl/keras/layers/merge_test.py b/tensorflow/python/keras/_impl/keras/layers/merge_test.py index bb03dda..b2fe06f 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -27,24 +28,25 @@ from tensorflow.python.platform import test class MergeLayersTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_add(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - i3 = keras.layers.Input(shape=(4, 5)) + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + i3 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.add([i1, i2, i3]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2, i3], o) + o = keras.layers.add([i1, i2, i3]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2, i3], o) - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - x3 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2, x3]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, x1 + x2 + x3, atol=1e-4) + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + x3 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2, x3]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, x1 + x2 + x3, atol=1e-4) - # test masking + def test_merge_add_masking(self): + with self.test_session(): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) m1 = keras.layers.Masking()(i1) @@ -54,11 +56,13 @@ class MergeLayersTest(test.TestCase): mask = layer.output_mask self.assertListEqual(mask.get_shape().as_list(), [None, 4]) - # test missing shape + def test_merge_add_dynamic_shape(self): + with self.test_session(): i1 = array_ops.placeholder(shape=(4, None), dtype='float32') i2 = array_ops.placeholder(shape=(4, 5), dtype='float32') layer = keras.layers.Add() o = layer([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [4, 5]) def test_merge_elementwise_errors(self): i1 = keras.layers.Input(shape=(4, 5)) @@ -72,79 +76,82 @@ class MergeLayersTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.add([i1]) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_multiply(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - i3 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.multiply([i1, i2, i3]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2, i3], o) - - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - x3 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2, x3]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, x1 * x2 * x3, atol=1e-4) - + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + i3 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.multiply([i1, i2, i3]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2, i3], o) + + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + x3 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2, x3]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, x1 * x2 * x3, atol=1e-4) + + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_average(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.average([i1, i2]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2], o) + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.average([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2], o) - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4) + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_maximum(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.maximum([i1, i2]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2], o) + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.maximum([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2], o) - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_minimum(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.minimum([i1, i2]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2], o) + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.minimum([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2], o) - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_concatenate(self): + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.concatenate([i1, i2], axis=1) + self.assertListEqual(o.get_shape().as_list(), [None, 8, 5]) + model = keras.models.Model([i1, i2], o) + + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 8, 5)) + self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4) + + def test_merge_concatenate_masking(self): with self.test_session(): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.concatenate([i1, i2], axis=1) - self.assertListEqual(o.get_shape().as_list(), [None, 8, 5]) - model = keras.models.Model([i1, i2], o) - - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 8, 5)) - self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4) - - # test masking m1 = keras.layers.Masking()(i1) layer = keras.layers.Concatenate() o = layer([m1, i2]) @@ -162,35 +169,35 @@ class MergeLayersTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'called on a list'): keras.layers.concatenate([i1], axis=-1) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_dot(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4,)) - i2 = keras.layers.Input(shape=(4,)) - o = keras.layers.dot([i1, i2], axes=1) - self.assertListEqual(o.get_shape().as_list(), [None, 1]) - model = keras.models.Model([i1, i2], o) - _ = keras.layers.Dot(axes=1).get_config() - - x1 = np.random.random((2, 4)) - x2 = np.random.random((2, 4)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 1)) - expected = np.zeros((2, 1)) - expected[0, 0] = np.dot(x1[0], x2[0]) - expected[1, 0] = np.dot(x1[1], x2[1]) - self.assertAllClose(out, expected, atol=1e-4) - - # Test with negative tuple of axes. - o = keras.layers.dot([i1, i2], axes=(-1, -1)) - self.assertListEqual(o.get_shape().as_list(), [None, 1]) - model = keras.models.Model([i1, i2], o) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 1)) - self.assertAllClose(out, expected, atol=1e-4) - - # test compute_output_shape - layer = keras.layers.Dot(axes=-1) - self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1)) + i1 = keras.layers.Input(shape=(4,)) + i2 = keras.layers.Input(shape=(4,)) + o = keras.layers.dot([i1, i2], axes=1) + self.assertListEqual(o.get_shape().as_list(), [None, 1]) + model = keras.models.Model([i1, i2], o) + _ = keras.layers.Dot(axes=1).get_config() + + x1 = np.random.random((2, 4)) + x2 = np.random.random((2, 4)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 1)) + expected = np.zeros((2, 1)) + expected[0, 0] = np.dot(x1[0], x2[0]) + expected[1, 0] = np.dot(x1[1], x2[1]) + self.assertAllClose(out, expected, atol=1e-4) + + # Test with negative tuple of axes. + o = keras.layers.dot([i1, i2], axes=(-1, -1)) + self.assertListEqual(o.get_shape().as_list(), [None, 1]) + model = keras.models.Model([i1, i2], o) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 1)) + self.assertAllClose(out, expected, atol=1e-4) + + # test compute_output_shape + layer = keras.layers.Dot(axes=-1) + self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1)) def test_dot_errors(self): i1 = keras.layers.Input(shape=(4, 5)) @@ -208,6 +215,7 @@ class MergeLayersTest(test.TestCase): dot = keras.layers.Dot(1) dot.compute_output_shape(1) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_subtract(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) diff --git a/tensorflow/python/keras/_impl/keras/layers/noise_test.py b/tensorflow/python/keras/_impl/keras/layers/noise_test.py index f9b4d9c..af4f031 100644 --- a/tensorflow/python/keras/_impl/keras/layers/noise_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/noise_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -39,12 +40,12 @@ class NoiseLayersTest(test.TestCase): kwargs={'rate': 0.5}, input_shape=(3, 2, 3)) + @tf_test_util.run_in_graph_and_eager_modes() def test_AlphaDropout(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.AlphaDropout, - kwargs={'rate': 0.2}, - input_shape=(3, 2, 3)) + testing_utils.layer_test( + keras.layers.AlphaDropout, + kwargs={'rate': 0.2}, + input_shape=(3, 2, 3)) if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py index ec0a5ae..70049f0 100644 --- a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -25,81 +27,85 @@ from tensorflow.python.platform import test class GlobalPoolingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_globalpooling_1d(self): - with self.test_session(use_gpu=True): - testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D, - input_shape=(3, 4, 5)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5)) + testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D, + input_shape=(3, 4, 5)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5)) + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_globalpooling_2d(self): - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.pooling.GlobalMaxPooling2D, - kwargs={'data_format': 'channels_first'}, - input_shape=(3, 4, 5, 6)) - testing_utils.layer_test( - keras.layers.pooling.GlobalMaxPooling2D, - kwargs={'data_format': 'channels_last'}, - input_shape=(3, 5, 6, 4)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling2D, - kwargs={'data_format': 'channels_first'}, - input_shape=(3, 4, 5, 6)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling2D, - kwargs={'data_format': 'channels_last'}, - input_shape=(3, 5, 6, 4)) - + testing_utils.layer_test( + keras.layers.pooling.GlobalMaxPooling2D, + kwargs={'data_format': 'channels_first'}, + input_shape=(3, 4, 5, 6)) + testing_utils.layer_test( + keras.layers.pooling.GlobalMaxPooling2D, + kwargs={'data_format': 'channels_last'}, + input_shape=(3, 5, 6, 4)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling2D, + kwargs={'data_format': 'channels_first'}, + input_shape=(3, 4, 5, 6)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling2D, + kwargs={'data_format': 'channels_last'}, + input_shape=(3, 5, 6, 4)) + + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_globalpooling_3d(self): - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.pooling.GlobalMaxPooling3D, - kwargs={'data_format': 'channels_first'}, - input_shape=(3, 4, 3, 4, 3)) - testing_utils.layer_test( - keras.layers.pooling.GlobalMaxPooling3D, - kwargs={'data_format': 'channels_last'}, - input_shape=(3, 4, 3, 4, 3)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling3D, - kwargs={'data_format': 'channels_first'}, - input_shape=(3, 4, 3, 4, 3)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling3D, - kwargs={'data_format': 'channels_last'}, - input_shape=(3, 4, 3, 4, 3)) + testing_utils.layer_test( + keras.layers.pooling.GlobalMaxPooling3D, + kwargs={'data_format': 'channels_first'}, + input_shape=(3, 4, 3, 4, 3)) + testing_utils.layer_test( + keras.layers.pooling.GlobalMaxPooling3D, + kwargs={'data_format': 'channels_last'}, + input_shape=(3, 4, 3, 4, 3)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling3D, + kwargs={'data_format': 'channels_first'}, + input_shape=(3, 4, 3, 4, 3)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling3D, + kwargs={'data_format': 'channels_last'}, + input_shape=(3, 4, 3, 4, 3)) class Pooling2DTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_maxpooling_2d(self): pool_size = (3, 3) - with self.test_session(use_gpu=True): - for strides in [(1, 1), (2, 2)]: - testing_utils.layer_test( - keras.layers.MaxPooling2D, - kwargs={ - 'strides': strides, - 'padding': 'valid', - 'pool_size': pool_size - }, - input_shape=(3, 5, 6, 4)) - - def test_averagepooling_2d(self): - with self.test_session(use_gpu=True): + for strides in [(1, 1), (2, 2)]: testing_utils.layer_test( - keras.layers.AveragePooling2D, - kwargs={'strides': (2, 2), - 'padding': 'same', - 'pool_size': (2, 2)}, - input_shape=(3, 5, 6, 4)) - testing_utils.layer_test( - keras.layers.AveragePooling2D, - kwargs={'strides': (2, 2), - 'padding': 'valid', - 'pool_size': (3, 3)}, + keras.layers.MaxPooling2D, + kwargs={ + 'strides': strides, + 'padding': 'valid', + 'pool_size': pool_size + }, input_shape=(3, 5, 6, 4)) + + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) + def test_averagepooling_2d(self): + testing_utils.layer_test( + keras.layers.AveragePooling2D, + kwargs={'strides': (2, 2), + 'padding': 'same', + 'pool_size': (2, 2)}, + input_shape=(3, 5, 6, 4)) + testing_utils.layer_test( + keras.layers.AveragePooling2D, + kwargs={'strides': (2, 2), + 'padding': 'valid', + 'pool_size': (3, 3)}, + input_shape=(3, 5, 6, 4)) + + # This part of the test can only run on GPU but doesn't appear + # to be properly assigned to a GPU when running in eager mode. + if not context.in_eager_mode(): # Only runs on GPU with CUDA, channels_first is not supported on CPU. # TODO(b/62340061): Support channels_first on CPU. if test.is_gpu_available(cuda_only=True): @@ -116,66 +122,66 @@ class Pooling2DTest(test.TestCase): class Pooling3DTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_maxpooling_3d(self): pool_size = (3, 3, 3) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.MaxPooling3D, - kwargs={'strides': 2, - 'padding': 'valid', - 'pool_size': pool_size}, - input_shape=(3, 11, 12, 10, 4)) - testing_utils.layer_test( - keras.layers.MaxPooling3D, - kwargs={ - 'strides': 3, - 'padding': 'valid', - 'data_format': 'channels_first', - 'pool_size': pool_size - }, - input_shape=(3, 4, 11, 12, 10)) - + testing_utils.layer_test( + keras.layers.MaxPooling3D, + kwargs={'strides': 2, + 'padding': 'valid', + 'pool_size': pool_size}, + input_shape=(3, 11, 12, 10, 4)) + testing_utils.layer_test( + keras.layers.MaxPooling3D, + kwargs={ + 'strides': 3, + 'padding': 'valid', + 'data_format': 'channels_first', + 'pool_size': pool_size + }, + input_shape=(3, 4, 11, 12, 10)) + + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_averagepooling_3d(self): pool_size = (3, 3, 3) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.AveragePooling3D, - kwargs={'strides': 2, - 'padding': 'valid', - 'pool_size': pool_size}, - input_shape=(3, 11, 12, 10, 4)) - testing_utils.layer_test( - keras.layers.AveragePooling3D, - kwargs={ - 'strides': 3, - 'padding': 'valid', - 'data_format': 'channels_first', - 'pool_size': pool_size - }, - input_shape=(3, 4, 11, 12, 10)) + testing_utils.layer_test( + keras.layers.AveragePooling3D, + kwargs={'strides': 2, + 'padding': 'valid', + 'pool_size': pool_size}, + input_shape=(3, 11, 12, 10, 4)) + testing_utils.layer_test( + keras.layers.AveragePooling3D, + kwargs={ + 'strides': 3, + 'padding': 'valid', + 'data_format': 'channels_first', + 'pool_size': pool_size + }, + input_shape=(3, 4, 11, 12, 10)) class Pooling1DTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_maxpooling_1d(self): - with self.test_session(use_gpu=True): - for padding in ['valid', 'same']: - for stride in [1, 2]: - testing_utils.layer_test( - keras.layers.MaxPooling1D, - kwargs={'strides': stride, - 'padding': padding}, - input_shape=(3, 5, 4)) + for padding in ['valid', 'same']: + for stride in [1, 2]: + testing_utils.layer_test( + keras.layers.MaxPooling1D, + kwargs={'strides': stride, + 'padding': padding}, + input_shape=(3, 5, 4)) + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_averagepooling_1d(self): - with self.test_session(use_gpu=True): - for padding in ['valid', 'same']: - for stride in [1, 2]: - testing_utils.layer_test( - keras.layers.AveragePooling1D, - kwargs={'strides': stride, - 'padding': padding}, - input_shape=(3, 5, 4)) + for padding in ['valid', 'same']: + for stride in [1, 2]: + testing_utils.layer_test( + keras.layers.AveragePooling1D, + kwargs={'strides': stride, + 'padding': padding}, + input_shape=(3, 5, 4)) if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index 2e9003f..a81971d 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -22,6 +22,7 @@ from __future__ import print_function import numbers import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import activations from tensorflow.python.keras._impl.keras import backend as K @@ -935,7 +936,9 @@ class SimpleRNNCell(Layer): # Properly set learning phase on output tensor. if 0 < self.dropout + self.recurrent_dropout: - if training is None: + if training is None and not context.in_eager_mode(): + # This would be harmless to set in eager mode, but eager tensors + # disallow setting arbitrary attributes. output._uses_learning_phase = True return output, [output] @@ -1299,23 +1302,6 @@ class GRUCell(Layer): constraint=self.bias_constraint) else: self.bias = None - - self.kernel_z = self.kernel[:, :self.units] - self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units] - self.kernel_r = self.kernel[:, self.units:self.units * 2] - self.recurrent_kernel_r = self.recurrent_kernel[:, self.units: - self.units * 2] - self.kernel_h = self.kernel[:, self.units * 2:] - self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:] - - if self.use_bias: - self.bias_z = self.bias[:self.units] - self.bias_r = self.bias[self.units:self.units * 2] - self.bias_h = self.bias[self.units * 2:] - else: - self.bias_z = None - self.bias_r = None - self.bias_h = None self.built = True def call(self, inputs, states, training=None): @@ -1350,13 +1336,13 @@ class GRUCell(Layer): inputs_z = inputs inputs_r = inputs inputs_h = inputs - x_z = K.dot(inputs_z, self.kernel_z) - x_r = K.dot(inputs_r, self.kernel_r) - x_h = K.dot(inputs_h, self.kernel_h) + x_z = K.dot(inputs_z, self.kernel[:, :self.units]) + x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2]) + x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:]) if self.use_bias: - x_z = K.bias_add(x_z, self.bias_z) - x_r = K.bias_add(x_r, self.bias_r) - x_h = K.bias_add(x_h, self.bias_h) + x_z = K.bias_add(x_z, self.bias[:self.units]) + x_r = K.bias_add(x_r, self.bias[self.units:self.units * 2]) + x_h = K.bias_add(x_h, self.bias[self.units * 2:]) if 0. < self.recurrent_dropout < 1.: h_tm1_z = h_tm1 * rec_dp_mask[0] @@ -1367,11 +1353,14 @@ class GRUCell(Layer): h_tm1_r = h_tm1 h_tm1_h = h_tm1 z = self.recurrent_activation( - x_z + K.dot(h_tm1_z, self.recurrent_kernel_z)) + x_z + K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])) r = self.recurrent_activation( - x_r + K.dot(h_tm1_r, self.recurrent_kernel_r)) + x_r + K.dot(h_tm1_r, self.recurrent_kernel[:, self.units: + self.units * 2])) - hh = self.activation(x_h + K.dot(r * h_tm1_h, self.recurrent_kernel_h)) + hh = self.activation(x_h + K.dot(r * h_tm1_h, + self.recurrent_kernel[:, + self.units * 2:])) else: if 0. < self.dropout < 1.: inputs *= dp_mask[0] @@ -1395,44 +1384,34 @@ class GRUCell(Layer): hh = self.activation(x_h + recurrent_h) h = z * h_tm1 + (1 - z) * hh if 0 < self.dropout + self.recurrent_dropout: - if training is None: + if training is None and not context.in_eager_mode(): + # This would be harmless to set in eager mode, but eager tensors + # disallow setting arbitrary attributes. h._uses_learning_phase = True return h, [h] def get_config(self): config = { - 'units': - self.units, - 'activation': - activations.serialize(self.activation), + 'units': self.units, + 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), - 'use_bias': - self.use_bias, - 'kernel_initializer': - initializers.serialize(self.kernel_initializer), + 'use_bias': self.use_bias, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), - 'bias_initializer': - initializers.serialize(self.bias_initializer), - 'kernel_regularizer': - regularizers.serialize(self.kernel_regularizer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': - regularizers.serialize(self.bias_regularizer), - 'kernel_constraint': - constraints.serialize(self.kernel_constraint), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), - 'bias_constraint': - constraints.serialize(self.bias_constraint), - 'dropout': - self.dropout, - 'recurrent_dropout': - self.recurrent_dropout, - 'implementation': - self.implementation + 'bias_constraint': constraints.serialize(self.bias_constraint), + 'dropout': self.dropout, + 'recurrent_dropout': self.recurrent_dropout, + 'implementation': self.implementation } base_config = super(GRUCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -1809,29 +1788,6 @@ class LSTMCell(Layer): constraint=self.bias_constraint) else: self.bias = None - - self.kernel_i = self.kernel[:, :self.units] - self.kernel_f = self.kernel[:, self.units:self.units * 2] - self.kernel_c = self.kernel[:, self.units * 2:self.units * 3] - self.kernel_o = self.kernel[:, self.units * 3:] - - self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units] - self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: - self.units * 2] - self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: - self.units * 3] - self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:] - - if self.use_bias: - self.bias_i = self.bias[:self.units] - self.bias_f = self.bias[self.units:self.units * 2] - self.bias_c = self.bias[self.units * 2:self.units * 3] - self.bias_o = self.bias[self.units * 3:] - else: - self.bias_i = None - self.bias_f = None - self.bias_c = None - self.bias_o = None self.built = True def call(self, inputs, states, training=None): @@ -1869,15 +1825,15 @@ class LSTMCell(Layer): inputs_f = inputs inputs_c = inputs inputs_o = inputs - x_i = K.dot(inputs_i, self.kernel_i) - x_f = K.dot(inputs_f, self.kernel_f) - x_c = K.dot(inputs_c, self.kernel_c) - x_o = K.dot(inputs_o, self.kernel_o) + x_i = K.dot(inputs_i, self.kernel[:, :self.units]) + x_f = K.dot(inputs_f, self.kernel[:, self.units:self.units * 2]) + x_c = K.dot(inputs_c, self.kernel[:, self.units * 2:self.units * 3]) + x_o = K.dot(inputs_o, self.kernel[:, self.units * 3:]) if self.use_bias: - x_i = K.bias_add(x_i, self.bias_i) - x_f = K.bias_add(x_f, self.bias_f) - x_c = K.bias_add(x_c, self.bias_c) - x_o = K.bias_add(x_o, self.bias_o) + x_i = K.bias_add(x_i, self.bias[:self.units]) + x_f = K.bias_add(x_f, self.bias[self.units:self.units * 2]) + x_c = K.bias_add(x_c, self.bias[self.units * 2:self.units * 3]) + x_o = K.bias_add(x_o, self.bias[self.units * 3:]) if 0 < self.recurrent_dropout < 1.: h_tm1_i = h_tm1 * rec_dp_mask[0] @@ -1890,13 +1846,15 @@ class LSTMCell(Layer): h_tm1_c = h_tm1 h_tm1_o = h_tm1 i = self.recurrent_activation( - x_i + K.dot(h_tm1_i, self.recurrent_kernel_i)) + x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) f = self.recurrent_activation( - x_f + K.dot(h_tm1_f, self.recurrent_kernel_f)) + x_f + K.dot(h_tm1_f, + self.recurrent_kernel[:, self.units: self.units * 2])) c = f * c_tm1 + i * self.activation( - x_c + K.dot(h_tm1_c, self.recurrent_kernel_c)) + x_c + K.dot(h_tm1_c, + self.recurrent_kernel[:, self.units * 2: self.units * 3])) o = self.recurrent_activation( - x_o + K.dot(h_tm1_o, self.recurrent_kernel_o)) + x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) else: if 0. < self.dropout < 1.: inputs *= dp_mask[0] @@ -1919,7 +1877,9 @@ class LSTMCell(Layer): h = o * self.activation(c) if 0 < self.dropout + self.recurrent_dropout: - if training is None: + if training is None and not context.in_eager_mode(): + # This would be harmless to set in eager mode, but eager tensors + # disallow setting arbitrary attributes. h._uses_learning_phase = True return h, [h, c] diff --git a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py index 7edebda..8c7189c 100644 --- a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py @@ -20,64 +20,66 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer class SimpleRNNLayerTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_return_sequences_SimpleRNN(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.SimpleRNN, - kwargs={'units': units, - 'return_sequences': True}, - input_shape=(num_samples, timesteps, embedding_dim)) + testing_utils.layer_test( + keras.layers.SimpleRNN, + kwargs={'units': units, + 'return_sequences': True}, + input_shape=(num_samples, timesteps, embedding_dim)) + @tf_test_util.run_in_graph_and_eager_modes() def test_dynamic_behavior_SimpleRNN(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - layer = keras.layers.SimpleRNN(units, input_shape=(None, embedding_dim)) - model = keras.models.Sequential() - model.add(layer) - model.compile('sgd', 'mse') - x = np.random.random((num_samples, timesteps, embedding_dim)) - y = np.random.random((num_samples, units)) - model.train_on_batch(x, y) - + layer = keras.layers.SimpleRNN(units, input_shape=(None, embedding_dim)) + model = keras.models.Sequential() + model.add(layer) + model.compile(RMSPropOptimizer(0.01), 'mse') + x = np.random.random((num_samples, timesteps, embedding_dim)) + y = np.random.random((num_samples, units)) + model.train_on_batch(x, y) + + @tf_test_util.run_in_graph_and_eager_modes() def test_dropout_SimpleRNN(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.SimpleRNN, - kwargs={'units': units, - 'dropout': 0.1, - 'recurrent_dropout': 0.1}, - input_shape=(num_samples, timesteps, embedding_dim)) - + testing_utils.layer_test( + keras.layers.SimpleRNN, + kwargs={'units': units, + 'dropout': 0.1, + 'recurrent_dropout': 0.1}, + input_shape=(num_samples, timesteps, embedding_dim)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_implementation_mode_SimpleRNN(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - for mode in [0, 1, 2]: - testing_utils.layer_test( - keras.layers.SimpleRNN, - kwargs={'units': units, - 'implementation': mode}, - input_shape=(num_samples, timesteps, embedding_dim)) + for mode in [0, 1, 2]: + testing_utils.layer_test( + keras.layers.SimpleRNN, + kwargs={'units': units, + 'implementation': mode}, + input_shape=(num_samples, timesteps, embedding_dim)) def test_statefulness_SimpleRNN(self): num_samples = 2 diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py index c81d6b8..8fcf66e 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py @@ -20,44 +20,43 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer class TimeDistributedTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_timedistributed_dense(self): - # first, test with Dense layer - with self.test_session(): - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(2), input_shape=(3, 4))) - model.compile(optimizer='rmsprop', loss='mse') - model.fit( - np.random.random((10, 3, 4)), - np.random.random((10, 3, 2)), - epochs=1, - batch_size=10) - - # test config - model.get_config() + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(2), input_shape=(3, 4))) + model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') + model.fit( + np.random.random((10, 3, 4)), + np.random.random((10, 3, 2)), + epochs=1, + batch_size=10) + + # test config + model.get_config() def test_timedistributed_static_batch_size(self): - with self.test_session(): - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(2), input_shape=(3, 4), batch_size=10)) - model.compile(optimizer='rmsprop', loss='mse') - model.fit( - np.random.random((10, 3, 4)), - np.random.random((10, 3, 2)), - epochs=1, - batch_size=10) + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(2), input_shape=(3, 4), batch_size=10)) + model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') + model.fit( + np.random.random((10, 3, 4)), + np.random.random((10, 3, 2)), + epochs=1, + batch_size=10) def test_timedistributed_conv2d(self): - # test with Conv2D with self.test_session(): model = keras.models.Sequential() model.add( @@ -73,7 +72,6 @@ class TimeDistributedTest(test.TestCase): model.summary() def test_timedistributed_stacked(self): - # test stacked layers with self.test_session(): model = keras.models.Sequential() model.add( @@ -167,7 +165,7 @@ class BidirectionalTest(test.TestCase): model.add( keras.layers.Bidirectional( rnn(output_dim), merge_mode=mode, input_shape=(timesteps, dim))) - model.compile(loss='mse', optimizer='sgd') + model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') model.fit(x, y, epochs=1, batch_size=1) # test compute output shape diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py index 76a9715..6520128 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/_impl/keras/optimizers.py @@ -704,8 +704,10 @@ class TFOptimizer(Optimizer): return self.optimizer.compute_gradients(loss, params) def get_updates(self, loss, params): - grads = self.optimizer.compute_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] + if not params: + return self.updates + grads = self.optimizer.compute_gradients(loss, params) opt_update = self.optimizer.apply_gradients( grads, global_step=self.iterations) self.updates.append(opt_update) diff --git a/tensorflow/python/keras/_impl/keras/testing_utils.py b/tensorflow/python/keras/_impl/keras/testing_utils.py index fa1ee2f..60799ee 100644 --- a/tensorflow/python/keras/_impl/keras/testing_utils.py +++ b/tensorflow/python/keras/_impl/keras/testing_utils.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl import keras +from tensorflow.python.training.rmsprop import RMSPropOptimizer from tensorflow.python.util import tf_inspect @@ -145,7 +146,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, np.testing.assert_allclose(output, actual_output, rtol=1e-3) # test training mode (e.g. useful for dropout tests) - model.compile('rmsprop', 'mse') + model.compile(RMSPropOptimizer(0.01), 'mse') model.train_on_batch(input_data, actual_output) # test as first layer in Sequential API @@ -181,9 +182,5 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, output = recovered_model.predict(input_data) np.testing.assert_allclose(output, actual_output, rtol=1e-3) - # test training mode (e.g. useful for dropout tests) - model.compile('rmsprop', 'mse') - model.train_on_batch(input_data, actual_output) - # for further checks in the caller function return actual_output -- 2.7.4