Keras: Supply `maximum_iterations` to the TF backend when possible.
authorRussell Power <power@google.com>
Sun, 29 Apr 2018 22:37:12 +0000 (15:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 29 Apr 2018 22:39:50 +0000 (15:39 -0700)
PiperOrigin-RevId: 194723199

tensorflow/contrib/tpu/python/tpu/keras_support.py
tensorflow/python/keras/_impl/keras/backend.py
tensorflow/python/keras/_impl/keras/layers/wrappers.py

index e86ca0a..b1d8d38 100644 (file)
@@ -66,7 +66,6 @@ from tensorflow.python.keras._impl.keras.layers import embeddings
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import training_util
 
 
 class TPUEmbedding(embeddings.Embedding):
@@ -126,7 +125,9 @@ class TPUFunction(object):
     """Specialize `self.model` (a Keras model) for the given input shapes."""
     # Re-create our input and output layers inside our subgraph.  They will be
     # attached to the true computation when we clone our model in `tpu_fn`.
-    K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN)
+    K.set_learning_phase(
+        self.execution_mode == model_fn_lib.ModeKeys.TRAIN
+    )
 
     # functools.partial and callable objects are not supported by tpu.rewrite
     def _model_fn():
@@ -161,9 +162,6 @@ class TPUFunction(object):
         if layer in self.model._output_layers:
           tpu_targets.append(tensor)
 
-      optimizer = self.model.optimizer
-      optimizer.iterations = training_util.get_or_create_global_step()
-
       # Call our model with our infeed inputs (re-using the weights).
       model_outputs = self.model(tpu_inputs)
       child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs)
@@ -219,8 +217,6 @@ class TPUFunction(object):
 
     tpu_execute_op = tpu.rewrite(_model_fn)
 
-    K._initialize_variables(K.get_session())  # pylint-disable: protected-access
-
     # Generate CPU side operations to enqueue features/labels and dequeue
     # outputs from the model call.
     with ops.device('/device:TPU:0'):
@@ -296,7 +292,6 @@ def setup_tpu_session(master):
       target=master, config=config_pb2.ConfigProto(isolate_session_state=True))
   K.set_session(session)
   K.get_session().run(tpu.initialize_system())
-  K.manual_variable_initialization(True)
   return session
 
 
@@ -357,10 +352,6 @@ class KerasTPUModel(models.Model):
       raise ValueError(
           'Optimizer must be a TFOptimizer, got: %s' % self.optimizer)
 
-  def train_on_batch(self, x, y, sample_weight=None, class_weight=None):
-    return super(KerasTPUModel, self).train_on_batch(x, y, sample_weight,
-                                                     class_weight)
-
   def _make_train_function(self):
     if not self.train_function:
       self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN)
@@ -378,14 +369,58 @@ class KerasTPUModel(models.Model):
     return self.predict_function
 
   def cpu_model(self):
-    return models.Model(
+    cpu_model = models.Model(
         inputs=self.inputs,
         outputs=self.outputs,
         name=self.name,
     )
 
+    if self.optimizer:
+      cpu_model.compile(
+          optimizer=self.optimizer,
+          loss=self.loss,
+          metrics=self.metrics,
+          loss_weights=self.loss_weights,
+      )
+
+    return cpu_model
+
+
+def _validate_shapes(model):
+  """Validate that all layers in `model` have constant shape."""
+  for layer in model.layers:
+    if isinstance(layer.input_shape, tuple):
+      input_shapes = [layer.input_shape]
+    else:
+      input_shapes = layer.input_shape
+
+    if isinstance(layer.output_shape, tuple):
+      output_shapes = [layer.output_shape]
+    else:
+      output_shapes = layer.output_shape
+
+    for shape in input_shapes + output_shapes:
+      for dim in shape[1:]:
+        if dim is None:
+          raise ValueError(
+              """
+Layer %(layer)s has a variable shape in a non-batch dimension.  TPU models must
+have constant shapes for all operations.
+
+You may have to specify `input_length` for RNN/TimeDistributed layers.
+
+Layer: %(layer)s
+Input shape: %(input_shape)s
+Output shape: %(output_shape)s
+  """ % {
+      'layer': layer,
+      'input_shape': layer.input_shape,
+      'output_shape': layer.output_shape
+      })
+
 
 @experimental
 def tpu_model(model):
+  _validate_shapes(model)
   return KerasTPUModel(
       inputs=model.inputs, outputs=model.outputs, name=model.name)
index 449410f..b1f1270 100644 (file)
@@ -2998,7 +2998,7 @@ def rnn(step_function,
       constants: a list of constant values passed at each step.
       unroll: whether to unroll the RNN or to use a symbolic loop
           (`while_loop` or `scan` depending on backend).
-      input_length: Unused; exists for API compatibility.
+      input_length: If specified, assume time dimension is of this length.
 
   Returns:
       A tuple, `(last_output, outputs, new_states)`.
@@ -3016,7 +3016,6 @@ def rnn(step_function,
       ValueError: if `mask` is provided (not `None`) but states is not provided
           (`len(states)` == 0).
   """
-  del input_length
   ndim = len(inputs.get_shape())
   if ndim < 3:
     raise ValueError('Input should be at least 3D.')
@@ -3194,6 +3193,7 @@ def rnn(step_function,
         cond=lambda time, *_: time < time_steps,
         body=_step,
         loop_vars=(time, output_ta) + states,
+        maximum_iterations=input_length,
         parallel_iterations=32,
         swap_memory=True)
     last_time = final_outputs[0]
index 34a8eee..91b8c11 100644 (file)
@@ -201,6 +201,7 @@ class TimeDistributed(Wrapper):
           step,
           inputs,
           initial_states=[],
+          input_length=input_shape[0],
           unroll=False)
       y = outputs
     else: