added support for calling fit on Dataset objects
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 23 May 2018 23:02:19 +0000 (16:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 23:05:14 +0000 (16:05 -0700)
PiperOrigin-RevId: 197805615

tensorflow/python/keras/engine/training.py
tensorflow/python/keras/engine/training_test.py
tensorflow/python/keras/engine/training_utils.py

index ff50d0b..0db805c 100644 (file)
@@ -112,6 +112,8 @@ class Model(Network):
     super(Model, self).__init__(*args, **kwargs)
     # Create a cache for iterator get_next op.
     self._iterator_get_next = weakref.WeakKeyDictionary()
+    # Create a cache for dataset - uninitialized iterators
+    self._dataset_iterator_cache = weakref.WeakKeyDictionary()
 
   def compile(self,
               optimizer,
@@ -670,12 +672,12 @@ class Model(Network):
           (in case the model has multiple inputs).
         - A dict mapping input names to the corresponding array/tensors,
           if the model has named inputs.
-        - A `tf.data` dataset iterator.
+        - A `tf.data` dataset or a dataset iterator.
       y: Target data. Like the input data `x`,
         it could be either Numpy array(s) or TensorFlow tensor(s).
         It should be consistent with `x` (you cannot have Numpy inputs and
-        tensor targets, or inversely). If `x` is a dataset iterator,
-        `y` should not be specified
+        tensor targets, or inversely). If `x` is a dataset or a
+        dataset iterator, `y` should not be specified
         (since targets will be obtained from the iterator).
       sample_weight: An optional sample-weight array passed by the user to
         weight the importance of each sample in `x`.
@@ -706,11 +708,16 @@ class Model(Network):
       RuntimeError: If the model was never compiled.
     """
     if isinstance(x, dataset_ops.Dataset):
-      raise ValueError('You passed a `Dataset` instance to your model (%s), '
-                       'which is not supported. Instead, pass an `Iterator`, '
-                       'which you can obtain e.g. via '
-                       '`dataset.make_one_shot_iterator()` (the exact method '
-                       'to use will depend on your specific dataset).' % x)
+      if context.executing_eagerly():
+        x = x.make_one_shot_iterator()
+      else:
+        if x in self._dataset_iterator_cache:
+          x = self._dataset_iterator_cache[x]
+        else:
+          iterator = x.make_initializable_iterator()
+          self._dataset_iterator_cache[x] = iterator
+          x = iterator
+        K.get_session().run(x.initializer)
 
     # Validates `steps` argument based on x's type.
     if check_steps:
@@ -719,7 +726,7 @@ class Model(Network):
     is_x_eager_iterator = isinstance(x, iterator_ops.EagerIterator)
     is_x_iterator = isinstance(x, iterator_ops.Iterator)
 
-    # Validate user inputs when data is given as a dataset iterator.
+    # Validate user inputs when data is given as a dataset or dataset iterator.
     if is_x_iterator or is_x_eager_iterator:
       training_utils.validate_iterator_input(x, y, sample_weight,
                                              validation_split)
@@ -1130,19 +1137,19 @@ class Model(Network):
             (in case the model has multiple inputs).
           - A dict mapping input names to the corresponding array/tensors,
             if the model has named inputs.
-          - A `tf.data` dataset iterator.
+          - A `tf.data` dataset or a dataset iterator.
         y: Target data. Like the input data `x`,
           it could be either Numpy array(s) or TensorFlow tensor(s).
           It should be consistent with `x` (you cannot have Numpy inputs and
-          tensor targets, or inversely). If `x` is a dataset iterator,
-          `y` should not be specified
+          tensor targets, or inversely). If `x` is a dataset or dataset
+          iterator, `y` should not be specified
           (since targets will be obtained from the iterator).
         batch_size: Integer or `None`.
             Number of samples per gradient update.
             If unspecified, `batch_size` will default to 32.
             Do not specify the `batch_size` if your data is in the
-            form of symbolic tensors or dataset iterators (since they generate
-            batches).
+            form of symbolic tensors, datasets, or dataset iterators
+            (since they generate batches).
         epochs: Integer. Number of epochs to train the model.
             An epoch is an iteration over the entire `x` and `y`
             data provided.
@@ -1164,7 +1171,7 @@ class Model(Network):
             on this data at the end of each epoch.
             The validation data is selected from the last samples
             in the `x` and `y` data provided, before shuffling. This argument is
-            not supported when `x` is a dataset iterator.
+            not supported when `x` is a dataset or a dataset iterator.
         validation_data: Data on which to evaluate
             the loss and any model metrics at the end of each epoch.
             The model will not be trained on this data.
@@ -1172,7 +1179,7 @@ class Model(Network):
             `validation_data` could be:
               - tuple `(x_val, y_val)` of Numpy arrays or tensors
               - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
-              - dataset iterator
+              - dataset or a dataset iterator
         shuffle: Boolean (whether to shuffle the training data
             before each epoch) or str (for 'batch').
             'batch' is a special option for dealing with the
@@ -1195,7 +1202,7 @@ class Model(Network):
             to apply a different weight to every timestep of every sample.
             In this case you should make sure to specify
             `sample_weight_mode="temporal"` in `compile()`. This argument is not
-            supported when `x` is a dataset iterator.
+            supported when `x` is a dataset or a dataset iterator.
         initial_epoch: Integer.
             Epoch at which to start training
             (useful for resuming a previous training run).
@@ -1252,7 +1259,8 @@ class Model(Network):
     # Prepare validation data.
     if validation_data:
       if (isinstance(validation_data, iterator_ops.Iterator) or
-          isinstance(validation_data, iterator_ops.EagerIterator)):
+          isinstance(validation_data, iterator_ops.EagerIterator) or
+          isinstance(validation_data, dataset_ops.Dataset)):
         val_x = validation_data
         val_y = None
         val_sample_weight = None
@@ -1266,8 +1274,9 @@ class Model(Network):
             'When passing a `validation_data` argument, '
             'it must contain either 2 items (x_val, y_val), '
             'or 3 items (x_val, y_val, val_sample_weights), '
-            'or alternatively it could be a dataset iterator. However we '
-            'received `validation_data=%s`' % validation_data)
+            'or alternatively it could be a dataset or a '
+            'dataset or a dataset iterator. '
+            'However we received `validation_data=%s`' % validation_data)
 
       # Validate and standardize validation data.
       val_x, val_y, val_sample_weights = self._standardize_user_data(
@@ -1351,19 +1360,19 @@ class Model(Network):
             (in case the model has multiple inputs).
           - A dict mapping input names to the corresponding array/tensors,
             if the model has named inputs.
-          - A `tf.data` dataset iterator.
+          - A `tf.data` dataset or a dataset iterator.
         y: Target data. Like the input data `x`,
           it could be either Numpy array(s) or TensorFlow tensor(s).
           It should be consistent with `x` (you cannot have Numpy inputs and
-          tensor targets, or inversely). If `x` is a dataset iterator,
-          `y` should not be specified
-          (since targets will be obtained from the iterator).
+          tensor targets, or inversely).
+          If `x` is a dataset or a dataset iterator, `y` should not be specified
+          (since targets will be obtained from the iterator/dataset).
         batch_size: Integer or `None`.
             Number of samples per gradient update.
             If unspecified, `batch_size` will default to 32.
             Do not specify the `batch_size` is your data is in the
-            form of symbolic tensors or dataset iterators (since they generate
-            batches).
+            form of symbolic tensors, datasets, or dataset iterators
+            (since they generate batches).
         verbose: 0 or 1. Verbosity mode.
             0 = silent, 1 = progress bar.
         sample_weight: Optional Numpy array of weights for
@@ -1377,7 +1386,7 @@ class Model(Network):
             to apply a different weight to every timestep of every sample.
             In this case you should make sure to specify
             `sample_weight_mode="temporal"` in `compile()`. This argument is not
-            supported when `x` is a dataset iterator.
+            supported when `x` is a dataset or a dataset iterator.
         steps: Integer or `None`.
             Total number of steps (batches of samples)
             before declaring the evaluation round finished.
@@ -1426,13 +1435,13 @@ class Model(Network):
             (in case the model has multiple inputs).
           - A TensorFlow tensor, or a list of tensors
             (in case the model has multiple inputs).
-          - A `tf.data` dataset iterator.
+          - A `tf.data` dataset or a dataset iterator.
         batch_size: Integer or `None`.
             Number of samples per gradient update.
             If unspecified, `batch_size` will default to 32.
             Do not specify the `batch_size` is your data is in the
-            form of symbolic tensors or dataset iterators (since they generate
-            batches).
+            form of symbolic tensors, dataset, or dataset iterators
+            (since they generate batches).
         verbose: Verbosity mode, 0 or 1.
         steps: Total number of steps (batches of samples)
             before declaring the prediction round finished.
@@ -1473,12 +1482,12 @@ class Model(Network):
             (in case the model has multiple inputs).
           - A dict mapping input names to the corresponding array/tensors,
             if the model has named inputs.
-          - A `tf.data` dataset iterator.
+          - A `tf.data` dataset or a dataset iterator.
         y: Target data. Like the input data `x`,
           it could be either Numpy array(s) or TensorFlow tensor(s).
           It should be consistent with `x` (you cannot have Numpy inputs and
-          tensor targets, or inversely). If `x` is a dataset iterator,
-          `y` should not be specified
+          tensor targets, or inversely). If `x` is a dataset or a
+          dataset iterator, `y` should not be specified
           (since targets will be obtained from the iterator).
         sample_weight: Optional array of the same length as x, containing
             weights to apply to the model's loss for each sample.
@@ -1487,8 +1496,7 @@ class Model(Network):
             to apply a different weight to every timestep of every sample.
             In this case you should make sure to specify
             sample_weight_mode="temporal" in compile(). This argument is not
-            supported when `x` is a dataset iterator.
-
+            supported when `x` is a dataset or a dataset iterator.
         class_weight: Optional dictionary mapping
             class indices (integers) to
             a weight (float) to apply to the model's loss for the samples
@@ -1537,12 +1545,12 @@ class Model(Network):
             (in case the model has multiple inputs).
           - A dict mapping input names to the corresponding array/tensors,
             if the model has named inputs.
-          - A `tf.data` dataset iterator.
+          - A `tf.data` dataset or a dataset iterator.
         y: Target data. Like the input data `x`,
           it could be either Numpy array(s) or TensorFlow tensor(s).
           It should be consistent with `x` (you cannot have Numpy inputs and
-          tensor targets, or inversely). If `x` is a dataset iterator,
-          `y` should not be specified
+          tensor targets, or inversely). If `x` is a dataset or a
+          dataset iterator, `y` should not be specified
           (since targets will be obtained from the iterator).
         sample_weight: Optional array of the same length as x, containing
             weights to apply to the model's loss for each sample.
@@ -1551,7 +1559,7 @@ class Model(Network):
             to apply a different weight to every timestep of every sample.
             In this case you should make sure to specify
             sample_weight_mode="temporal" in compile(). This argument is not
-            supported when `x` is a dataset iterator.
+            supported when `x` is a dataset or a dataset iterator.
 
     Returns:
         Scalar test loss (if the model has a single output and no metrics)
@@ -1590,7 +1598,7 @@ class Model(Network):
             (in case the model has multiple inputs).
           - A TensorFlow tensor, or a list of tensors
             (in case the model has multiple inputs).
-          - A `tf.data` dataset iterator.
+          - A `tf.data` dataset or a dataset iterator.
 
     Returns:
         Numpy array(s) of predictions.
index 7dec0bb..222e349 100644 (file)
@@ -1742,7 +1742,7 @@ class TestTrainingWithDatasetIterators(test.TestCase):
       # Test with validation split
       with self.assertRaisesRegexp(
           ValueError, '`validation_split` argument is not supported '
-          'when input `x` is a dataset iterator'):
+          'when input `x` is a dataset or a dataset iterator'):
         model.fit(iterator,
                   epochs=1, steps_per_epoch=2, verbose=0,
                   validation_split=0.5, validation_steps=2)
@@ -1751,7 +1751,7 @@ class TestTrainingWithDatasetIterators(test.TestCase):
       sample_weight = np.random.random((10,))
       with self.assertRaisesRegexp(
           ValueError, '`sample_weight` argument is not supported '
-          'when input `x` is a dataset iterator'):
+          'when input `x` is a dataset or a dataset iterator'):
         model.fit(
             iterator,
             epochs=1,
@@ -1761,10 +1761,6 @@ class TestTrainingWithDatasetIterators(test.TestCase):
 
       # Test invalid usage
       with self.assertRaisesRegexp(ValueError,
-                                   'Instead, pass an `Iterator`'):
-        model.fit(dataset,
-                  epochs=1, steps_per_epoch=2, verbose=0)
-      with self.assertRaisesRegexp(ValueError,
                                    'you should not specify a target'):
         model.fit(iterator, iterator,
                   epochs=1, steps_per_epoch=2, verbose=0)
@@ -1829,5 +1825,98 @@ class TestTrainingWithDatasetIterators(test.TestCase):
             'dataset iterator ran out of data')
 
 
+class TestTrainingWithDataset(test.TestCase):
+
+  def test_calling_model_on_same_dataset(self):
+    with self.test_session():
+      x = keras.layers.Input(shape=(3,), name='input')
+      y = keras.layers.Dense(4, name='dense')(x)
+      model = keras.Model(x, y)
+
+      optimizer = RMSPropOptimizer(learning_rate=0.001)
+      loss = 'mse'
+      metrics = ['mae']
+      model.compile(optimizer, loss, metrics=metrics)
+
+      inputs = np.zeros((10, 3), dtype=np.float32)
+      targets = np.zeros((10, 4), dtype=np.float32)
+      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+      dataset = dataset.repeat(100)
+      dataset = dataset.batch(10)
+
+      # Call fit with validation data
+      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+                validation_data=dataset, validation_steps=2)
+      # Finalize the graph to make sure new ops aren't added when calling on the
+      # same dataset
+      ops.get_default_graph().finalize()
+      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+                validation_data=dataset, validation_steps=2)
+
+  @tf_test_util.run_in_graph_and_eager_modes()
+  def test_training_and_eval_methods_on_dataset(self):
+    with self.test_session():
+      x = keras.layers.Input(shape=(3,), name='input')
+      y = keras.layers.Dense(4, name='dense')(x)
+      model = keras.Model(x, y)
+
+      optimizer = RMSPropOptimizer(learning_rate=0.001)
+      loss = 'mse'
+      metrics = ['mae']
+      model.compile(optimizer, loss, metrics=metrics)
+
+      inputs = np.zeros((10, 3), dtype=np.float32)
+      targets = np.zeros((10, 4), dtype=np.float32)
+      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+      dataset = dataset.repeat(100)
+      dataset = dataset.batch(10)
+
+      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+      model.evaluate(dataset, steps=2, verbose=1)
+      model.predict(dataset, steps=2)
+      model.train_on_batch(dataset)
+      model.predict_on_batch(dataset)
+
+      # Test with validation data
+      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+                validation_data=dataset, validation_steps=2)
+
+      # Test with validation split
+      with self.assertRaisesRegexp(
+          ValueError, '`validation_split` argument is not supported '
+          'when input `x` is a dataset or a dataset iterator'):
+        model.fit(dataset,
+                  epochs=1, steps_per_epoch=2, verbose=0,
+                  validation_split=0.5, validation_steps=2)
+
+      # Test with sample weight.
+      sample_weight = np.random.random((10,))
+      with self.assertRaisesRegexp(
+          ValueError, '`sample_weight` argument is not supported '
+          'when input `x` is a dataset or a dataset iterator'):
+        model.fit(
+            dataset,
+            epochs=1,
+            steps_per_epoch=2,
+            verbose=0,
+            sample_weight=sample_weight)
+
+      # Test invalid usage
+      with self.assertRaisesRegexp(ValueError,
+                                   'you should not specify a target'):
+        model.fit(dataset, dataset,
+                  epochs=1, steps_per_epoch=2, verbose=0)
+
+      with self.assertRaisesRegexp(
+          ValueError, 'you should specify the `steps_per_epoch` argument'):
+        model.fit(dataset, epochs=1, verbose=0)
+      with self.assertRaisesRegexp(ValueError,
+                                   'you should specify the `steps` argument'):
+        model.evaluate(dataset, verbose=0)
+      with self.assertRaisesRegexp(ValueError,
+                                   'you should specify the `steps` argument'):
+        model.predict(dataset, verbose=0)
+
+
 if __name__ == '__main__':
   test.main()
index 7d214d6..c53948b 100644 (file)
@@ -632,19 +632,20 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
         provided by user.
   """
   if y is not None:
-    raise ValueError('You passed a dataset iterator (%s) as input `x` to '
-                     'your model. In that case, you should not specify '
-                     'a target (`y`) argument, since the dataset iterator '
-                     'generates both input data and target data. '
+    raise ValueError('You passed a dataset or dataset iterator (%s) as '
+                     'input `x` to your model. In that case, you should '
+                     'not specify a target (`y`) argument, since the dataset '
+                     'or dataset iterator generates both input data and '
+                     'target data. '
                      'Received: %s' % (x, y))
   if sample_weight is not None:
-    raise ValueError('`sample_weight` argument is not supported when input'
-                     ' `x` is a dataset iterator. '
+    raise ValueError('`sample_weight` argument is not supported when input '
+                     '`x` is a dataset or a dataset iterator. '
                      'Received: x=%s, sample_weight=%s' % (x, sample_weight))
   if validation_split is not None and validation_split != 0.0:
     raise ValueError(
         '`validation_split` argument is not supported when '
-        'input `x` is a dataset iterator. '
+        'input `x` is a dataset or a dataset iterator. '
         'Received: x=%s, validation_split=%f' % (x, validation_split))