From ccefd0a1307ac5dd39d0a254c49ce71f8c2b93e2 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 26 Feb 2018 19:57:42 -0800 Subject: [PATCH] Fixes and simplification in the Keras training engine. - Explicitly disallow sample/class weighting in eager (it was never supported) - Remove tests for it (which were actually ignoring sample/class weights) - Make sample weight placeholders placeholder_with_default, and do not create all-ones numpy arrays to feed them when no sample weights are provided (this might lead to better performance) PiperOrigin-RevId: 187121215 --- tensorflow/python/keras/_impl/keras/backend.py | 11 +- tensorflow/python/keras/_impl/keras/callbacks.py | 20 +- .../python/keras/_impl/keras/engine/training.py | 151 +++---- .../keras/_impl/keras/engine/training_eager.py | 17 +- .../_impl/keras/engine/training_eager_test.py | 436 --------------------- .../keras/_impl/keras/engine/training_test.py | 8 - 6 files changed, 110 insertions(+), 533 deletions(-) diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index a2db05f..2b75666 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -2749,7 +2749,7 @@ class Function(object): self.updates_op = control_flow_ops.group(*updates_ops) self.name = name # additional tensor substitutions - self.feed_dict = session_kwargs.pop('feed_dict', {}) + self.feed_dict = session_kwargs.pop('feed_dict', None) # additional operations self.fetches = session_kwargs.pop('fetches', []) if not isinstance(self.fetches, list): @@ -2759,8 +2759,15 @@ class Function(object): def __call__(self, inputs): if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') - feed_dict = self.feed_dict.copy() + + if self.feed_dict: + feed_dict = self.feed_dict.copy() + else: + feed_dict = {} + for tensor, value in zip(self.inputs, inputs): + if value is None: + continue if is_sparse(tensor): sparse_coo = value.tocoo() indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py index f6c4661..deb1e88 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks.py +++ b/tensorflow/python/keras/_impl/keras/callbacks.py @@ -778,16 +778,24 @@ class TensorBoard(Callback): while i < val_size: step = min(self.batch_size, val_size - i) batch_val = [] - batch_val.append(val_data[0][i:i + step]) - batch_val.append(val_data[1][i:i + step]) - batch_val.append(val_data[2][i:i + step]) + batch_val.append(val_data[0][i:i + step] + if val_data[0] is not None else None) + batch_val.append(val_data[1][i:i + step] + if val_data[1] is not None else None) + batch_val.append(val_data[2][i:i + step] + if val_data[2] is not None else None) if self.model.uses_learning_phase: # do not slice the learning phase - batch_val = [x[i:i + step] for x in val_data[:-1]] + batch_val = [x[i:i + step] if x is not None else None + for x in val_data[:-1]] batch_val.append(val_data[-1]) else: - batch_val = [x[i:i + step] for x in val_data] - feed_dict = dict(zip(tensors, batch_val)) + batch_val = [x[i:i + step] if x is not None else None + for x in val_data] + feed_dict = {} + for key, val in zip(tensors, batch_val): + if val is not None: + feed_dict[key] = val result = self.sess.run([self.merged], feed_dict=feed_dict) summary_str = result[0] self.writer.add_summary(summary_str, epoch) diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 57451ad..63bea08 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -40,6 +40,7 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays from tensorflow.python.layers.base import _DeferredTensor +from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.util.tf_export import tf_export @@ -225,9 +226,9 @@ def _check_array_lengths(inputs, targets, weights=None): # return a set with the variation between # different shapes, with None => 0 if x is None: - return {0} + return {} else: - return set([0 if y is None else y.shape[0] for y in x]) + return set([y.shape[0] for y in x if y is not None]) set_x = set_of_lengths(inputs) set_y = set_of_lengths(targets) @@ -259,7 +260,8 @@ def _check_array_lengths(inputs, targets, weights=None): def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): """Does validation on the compatibility of targets and loss functions. - This helps prevent users from using loss functions incorrectly. + This helps prevent users from using loss functions incorrectly. This check + is purely for UX purposes. Arguments: targets: list of Numpy arrays of targets. @@ -275,7 +277,7 @@ def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): losses.categorical_crossentropy } for y, loss, shape in zip(targets, loss_fns, output_shapes): - if y is None or loss is None: + if y is None or loss is None or tensor_util.is_tensor(y): continue if loss is losses.categorical_crossentropy: if y.shape[-1] == 1: @@ -507,10 +509,7 @@ def _standardize_weights(y, (existing_classes - existing_class_weight)) return weights else: - if sample_weight_mode is None: - return np.ones((y.shape[0],), dtype=K.floatx()) - else: - return np.ones((y.shape[0], y.shape[1]), dtype=K.floatx()) + return None @tf_export('keras.models.Model', 'keras.Model') @@ -862,12 +861,12 @@ class Model(Network): sample_weights.append(None) else: if sample_weight_mode == 'temporal': - sample_weights.append( - K.placeholder(ndim=2, name=name + '_sample_weights')) + sample_weights.append(array_ops.placeholder_with_default( + [[1.]], shape=[None, None], name=name + '_sample_weights')) sample_weight_modes.append('temporal') else: - sample_weights.append( - K.placeholder(ndim=1, name=name + '_sample_weights')) + sample_weights.append(array_ops.placeholder_with_default( + [1.], shape=[None], name=name + '_sample_weights')) sample_weight_modes.append(None) self.sample_weight_modes = sample_weight_modes self._feed_sample_weight_modes = [] @@ -1314,7 +1313,7 @@ class Model(Network): for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] try: - if isinstance(ins[-1], float): + if isinstance(ins[-1], int): # Do not slice the training phase flag. ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: @@ -1424,7 +1423,7 @@ class Model(Network): index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] - if ins and isinstance(ins[-1], float): + if ins and isinstance(ins[-1], int): # Do not slice the training phase flag. ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: @@ -1518,7 +1517,7 @@ class Model(Network): index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] - if isinstance(ins[-1], float): + if isinstance(ins[-1], int): # Do not slice the training phase flag. ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: @@ -2070,10 +2069,6 @@ class Model(Network): val_y, sample_weight=val_sample_weight, batch_size=batch_size) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_ins = val_x + val_y + val_sample_weights + [0.] - else: - val_ins = val_x + val_y + val_sample_weights elif validation_split and 0. < validation_split < 1.: do_validation = True @@ -2085,36 +2080,34 @@ class Model(Network): y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at)) sample_weights, val_sample_weights = (slice_arrays( sample_weights, 0, split_at), slice_arrays(sample_weights, split_at)) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_ins = val_x + val_y + val_sample_weights + [0.] - else: - val_ins = val_x + val_y + val_sample_weights - elif validation_steps: + val_x = [] + val_y = [] + val_sample_weights = [] do_validation = True - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_ins = [0.] - - # Prepare input arrays and training function. - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [1.] - else: - ins = x + y + sample_weights # Prepare display labels. out_labels = self.metrics_names if context.in_eager_mode(): + if any([w is not None for w in sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported ' + 'when eager execution is enabled, for now.') + if do_validation: + if any([w is not None for w in val_sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported' + ' when eager execution is enabled, for now.') callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] + val_ins = val_x + val_y else: callback_metrics = copy.copy(out_labels) return training_eager.fit_loop( self, - ins, + x + y, out_labels=out_labels, batch_size=batch_size, epochs=epochs, @@ -2127,18 +2120,25 @@ class Model(Network): steps_per_epoch=steps_per_epoch, validation_steps=validation_steps) else: + # Prepare input arrays and training function. + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [1] + else: + ins = x + y + sample_weights + self._make_train_function() f = self.train_function if do_validation: - if context.in_graph_mode(): - self._make_test_function() - val_f = self.test_function - else: - val_f = None + self._make_test_function() + val_f = self.test_function callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + val_ins = val_x + val_y + val_sample_weights + [0] + else: + val_ins = val_x + val_y + val_sample_weights else: val_f = None callback_metrics = copy.copy(out_labels) @@ -2229,16 +2229,20 @@ class Model(Network): y, sample_weight=sample_weight, batch_size=batch_size) - # Prepare inputs, delegate logic to `_test_loop`. - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [0.] - else: - ins = x + y + sample_weights if context.in_eager_mode(): + if any([w is not None for w in sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported ' + 'when eager execution is enabled, for now.') return training_eager.test_loop( - self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + self, x + y, batch_size=batch_size, verbose=verbose, steps=steps) else: + # Prepare inputs, delegate logic to `_test_loop`. + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [0] + else: + ins = x + y + sample_weights + self._make_test_function() f = self.test_function return self._test_loop( @@ -2276,16 +2280,16 @@ class Model(Network): 'argument.') x, _, _ = self._standardize_user_data(x) - # Prepare inputs, delegate logic to `_predict_loop`. - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + [0.] - else: - ins = x - if context.in_eager_mode(): return training_eager.predict_loop( - self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + self, x, batch_size=batch_size, verbose=verbose, steps=steps) else: + # Prepare inputs, delegate logic to `_predict_loop`. + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + [0] + else: + ins = x + self._make_predict_function() f = self.predict_function @@ -2327,20 +2331,26 @@ class Model(Network): and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. + Raises: + ValueError: In case of invalid user-provided arguments. """ x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, class_weight=class_weight) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [1.] - else: - ins = x + y + sample_weights if context.in_eager_mode(): - outputs = training_eager.train_on_batch(self, ins) + if any([w is not None for w in sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported ' + 'when eager execution is enabled, for now.') + outputs = training_eager.train_on_batch(self, x + y) else: + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [1] + else: + ins = x + y + sample_weights + self._make_train_function() outputs = self.train_function(ins) @@ -2377,18 +2387,21 @@ class Model(Network): the display labels for the scalar outputs. Raises: - ValueError: in case of invalid arguments. + ValueError: In case of invalid user-provided arguments. """ x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [0.] - else: - ins = x + y + sample_weights if context.in_eager_mode(): - outputs = training_eager.test_on_batch(self, ins) + if any([w is not None for w in sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported ' + 'when eager execution is enabled, for now.') + outputs = training_eager.test_on_batch(self, x + y) else: + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [0] + else: + ins = x + y + sample_weights self._make_test_function() outputs = self.test_function(ins) @@ -2408,14 +2421,9 @@ class Model(Network): """ x, _, _ = self._standardize_user_data(x) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + [0.] - else: - ins = x - if context.in_eager_mode(): ins_batch_converted = [] - for ib in ins: + for ib in x: ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) eager_model_inputs = [] @@ -2426,6 +2434,11 @@ class Model(Network): return outs if context.in_graph_mode(): + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + [0] + else: + ins = x + self._make_predict_function() outputs = self.predict_function(ins) if len(outputs) == 1: @@ -2643,7 +2656,7 @@ class Model(Network): val_data = val_x + val_y + val_sample_weights if self.uses_learning_phase and not isinstance( K.learning_phase(), int): - val_data += [0.] + val_data += [0] for cbk in callbacks: cbk.validation_data = val_data diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py index 282dd0d..cdf189a 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py @@ -139,6 +139,8 @@ def _model_loss(model, inputs, targets, training=False): model.output_names[i]) loss_metrics.append(K.mean(output_loss)) + # TODO(fchollet): support masking; in practice `_keras_mask` is never + # set in this context currently. mask = outs[i]._keras_mask # adapted from weighted_loss_fn if mask is not None: @@ -148,17 +150,7 @@ def _model_loss(model, inputs, targets, training=False): # to the number of unmasked samples. output_loss /= K.mean(mask) - # adapted from weighted_loss_fn - # apply sample weighting - if model.sample_weights: - # reduce score_array to same ndim as weight array - ndim = K.ndim(output_loss) - weight_ndim = K.ndim(model.sample_weights) - output_loss = K.mean(output_loss, axis=list(range(weight_ndim, ndim))) - output_loss *= model.sample_weights - output_loss /= K.mean(K.cast(K.not_equal(model.sample_weights, 0), - K.floatx())) - output_loss = K.mean(output_loss) + # TODO(fchollet): support sample weighting loss_weight = model.loss_weights_list[i] if total_loss is None: @@ -231,7 +223,8 @@ def train_on_batch(model, ins): """ ins_batch_converted = [] for ib in ins: - ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + if ib is not None: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) eager_model_inputs = [] eager_model_outputs = [] for i in range(len(model.inputs)): diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py index 3d94b75..550b86a 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py @@ -24,9 +24,7 @@ import numpy as np from tensorflow.python.framework import ops from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils -from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -311,440 +309,6 @@ class TrainingTest(test.TestCase): optimizer='rms') -class LossWeightingTest(test.TestCase): - - def test_class_weights(self): - num_classes = 5 - batch_size = 5 - epochs = 5 - weighted_class = 3 - train_samples = 3000 - test_samples = 3000 - input_dim = 5 - - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_shape=(input_dim,))) - model.add(keras.layers.Activation('relu')) - model.add(keras.layers.Dense(num_classes)) - model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - - np.random.seed(1337) - (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_test = y_test.copy() - int_y_train = y_train.copy() - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - y_test = keras.utils.to_categorical(y_test, num_classes) - test_ids = np.where(int_y_test == np.array(weighted_class))[0] - - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 2. - - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 3, - verbose=0, - class_weight=class_weight, - validation_data=(x_train, y_train, sample_weight)) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 2, - verbose=0, - class_weight=class_weight) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 2, - verbose=0, - class_weight=class_weight, - validation_split=0.1) - - model.train_on_batch( - x_train[:batch_size], y_train[:batch_size], class_weight=class_weight) - ref_score = model.evaluate(x_test, y_test, verbose=0) - score = model.evaluate( - x_test[test_ids, :], y_test[test_ids, :], verbose=0) - self.assertLess(score, ref_score) - - def test_sample_weights(self): - num_classes = 5 - batch_size = 5 - epochs = 5 - weighted_class = 3 - train_samples = 3000 - test_samples = 3000 - input_dim = 5 - - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_shape=(input_dim,))) - model.add(keras.layers.Activation('relu')) - model.add(keras.layers.Dense(num_classes)) - model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - - np.random.seed(43) - (x_train, y_train), _ = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_train = y_train.copy() - y_train = keras.utils.to_categorical(y_train, num_classes) - - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 2. - - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 3, - verbose=0, - sample_weight=sample_weight) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 3, - verbose=0, - sample_weight=sample_weight, - validation_split=0.1) - model.train_on_batch( - x_train[:batch_size], - y_train[:batch_size], - sample_weight=sample_weight[:batch_size]) - model.test_on_batch( - x_train[:batch_size], - y_train[:batch_size], - sample_weight=sample_weight[:batch_size]) - - def test_temporal_sample_weights(self): - num_classes = 5 - weighted_class = 3 - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - timesteps = 3 - - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(num_classes), - input_shape=(timesteps, input_dim))) - model.add(keras.layers.Activation('softmax')) - - np.random.seed(1337) - (_, y_train), _ = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_train = y_train.copy() - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 2. - with self.assertRaises(ValueError): - model.compile( - loss='binary_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001), - sample_weight_mode='temporal') - - def test_class_weight_invalid_use_case(self): - num_classes = 5 - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - timesteps = 3 - - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(num_classes), - input_shape=(timesteps, input_dim))) - model.add(keras.layers.Activation('softmax')) - model.compile( - loss='binary_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - - (x_train, y_train), _ = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - class_weight = dict([(i, 1.) for i in range(num_classes)]) - - del class_weight[1] - with self.assertRaises(ValueError): - model.fit(x_train, y_train, - epochs=0, verbose=0, class_weight=class_weight) - - with self.assertRaises(ValueError): - model.compile( - loss='binary_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001), - sample_weight_mode=[]) - - # Build multi-output model - x = keras.Input((3,)) - y1 = keras.layers.Dense(4, name='1')(x) - y2 = keras.layers.Dense(4, name='2')(x) - model = keras.models.Model(x, [y1, y2]) - model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') - x_np = np.random.random((10, 3)) - y_np = np.random.random((10, 4)) - w_np = np.random.random((10,)) - # This will work - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': w_np}) - # These will not - with self.assertRaises(ValueError): - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=[w_np]) - with self.assertRaises(TypeError): - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=w_np) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((11,)) - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((10, 2)) - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((10, 2, 2)) - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - - -class TestDynamicTrainability(test.TestCase): - - def test_trainable_warning(self): - x = np.random.random((5, 3)) - y = np.random.random((5, 2)) - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_dim=3)) - model.trainable = False - model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') - model.trainable = True - with test.mock.patch.object(logging, 'warning') as mock_log: - model.train_on_batch(x, y) - self.assertRegexpMatches(str(mock_log.call_args), - 'trainable weights is empty') - - def test_trainable_argument(self): - x = np.random.random((5, 3)) - y = np.random.random((5, 2)) - - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_dim=3, trainable=False)) - model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') - out = model.predict(x) - with test.mock.patch.object(logging, 'warning') as mock_log: - model.train_on_batch(x, y) - self.assertRegexpMatches(str(mock_log.call_args), - 'trainable weights is empty') - out_2 = model.predict(x) - self.assertAllClose(out, out_2) - - # test with nesting - inputs = keras.layers.Input(shape=(3,)) - output = model(inputs) - model = keras.models.Model(inputs, output) - model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') - out = model.predict(x) - with test.mock.patch.object(logging, 'warning') as mock_log: - model.train_on_batch(x, y) - self.assertRegexpMatches(str(mock_log.call_args), - 'trainable weights is empty') - out_2 = model.predict(x) - self.assertAllClose(out, out_2) - - def test_layer_trainability_switch(self): - # with constructor argument, in Sequential - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, trainable=False, input_dim=1)) - self.assertListEqual(model.trainable_weights, []) - - # by setting the `trainable` argument, in Sequential - model = keras.models.Sequential() - layer = keras.layers.Dense(2, input_dim=1) - model.add(layer) - self.assertListEqual(model.trainable_weights, layer.trainable_weights) - layer.trainable = False - self.assertListEqual(model.trainable_weights, []) - - # with constructor argument, in Model - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2, trainable=False)(x) - model = keras.models.Model(x, y) - self.assertListEqual(model.trainable_weights, []) - - # by setting the `trainable` argument, in Model - x = keras.layers.Input(shape=(1,)) - layer = keras.layers.Dense(2) - y = layer(x) - model = keras.models.Model(x, y) - self.assertListEqual(model.trainable_weights, layer.trainable_weights) - layer.trainable = False - self.assertListEqual(model.trainable_weights, []) - - def test_model_trainability_switch(self): - # a non-trainable model has no trainable weights - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2)(x) - model = keras.models.Model(x, y) - model.trainable = False - self.assertListEqual(model.trainable_weights, []) - - # same for Sequential - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_dim=1)) - model.trainable = False - self.assertListEqual(model.trainable_weights, []) - - def test_nested_model_trainability(self): - - # a Sequential inside a Model - inner_model = keras.models.Sequential() - inner_model.add(keras.layers.Dense(2, input_dim=1)) - - x = keras.layers.Input(shape=(1,)) - y = inner_model(x) - outer_model = keras.models.Model(x, y) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - # a Sequential inside a Sequential - inner_model = keras.models.Sequential() - inner_model.add(keras.layers.Dense(2, input_dim=1)) - outer_model = keras.models.Sequential() - outer_model.add(inner_model) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - # a Model inside a Model - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2)(x) - inner_model = keras.models.Model(x, y) - x = keras.layers.Input(shape=(1,)) - y = inner_model(x) - outer_model = keras.models.Model(x, y) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - # a Model inside a Sequential - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2)(x) - inner_model = keras.models.Model(x, y) - outer_model = keras.models.Sequential() - outer_model.add(inner_model) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - -class TestTrainingUtils(test.TestCase): - - def test_check_array_lengths(self): - keras.engine.training._check_array_lengths(None, None, None) - a_np = np.random.random((4, 3, 3)) - keras.engine.training._check_array_lengths(a_np, a_np, a_np) - keras.engine.training._check_array_lengths( - [a_np, a_np], [a_np, a_np], [a_np, a_np]) - keras.engine.training._check_array_lengths([None], [None], [None]) - - b_np = np.random.random((3, 4)) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, None, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, a_np, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [None], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [b_np], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], None, [b_np]) - - def test_slice_arrays(self): - input_a = np.random.random((10, 3)) - slice_arrays(None) - slice_arrays(input_a, 0) - slice_arrays(input_a, 0, 1) - slice_arrays(input_a, stop=2) - input_a = [None, [1, 1], None, [1, 1]] - slice_arrays(input_a, 0) - slice_arrays(input_a, 0, 1) - slice_arrays(input_a, stop=2) - input_a = [None] - slice_arrays(input_a, 0) - slice_arrays(input_a, 0, 1) - slice_arrays(input_a, stop=2) - input_a = None - slice_arrays(input_a, 0) - slice_arrays(input_a, 0, 1) - slice_arrays(input_a, stop=2) - - def test_fit_with_BatchNorm(self): - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_dim=4)) - model.add(keras.layers.BatchNormalization()) - model.add(keras.layers.Activation('tanh')) - model.add(keras.layers.Dropout(0.2)) - - input_a_np = np.random.random((10, 4)) - output_b_np = np.random.random((10, 10)) - - model.compile(loss='binary_crossentropy', optimizer=RMSPropOptimizer(0.001)) - model.fit(input_a_np, output_b_np, epochs=1, batch_size=5, verbose=0) - - def test_fit_with_regularization(self): - model = keras.models.Sequential() - with self.assertRaises(ValueError): - model.add( - keras.layers.Dense(4, input_dim=3, - kernel_regularizer=keras.regularizers.l2(0.01), - activity_regularizer=keras.regularizers.l1(0.01))) - - if __name__ == '__main__': # Bazel sets these environment variables to very long paths. # Tempfile uses them to create long paths, and in turn multiprocessing diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index 9651eb9..6ca5941 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -1046,15 +1046,7 @@ class TestTrainingUtils(test.TestCase): b_np = np.random.random((3, 4)) with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, None, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, a_np, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [None], None) - with self.assertRaises(ValueError): keras.engine.training._check_array_lengths([a_np], [b_np], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], None, [b_np]) def test_slice_arrays(self): input_a = np.random.random((10, 3)) -- 2.7.4