Add shape validation for symbolic tensors passed to fit (only graph mode).
authorFrancois Chollet <fchollet@google.com>
Thu, 24 May 2018 18:11:42 +0000 (11:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 18:14:25 +0000 (11:14 -0700)
PiperOrigin-RevId: 197921675

tensorflow/python/keras/engine/training.py
tensorflow/python/keras/engine/training_test.py
tensorflow/python/keras/engine/training_utils.py

index 0db805c..6d625f1 100644 (file)
@@ -846,7 +846,8 @@ class Model(Network):
     # in the case where all inputs are value arrays.
 
     if context.executing_eagerly():
-      # In eager mode, do not do shape validation.
+      # In eager mode, do not do shape validation
+      # since the network has no input nodes (placeholders) to be fed.
       feed_input_names = self.input_names
       feed_input_shapes = None
     elif not self._is_graph_network:
index 222e349..5c02d36 100644 (file)
@@ -1917,6 +1917,37 @@ class TestTrainingWithDataset(test.TestCase):
                                    'you should specify the `steps` argument'):
         model.predict(dataset, verbose=0)
 
+  def test_dataset_input_shape_validation(self):
+    with self.test_session():
+      x = keras.layers.Input(shape=(3,), name='input')
+      y = keras.layers.Dense(4, name='dense')(x)
+      model = keras.Model(x, y)
+
+      optimizer = RMSPropOptimizer(learning_rate=0.001)
+      loss = 'mse'
+      model.compile(optimizer, loss)
+
+      # User forgets to batch the dataset
+      inputs = np.zeros((10, 3), dtype=np.float32)
+      targets = np.zeros((10, 4), dtype=np.float32)
+      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+      dataset = dataset.repeat(100)
+
+      with self.assertRaisesRegexp(ValueError,
+                                   'expected input to have 2 dimensions'):
+        model.train_on_batch(dataset)
+
+      # Wrong input shape
+      inputs = np.zeros((10, 5), dtype=np.float32)
+      targets = np.zeros((10, 4), dtype=np.float32)
+      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+      dataset = dataset.repeat(100)
+      dataset = dataset.batch(10)
+
+      with self.assertRaisesRegexp(ValueError,
+                                   'expected input to have shape'):
+        model.train_on_batch(dataset)
+
 
 if __name__ == '__main__':
   test.main()
index c53948b..b93f999 100644 (file)
@@ -166,10 +166,16 @@ def standardize_input_data(data,
   # Check shapes compatibility.
   if shapes:
     for i in range(len(names)):
-      if shapes[i] is not None and not tensor_util.is_tensor(data[i]):
-        data_shape = data[i].shape
+      if shapes[i] is not None:
+        if tensor_util.is_tensor(data[i]):
+          tensorshape = data[i].get_shape()
+          if not tensorshape:
+            continue
+          data_shape = tuple(tensorshape.as_list())
+        else:
+          data_shape = data[i].shape
         shape = shapes[i]
-        if data[i].ndim != len(shape):
+        if len(data_shape) != len(shape):
           raise ValueError('Error when checking ' + exception_prefix +
                            ': expected ' + names[i] + ' to have ' +
                            str(len(shape)) + ' dimensions, but got array '
@@ -178,7 +184,7 @@ def standardize_input_data(data,
           data_shape = data_shape[1:]
           shape = shape[1:]
         for dim, ref_dim in zip(data_shape, shape):
-          if ref_dim != dim and ref_dim:
+          if ref_dim != dim and ref_dim is not None and dim is not None:
             raise ValueError(
                 'Error when checking ' + exception_prefix + ': expected ' +
                 names[i] + ' to have shape ' + str(shape) +