From 07cd8f2565cd1c7a44be681379eb7dfc64a77b1c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 23 May 2018 16:02:19 -0700 Subject: [PATCH] added support for calling fit on Dataset objects PiperOrigin-RevId: 197805615 --- tensorflow/python/keras/engine/training.py | 88 +++++++++++--------- tensorflow/python/keras/engine/training_test.py | 101 +++++++++++++++++++++-- tensorflow/python/keras/engine/training_utils.py | 15 ++-- 3 files changed, 151 insertions(+), 53 deletions(-) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index ff50d0b..0db805c 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -112,6 +112,8 @@ class Model(Network): super(Model, self).__init__(*args, **kwargs) # Create a cache for iterator get_next op. self._iterator_get_next = weakref.WeakKeyDictionary() + # Create a cache for dataset - uninitialized iterators + self._dataset_iterator_cache = weakref.WeakKeyDictionary() def compile(self, optimizer, @@ -670,12 +672,12 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified + tensor targets, or inversely). If `x` is a dataset or a + dataset iterator, `y` should not be specified (since targets will be obtained from the iterator). sample_weight: An optional sample-weight array passed by the user to weight the importance of each sample in `x`. @@ -706,11 +708,16 @@ class Model(Network): RuntimeError: If the model was never compiled. """ if isinstance(x, dataset_ops.Dataset): - raise ValueError('You passed a `Dataset` instance to your model (%s), ' - 'which is not supported. Instead, pass an `Iterator`, ' - 'which you can obtain e.g. via ' - '`dataset.make_one_shot_iterator()` (the exact method ' - 'to use will depend on your specific dataset).' % x) + if context.executing_eagerly(): + x = x.make_one_shot_iterator() + else: + if x in self._dataset_iterator_cache: + x = self._dataset_iterator_cache[x] + else: + iterator = x.make_initializable_iterator() + self._dataset_iterator_cache[x] = iterator + x = iterator + K.get_session().run(x.initializer) # Validates `steps` argument based on x's type. if check_steps: @@ -719,7 +726,7 @@ class Model(Network): is_x_eager_iterator = isinstance(x, iterator_ops.EagerIterator) is_x_iterator = isinstance(x, iterator_ops.Iterator) - # Validate user inputs when data is given as a dataset iterator. + # Validate user inputs when data is given as a dataset or dataset iterator. if is_x_iterator or is_x_eager_iterator: training_utils.validate_iterator_input(x, y, sample_weight, validation_split) @@ -1130,19 +1137,19 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified + tensor targets, or inversely). If `x` is a dataset or dataset + iterator, `y` should not be specified (since targets will be obtained from the iterator). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the - form of symbolic tensors or dataset iterators (since they generate - batches). + form of symbolic tensors, datasets, or dataset iterators + (since they generate batches). epochs: Integer. Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided. @@ -1164,7 +1171,7 @@ class Model(Network): on this data at the end of each epoch. The validation data is selected from the last samples in the `x` and `y` data provided, before shuffling. This argument is - not supported when `x` is a dataset iterator. + not supported when `x` is a dataset or a dataset iterator. validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. @@ -1172,7 +1179,7 @@ class Model(Network): `validation_data` could be: - tuple `(x_val, y_val)` of Numpy arrays or tensors - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays - - dataset iterator + - dataset or a dataset iterator shuffle: Boolean (whether to shuffle the training data before each epoch) or str (for 'batch'). 'batch' is a special option for dealing with the @@ -1195,7 +1202,7 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset iterator. + supported when `x` is a dataset or a dataset iterator. initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run). @@ -1252,7 +1259,8 @@ class Model(Network): # Prepare validation data. if validation_data: if (isinstance(validation_data, iterator_ops.Iterator) or - isinstance(validation_data, iterator_ops.EagerIterator)): + isinstance(validation_data, iterator_ops.EagerIterator) or + isinstance(validation_data, dataset_ops.Dataset)): val_x = validation_data val_y = None val_sample_weight = None @@ -1266,8 +1274,9 @@ class Model(Network): 'When passing a `validation_data` argument, ' 'it must contain either 2 items (x_val, y_val), ' 'or 3 items (x_val, y_val, val_sample_weights), ' - 'or alternatively it could be a dataset iterator. However we ' - 'received `validation_data=%s`' % validation_data) + 'or alternatively it could be a dataset or a ' + 'dataset or a dataset iterator. ' + 'However we received `validation_data=%s`' % validation_data) # Validate and standardize validation data. val_x, val_y, val_sample_weights = self._standardize_user_data( @@ -1351,19 +1360,19 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified - (since targets will be obtained from the iterator). + tensor targets, or inversely). + If `x` is a dataset or a dataset iterator, `y` should not be specified + (since targets will be obtained from the iterator/dataset). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors or dataset iterators (since they generate - batches). + form of symbolic tensors, datasets, or dataset iterators + (since they generate batches). verbose: 0 or 1. Verbosity mode. 0 = silent, 1 = progress bar. sample_weight: Optional Numpy array of weights for @@ -1377,7 +1386,7 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset iterator. + supported when `x` is a dataset or a dataset iterator. steps: Integer or `None`. Total number of steps (batches of samples) before declaring the evaluation round finished. @@ -1426,13 +1435,13 @@ class Model(Network): (in case the model has multiple inputs). - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors or dataset iterators (since they generate - batches). + form of symbolic tensors, dataset, or dataset iterators + (since they generate batches). verbose: Verbosity mode, 0 or 1. steps: Total number of steps (batches of samples) before declaring the prediction round finished. @@ -1473,12 +1482,12 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified + tensor targets, or inversely). If `x` is a dataset or a + dataset iterator, `y` should not be specified (since targets will be obtained from the iterator). sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. @@ -1487,8 +1496,7 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset iterator. - + supported when `x` is a dataset or a dataset iterator. class_weight: Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples @@ -1537,12 +1545,12 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset iterator, - `y` should not be specified + tensor targets, or inversely). If `x` is a dataset or a + dataset iterator, `y` should not be specified (since targets will be obtained from the iterator). sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. @@ -1551,7 +1559,7 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset iterator. + supported when `x` is a dataset or a dataset iterator. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -1590,7 +1598,7 @@ class Model(Network): (in case the model has multiple inputs). - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - - A `tf.data` dataset iterator. + - A `tf.data` dataset or a dataset iterator. Returns: Numpy array(s) of predictions. diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 7dec0bb..222e349 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1742,7 +1742,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): # Test with validation split with self.assertRaisesRegexp( ValueError, '`validation_split` argument is not supported ' - 'when input `x` is a dataset iterator'): + 'when input `x` is a dataset or a dataset iterator'): model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=0, validation_split=0.5, validation_steps=2) @@ -1751,7 +1751,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): sample_weight = np.random.random((10,)) with self.assertRaisesRegexp( ValueError, '`sample_weight` argument is not supported ' - 'when input `x` is a dataset iterator'): + 'when input `x` is a dataset or a dataset iterator'): model.fit( iterator, epochs=1, @@ -1761,10 +1761,6 @@ class TestTrainingWithDatasetIterators(test.TestCase): # Test invalid usage with self.assertRaisesRegexp(ValueError, - 'Instead, pass an `Iterator`'): - model.fit(dataset, - epochs=1, steps_per_epoch=2, verbose=0) - with self.assertRaisesRegexp(ValueError, 'you should not specify a target'): model.fit(iterator, iterator, epochs=1, steps_per_epoch=2, verbose=0) @@ -1829,5 +1825,98 @@ class TestTrainingWithDatasetIterators(test.TestCase): 'dataset iterator ran out of data') +class TestTrainingWithDataset(test.TestCase): + + def test_calling_model_on_same_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Call fit with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + # Finalize the graph to make sure new ops aren't added when calling on the + # same dataset + ops.get_default_graph().finalize() + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + + @tf_test_util.run_in_graph_and_eager_modes() + def test_training_and_eval_methods_on_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + model.train_on_batch(dataset) + model.predict_on_batch(dataset) + + # Test with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not supported ' + 'when input `x` is a dataset or a dataset iterator'): + model.fit(dataset, + epochs=1, steps_per_epoch=2, verbose=0, + validation_split=0.5, validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + ValueError, '`sample_weight` argument is not supported ' + 'when input `x` is a dataset or a dataset iterator'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test invalid usage + with self.assertRaisesRegexp(ValueError, + 'you should not specify a target'): + model.fit(dataset, dataset, + epochs=1, steps_per_epoch=2, verbose=0) + + with self.assertRaisesRegexp( + ValueError, 'you should specify the `steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.evaluate(dataset, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.predict(dataset, verbose=0) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 7d214d6..c53948b 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -632,19 +632,20 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None): provided by user. """ if y is not None: - raise ValueError('You passed a dataset iterator (%s) as input `x` to ' - 'your model. In that case, you should not specify ' - 'a target (`y`) argument, since the dataset iterator ' - 'generates both input data and target data. ' + raise ValueError('You passed a dataset or dataset iterator (%s) as ' + 'input `x` to your model. In that case, you should ' + 'not specify a target (`y`) argument, since the dataset ' + 'or dataset iterator generates both input data and ' + 'target data. ' 'Received: %s' % (x, y)) if sample_weight is not None: - raise ValueError('`sample_weight` argument is not supported when input' - ' `x` is a dataset iterator. ' + raise ValueError('`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator. ' 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) if validation_split is not None and validation_split != 0.0: raise ValueError( '`validation_split` argument is not supported when ' - 'input `x` is a dataset iterator. ' + 'input `x` is a dataset or a dataset iterator. ' 'Received: x=%s, validation_split=%f' % (x, validation_split)) -- 2.7.4