Fix bug in which the ConvLSTM2D layer could not be cloned.
authorFrancois Chollet <fchollet@google.com>
Wed, 9 May 2018 22:36:34 +0000 (15:36 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 22:39:24 +0000 (15:39 -0700)
PiperOrigin-RevId: 196040413

tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py

index be25bbc..5e20042 100644 (file)
@@ -609,16 +609,25 @@ class ConvLSTM2DCell(Layer):
         name='recurrent_kernel',
         regularizer=self.recurrent_regularizer,
         constraint=self.recurrent_constraint)
+
     if self.use_bias:
-      self.bias = self.add_weight(shape=(self.filters * 4,),
-                                  initializer=self.bias_initializer,
-                                  name='bias',
-                                  regularizer=self.bias_regularizer,
-                                  constraint=self.bias_constraint)
       if self.unit_forget_bias:
-        bias_value = np.zeros((self.filters * 4,))
-        bias_value[self.filters: self.filters * 2] = 1.
-        K.set_value(self.bias, bias_value)
+
+        def bias_initializer(_, *args, **kwargs):
+          return K.concatenate([
+              self.bias_initializer((self.filters,), *args, **kwargs),
+              initializers.Ones()((self.filters,), *args, **kwargs),
+              self.bias_initializer((self.filters * 2,), *args, **kwargs),
+          ])
+      else:
+        bias_initializer = self.bias_initializer
+      self.bias = self.add_weight(
+          shape=(self.filters * 4,),
+          name='bias',
+          initializer=bias_initializer,
+          regularizer=self.bias_regularizer,
+          constraint=self.bias_constraint)
+
     else:
       self.bias = None
 
index 9e768b4..827a7ff 100644 (file)
@@ -180,6 +180,23 @@ class ConvLSTMTest(test.TestCase):
                   'recurrent_dropout': 0.1},
           input_shape=(1, 2, 5, 5, 2))
 
+  def test_conv_lstm_cloning(self):
+    with self.test_session():
+      model = keras.models.Sequential()
+      model.add(keras.layers.ConvLSTM2D(5, 3, input_shape=(None, 5, 5, 3)))
+
+      test_inputs = np.random.random((2, 4, 5, 5, 3))
+      reference_outputs = model.predict(test_inputs)
+      weights = model.get_weights()
+
+    # Use a new graph to clone the model
+    with self.test_session():
+      clone = keras.models.clone_model(model)
+      clone.set_weights(weights)
+
+      outputs = clone.predict(test_inputs)
+      self.assertAllClose(reference_outputs, outputs, atol=1e-5)
+
 
 if __name__ == '__main__':
   test.main()