Sync only the convolutional_recurrent file to Keras 2.1.5.
authorAnjali Sridhar <anjalisridhar@google.com>
Thu, 5 Apr 2018 17:41:40 +0000 (10:41 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 17:44:33 +0000 (10:44 -0700)
PiperOrigin-RevId: 191763101

tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt

index b78962d..6b2a1d9 100644 (file)
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
+# pylint: disable=protected-access
 """Convolutional-recurrent layers.
 """
 from __future__ import absolute_import
@@ -26,181 +27,456 @@ from tensorflow.python.keras._impl.keras import constraints
 from tensorflow.python.keras._impl.keras import initializers
 from tensorflow.python.keras._impl.keras import regularizers
 from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
 from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
-from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent
+from tensorflow.python.keras._impl.keras.layers.recurrent import _generate_dropout_mask
+from tensorflow.python.keras._impl.keras.layers.recurrent import RNN
 from tensorflow.python.keras._impl.keras.utils import conv_utils
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.keras._impl.keras.utils import generic_utils
 from tensorflow.python.util.tf_export import tf_export
 
 
-class ConvRecurrent2D(Recurrent):
-  """Abstract base class for convolutional recurrent layers.
-
-  Do not use in a model -- it's not a functional layer!
+class ConvRNN2D(RNN):
+  """Base class for convolutional-recurrent layers.
 
   Arguments:
-      filters: Integer, the dimensionality of the output space
-          (i.e. the number of output filters in the convolution).
-      kernel_size: An integer or tuple/list of n integers, specifying the
-          dimensions of the convolution window.
-      strides: An integer or tuple/list of n integers,
-          specifying the strides of the convolution.
-          Specifying any stride value != 1 is incompatible with specifying
-          any `dilation_rate` value != 1.
-      padding: One of `"valid"` or `"same"` (case-insensitive).
-      data_format: A string,
-          one of `channels_last` (default) or `channels_first`.
-          The ordering of the dimensions in the inputs.
-          `channels_last` corresponds to inputs with shape
-          `(batch, time, ..., channels)`
-          while `channels_first` corresponds to
-          inputs with shape `(batch, time, channels, ...)`.
-          It defaults to the `image_data_format` value found in your
-          Keras config file at `~/.keras/keras.json`.
-          If you never set it, then it will be "channels_last".
-      dilation_rate: An integer or tuple/list of n integers, specifying
-          the dilation rate to use for dilated convolution.
-          Currently, specifying any `dilation_rate` value != 1 is
-          incompatible with specifying any `strides` value != 1.
-      return_sequences: Boolean. Whether to return the last output
-          in the output sequence, or the full sequence.
-      go_backwards: Boolean (default False).
-          If True, rocess the input sequence backwards.
-      stateful: Boolean (default False). If True, the last state
-          for each sample at index i in a batch will be used as initial
-          state for the sample of index i in the following batch.
+    cell: A RNN cell instance. A RNN cell is a class that has:
+        - a `call(input_at_t, states_at_t)` method, returning
+            `(output_at_t, states_at_t_plus_1)`. The call method of the
+            cell can also take the optional argument `constants`, see
+            section "Note on passing external constants" below.
+        - a `state_size` attribute. This can be a single integer
+            (single state) in which case it is
+            the number of channels of the recurrent state
+            (which should be the same as the number of channels of the cell
+            output). This can also be a list/tuple of integers
+            (one size per state). In this case, the first entry
+            (`state_size[0]`) should be the same as
+            the size of the cell output.
+    return_sequences: Boolean. Whether to return the last output.
+        in the output sequence, or the full sequence.
+    return_state: Boolean. Whether to return the last state
+        in addition to the output.
+    go_backwards: Boolean (default False).
+        If True, process the input sequence backwards and return the
+        reversed sequence.
+    stateful: Boolean (default False). If True, the last state
+        for each sample at index i in a batch will be used as initial
+        state for the sample of index i in the following batch.
+    input_shape: Use this argument to specify the shape of the
+        input when this layer is the first one in a model.
 
   Input shape:
-      5D tensor with shape `(num_samples, timesteps, channels, rows, cols)`.
+    5D tensor with shape:
+    `(samples, timesteps, channels, rows, cols)`
+    if data_format='channels_first' or 5D tensor with shape:
+    `(samples, timesteps, rows, cols, channels)`
+    if data_format='channels_last'.
 
   Output shape:
-      - if `return_sequences`: 5D tensor with shape
-          `(num_samples, timesteps, channels, rows, cols)`.
-      - else, 4D tensor with shape `(num_samples, channels, rows, cols)`.
-
-  # Masking
-      This layer supports masking for input data with a variable number
-      of timesteps. To introduce masks to your data,
-      use an `Embedding` layer with the `mask_zero` parameter
-      set to `True`.
-      **Note:** for the time being, masking is only supported with Theano.
-
-  # Note on using statefulness in RNNs
-      You can set RNN layers to be 'stateful', which means that the states
-      computed for the samples in one batch will be reused as initial states
-      for the samples in the next batch.
-      This assumes a one-to-one mapping between
-      samples in different successive batches.
-
-      To enable statefulness:
-          - specify `stateful=True` in the layer constructor.
-          - specify a fixed batch size for your model, by passing
-              a `batch_input_size=(...)` to the first layer in your model.
-              This is the expected shape of your inputs *including the batch
-              size*.
-              It should be a tuple of integers, e.g. `(32, 10, 100)`.
-
-      To reset the states of your model, call `.reset_states()` on either
-      a specific layer, or on your entire model.
+    - if `return_state`: a list of tensors. The first tensor is
+        the output. The remaining tensors are the last states,
+        each 5D tensor with shape:
+        `(samples, timesteps, filters, new_rows, new_cols)`
+        if data_format='channels_first'
+        or 5D tensor with shape:
+        `(samples, timesteps, new_rows, new_cols, filters)`
+        if data_format='channels_last'.
+        `rows` and `cols` values might have changed due to padding.
+    - if `return_sequences`: 5D tensor with shape:
+        `(samples, timesteps, filters, new_rows, new_cols)`
+        if data_format='channels_first'
+        or 5D tensor with shape:
+        `(samples, timesteps, new_rows, new_cols, filters)`
+        if data_format='channels_last'.
+    - else, 4D tensor with shape:
+        `(samples, filters, new_rows, new_cols)`
+        if data_format='channels_first'
+        or 4D tensor with shape:
+        `(samples, new_rows, new_cols, filters)`
+        if data_format='channels_last'.
+
+  Masking:
+    This layer supports masking for input data with a variable number
+    of timesteps. To introduce masks to your data,
+    use an Embedding layer with the `mask_zero` parameter
+    set to `True`.
+
+  Note on using statefulness in RNNs:
+    You can set RNN layers to be 'stateful', which means that the states
+    computed for the samples in one batch will be reused as initial states
+    for the samples in the next batch. This assumes a one-to-one mapping
+    between samples in different successive batches.
+    To enable statefulness:
+        - specify `stateful=True` in the layer constructor.
+        - specify a fixed batch size for your model, by passing
+             - if sequential model:
+                `batch_input_shape=(...)` to the first layer in your model.
+             - if functional model with 1 or more Input layers:
+                `batch_shape=(...)` to all the first layers in your model.
+                This is the expected shape of your inputs
+                *including the batch size*.
+                It should be a tuple of integers,
+                e.g. `(32, 10, 100, 100, 32)`.
+                Note that the number of rows and columns should be specified
+                too.
+        - specify `shuffle=False` when calling fit().
+    To reset the states of your model, call `.reset_states()` on either
+    a specific layer, or on your entire model.
+
+  Note on specifying the initial state of RNNs:
+    You can specify the initial state of RNN layers symbolically by
+    calling them with the keyword argument `initial_state`. The value of
+    `initial_state` should be a tensor or list of tensors representing
+    the initial state of the RNN layer.
+    You can specify the initial state of RNN layers numerically by
+    calling `reset_states` with the keyword argument `states`. The value of
+    `states` should be a numpy array or list of numpy arrays representing
+    the initial state of the RNN layer.
+
+  Note on passing external constants to RNNs:
+    You can pass "external" constants to the cell using the `constants`
+    keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
+    requires that the `cell.call` method accepts the same keyword argument
+    `constants`. Such constants can be used to condition the cell
+    transformation on additional static inputs (not changing over time),
+    a.k.a. an attention mechanism.
   """
 
   def __init__(self,
-               filters,
-               kernel_size,
-               strides=(1, 1),
-               padding='valid',
-               data_format=None,
-               dilation_rate=(1, 1),
+               cell,
                return_sequences=False,
+               return_state=False,
                go_backwards=False,
                stateful=False,
+               unroll=False,
                **kwargs):
-    super(ConvRecurrent2D, self).__init__(**kwargs)
-    self.filters = filters
-    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
-    self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
-    self.padding = conv_utils.normalize_padding(padding)
-    self.data_format = conv_utils.normalize_data_format(data_format)
-    self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2,
-                                                    'dilation_rate')
-    self.return_sequences = return_sequences
-    self.go_backwards = go_backwards
-    self.stateful = stateful
+    if unroll:
+      raise TypeError('Unrolling isn\'t possible with '
+                      'convolutional RNNs.')
+    if isinstance(cell, (list, tuple)):
+      # The StackedConvRNN2DCells isn't implemented yet.
+      raise TypeError('It is not possible at the moment to'
+                      'stack convolutional cells.')
+    super(ConvRNN2D, self).__init__(cell,
+                                    return_sequences,
+                                    return_state,
+                                    go_backwards,
+                                    stateful,
+                                    unroll,
+                                    **kwargs)
     self.input_spec = [InputSpec(ndim=5)]
-    self.state_spec = None
+    self.states = None
 
   @shape_type_conversion
   def compute_output_shape(self, input_shape):
     if isinstance(input_shape, list):
       input_shape = input_shape[0]
-    if self.data_format == 'channels_first':
+
+    cell = self.cell
+    if cell.data_format == 'channels_first':
       rows = input_shape[3]
       cols = input_shape[4]
-    elif self.data_format == 'channels_last':
+    elif cell.data_format == 'channels_last':
       rows = input_shape[2]
       cols = input_shape[3]
-    rows = conv_utils.conv_output_length(
-        rows,
-        self.kernel_size[0],
-        padding=self.padding,
-        stride=self.strides[0],
-        dilation=self.dilation_rate[0])
-    cols = conv_utils.conv_output_length(
-        cols,
-        self.kernel_size[1],
-        padding=self.padding,
-        stride=self.strides[1],
-        dilation=self.dilation_rate[1])
+    rows = conv_utils.conv_output_length(rows,
+                                         cell.kernel_size[0],
+                                         padding=cell.padding,
+                                         stride=cell.strides[0],
+                                         dilation=cell.dilation_rate[0])
+    cols = conv_utils.conv_output_length(cols,
+                                         cell.kernel_size[1],
+                                         padding=cell.padding,
+                                         stride=cell.strides[1],
+                                         dilation=cell.dilation_rate[1])
+
+    if cell.data_format == 'channels_first':
+      output_shape = input_shape[:2] + (cell.filters, rows, cols)
+    elif cell.data_format == 'channels_last':
+      output_shape = input_shape[:2] + (rows, cols, cell.filters)
+
+    if not self.return_sequences:
+      output_shape = output_shape[:1] + output_shape[2:]
+
+    if self.return_state:
+      output_shape = [output_shape]
+      if cell.data_format == 'channels_first':
+        output_shape += [(input_shape[0], cell.filters, rows, cols)
+                         for _ in range(2)]
+      elif cell.data_format == 'channels_last':
+        output_shape += [(input_shape[0], rows, cols, cell.filters)
+                         for _ in range(2)]
+    return output_shape
+
+  @shape_type_conversion
+  def build(self, input_shape):
+    # Note input_shape will be list of shapes of initial states and
+    # constants if these are passed in __call__.
+    if self._num_constants is not None:
+      constants_shape = input_shape[-self._num_constants:]
+    else:
+      constants_shape = None
+
+    if isinstance(input_shape, list):
+      input_shape = input_shape[0]
+
+    batch_size = input_shape[0] if self.stateful else None
+    self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5])
+
+    # allow cell (if layer) to build before we set or validate state_spec
+    if isinstance(self.cell, Layer):
+      step_input_shape = (input_shape[0],) + input_shape[2:]
+      if constants_shape is not None:
+        self.cell.build([step_input_shape] + constants_shape)
+      else:
+        self.cell.build(step_input_shape)
+
+    # set or validate state_spec
+    if hasattr(self.cell.state_size, '__len__'):
+      state_size = list(self.cell.state_size)
+    else:
+      state_size = [self.cell.state_size]
+
+    if self.state_spec is not None:
+      # initial_state was passed in call, check compatibility
+      if self.cell.data_format == 'channels_first':
+        ch_dim = 1
+      elif self.cell.data_format == 'channels_last':
+        ch_dim = 3
+      if [spec.shape[ch_dim] for spec in self.state_spec] != state_size:
+        raise ValueError(
+            'An initial_state was passed that is not compatible with '
+            '`cell.state_size`. Received `state_spec`={}; '
+            'However `cell.state_size` is '
+            '{}'.format([spec.shape for spec in self.state_spec],
+                        self.cell.state_size))
+    else:
+      if self.cell.data_format == 'channels_first':
+        self.state_spec = [InputSpec(shape=(None, dim, None, None))
+                           for dim in state_size]
+      elif self.cell.data_format == 'channels_last':
+        self.state_spec = [InputSpec(shape=(None, None, None, dim))
+                           for dim in state_size]
+    if self.stateful:
+      self.reset_states()
+    self.built = True
+
+  def get_initial_state(self, inputs):
+    # (samples, timesteps, rows, cols, filters)
+    initial_state = K.zeros_like(inputs)
+    # (samples, rows, cols, filters)
+    initial_state = K.sum(initial_state, axis=1)
+    shape = list(self.cell.kernel_shape)
+    shape[-1] = self.cell.filters
+    initial_state = self.cell.input_conv(initial_state,
+                                         K.zeros(tuple(shape)),
+                                         padding=self.cell.padding)
+
+    if hasattr(self.cell.state_size, '__len__'):
+      return [initial_state for _ in self.cell.state_size]
+    else:
+      return [initial_state]
+
+  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
+    inputs, initial_state, constants = self._standardize_args(
+        inputs, initial_state, constants)
+
+    if initial_state is None and constants is None:
+      return super(ConvRNN2D, self).__call__(inputs, **kwargs)
+
+    # If any of `initial_state` or `constants` are specified and are Keras
+    # tensors, then add them to the inputs and temporarily modify the
+    # input_spec to include them.
+
+    additional_inputs = []
+    additional_specs = []
+    if initial_state is not None:
+      kwargs['initial_state'] = initial_state
+      additional_inputs += initial_state
+      self.state_spec = []
+      for state in initial_state:
+        shape = K.int_shape(state)
+        self.state_spec.append(InputSpec(shape=shape))
+
+      additional_specs += self.state_spec
+    if constants is not None:
+      kwargs['constants'] = constants
+      additional_inputs += constants
+      self.constants_spec = [InputSpec(shape=K.int_shape(constant))
+                             for constant in constants]
+      self._num_constants = len(constants)
+      additional_specs += self.constants_spec
+    # at this point additional_inputs cannot be empty
+    for tensor in additional_inputs:
+      if K.is_keras_tensor(tensor) != K.is_keras_tensor(additional_inputs[0]):
+        raise ValueError('The initial state or constants of an RNN'
+                         ' layer cannot be specified with a mix of'
+                         ' Keras tensors and non-Keras tensors')
+
+    if K.is_keras_tensor(additional_inputs[0]):
+      # Compute the full input spec, including state and constants
+      full_input = [inputs] + additional_inputs
+      full_input_spec = self.input_spec + additional_specs
+      # Perform the call with temporarily replaced input_spec
+      original_input_spec = self.input_spec
+      self.input_spec = full_input_spec
+      output = super(ConvRNN2D, self).__call__(full_input, **kwargs)
+      self.input_spec = original_input_spec
+      return output
+    else:
+      return super(ConvRNN2D, self).__call__(inputs, **kwargs)
+
+  def call(self,
+           inputs,
+           mask=None,
+           training=None,
+           initial_state=None,
+           constants=None):
+    # note that the .build() method of subclasses MUST define
+    # self.input_spec and self.state_spec with complete input shapes.
+    if isinstance(inputs, list):
+      inputs = inputs[0]
+    if initial_state is not None:
+      pass
+    elif self.stateful:
+      initial_state = self.states
+    else:
+      initial_state = self.get_initial_state(inputs)
+
+    if isinstance(mask, list):
+      mask = mask[0]
+
+    if len(initial_state) != len(self.states):
+      raise ValueError('Layer has ' + str(len(self.states)) +
+                       ' states but was passed ' +
+                       str(len(initial_state)) +
+                       ' initial states.')
+    timesteps = K.int_shape(inputs)[1]
+
+    kwargs = {}
+    if generic_utils.has_arg(self.cell.call, 'training'):
+      kwargs['training'] = training
+
+    if constants:
+      if not generic_utils.has_arg(self.cell.call, 'constants'):
+        raise ValueError('RNN cell does not support constants')
+
+      def step(inputs, states):
+        constants = states[-self._num_constants:]
+        states = states[:-self._num_constants]
+        return self.cell.call(inputs, states, constants=constants,
+                              **kwargs)
+    else:
+      def step(inputs, states):
+        return self.cell.call(inputs, states, **kwargs)
+
+    last_output, outputs, states = K.rnn(step,
+                                         inputs,
+                                         initial_state,
+                                         constants=constants,
+                                         go_backwards=self.go_backwards,
+                                         mask=mask,
+                                         input_length=timesteps)
+    if self.stateful:
+      updates = []
+      for i in range(len(states)):
+        updates.append(K.update(self.states[i], states[i]))
+      self.add_update(updates, inputs=True)
+
     if self.return_sequences:
-      if self.data_format == 'channels_first':
-        output_shape = (input_shape[0], input_shape[1], self.filters, rows,
-                        cols)
-      elif self.data_format == 'channels_last':
-        output_shape = (input_shape[0], input_shape[1], rows, cols,
-                        self.filters)
+      output = outputs
     else:
-      if self.data_format == 'channels_first':
-        output_shape = (input_shape[0], self.filters, rows, cols)
-      elif self.data_format == 'channels_last':
-        output_shape = (input_shape[0], rows, cols, self.filters)
+      output = last_output
+
+    # Properly set learning phase
+    if getattr(last_output, '_uses_learning_phase', False):
+      output._uses_learning_phase = True
 
     if self.return_state:
-      if self.data_format == 'channels_first':
-        output_shape = [output_shape] + [
-            (input_shape[0], self.filters, rows, cols) for _ in range(2)
-        ]
-      elif self.data_format == 'channels_last':
-        output_shape = [output_shape] + [
-            (input_shape[0], rows, cols, self.filters) for _ in range(2)
-        ]
+      if not isinstance(states, (list, tuple)):
+        states = [states]
+      else:
+        states = list(states)
+      return [output] + states
+    else:
+      return output
 
-    return output_shape
+  def reset_states(self, states=None):
+    if not self.stateful:
+      raise AttributeError('Layer must be stateful.')
+    input_shape = self.input_spec[0].shape
+    state_shape = self.compute_output_shape(input_shape)
+    if self.return_state:
+      state_shape = state_shape[0]
+    if self.return_sequences:
+      state_shape = state_shape[:1].concatenate(state_shape[2:])
+    if None in state_shape:
+      raise ValueError('If a RNN is stateful, it needs to know '
+                       'its batch size. Specify the batch size '
+                       'of your input tensors: \n'
+                       '- If using a Sequential model, '
+                       'specify the batch size by passing '
+                       'a `batch_input_shape` '
+                       'argument to your first layer.\n'
+                       '- If using the functional API, specify '
+                       'the time dimension by passing a '
+                       '`batch_shape` argument to your Input layer.\n'
+                       'The same thing goes for the number of rows and '
+                       'columns.')
 
-  def get_config(self):
-    config = {
-        'filters': self.filters,
-        'kernel_size': self.kernel_size,
-        'strides': self.strides,
-        'padding': self.padding,
-        'data_format': self.data_format,
-        'dilation_rate': self.dilation_rate,
-        'return_sequences': self.return_sequences,
-        'go_backwards': self.go_backwards,
-        'stateful': self.stateful
-    }
-    base_config = super(ConvRecurrent2D, self).get_config()
-    return dict(list(base_config.items()) + list(config.items()))
+    # helper function
+    def get_tuple_shape(nb_channels):
+      result = list(state_shape)
+      if self.cell.data_format == 'channels_first':
+        result[1] = nb_channels
+      elif self.cell.data_format == 'channels_last':
+        result[3] = nb_channels
+      else:
+        raise KeyError
+      return tuple(result)
 
+    # initialize state if None
+    if self.states[0] is None:
+      if hasattr(self.cell.state_size, '__len__'):
+        self.states = [K.zeros(get_tuple_shape(dim))
+                       for dim in self.cell.state_size]
+      else:
+        self.states = [K.zeros(get_tuple_shape(self.cell.state_size))]
+    elif states is None:
+      if hasattr(self.cell.state_size, '__len__'):
+        for state, dim in zip(self.states, self.cell.state_size):
+          K.set_value(state, np.zeros(get_tuple_shape(dim)))
+      else:
+        K.set_value(self.states[0],
+                    np.zeros(get_tuple_shape(self.cell.state_size)))
+    else:
+      if not isinstance(states, (list, tuple)):
+        states = [states]
+      if len(states) != len(self.states):
+        raise ValueError('Layer ' + self.name + ' expects ' +
+                         str(len(self.states)) + ' states, ' +
+                         'but it received ' + str(len(states)) +
+                         ' state values. Input received: ' + str(states))
+      for index, (value, state) in enumerate(zip(states, self.states)):
+        if hasattr(self.cell.state_size, '__len__'):
+          dim = self.cell.state_size[index]
+        else:
+          dim = self.cell.state_size
+        if value.shape != get_tuple_shape(dim):
+          raise ValueError('State ' + str(index) +
+                           ' is incompatible with layer ' +
+                           self.name + ': expected shape=' +
+                           str(get_tuple_shape(dim)) +
+                           ', found shape=' + str(value.shape))
+        # TODO(anjalisridhar): consider batch calls to `set_value`.
+        K.set_value(state, value)
 
-@tf_export('keras.layers.ConvLSTM2D')
-class ConvLSTM2D(ConvRecurrent2D):
-  """Convolutional LSTM.
 
-  It is similar to an LSTM layer, but the input transformations
-  and recurrent transformations are both convolutional.
+class ConvLSTM2DCell(Layer):
+  """Cell class for the ConvLSTM2D layer.
 
-  Arguments:
+  # Arguments
       filters: Integer, the dimensionality of the output space
           (i.e. the number of output filters in the convolution).
       kernel_size: An integer or tuple/list of n integers, specifying the
@@ -212,11 +488,6 @@ class ConvLSTM2D(ConvRecurrent2D):
       padding: One of `"valid"` or `"same"` (case-insensitive).
       data_format: A string,
           one of `channels_last` (default) or `channels_first`.
-          The ordering of the dimensions in the inputs.
-          `channels_last` corresponds to inputs with shape
-          `(batch, time, ..., channels)`
-          while `channels_first` corresponds to
-          inputs with shape `(batch, time, channels, ...)`.
           It defaults to the `image_data_format` value found in your
           Keras config file at `~/.keras/keras.json`.
           If you never set it, then it will be "channels_last".
@@ -231,71 +502,32 @@ class ConvLSTM2D(ConvRecurrent2D):
           for the recurrent step.
       use_bias: Boolean, whether the layer uses a bias vector.
       kernel_initializer: Initializer for the `kernel` weights matrix,
-          used for the linear transformation of the inputs..
+          used for the linear transformation of the inputs.
       recurrent_initializer: Initializer for the `recurrent_kernel`
           weights matrix,
-          used for the linear transformation of the recurrent state..
+          used for the linear transformation of the recurrent state.
       bias_initializer: Initializer for the bias vector.
       unit_forget_bias: Boolean.
           If True, add 1 to the bias of the forget gate at initialization.
           Use in combination with `bias_initializer="zeros"`.
-          This is recommended in [Jozefowicz et
-            al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
+          This is recommended in [Jozefowicz et al.]
+          (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
       kernel_regularizer: Regularizer function applied to
           the `kernel` weights matrix.
       recurrent_regularizer: Regularizer function applied to
           the `recurrent_kernel` weights matrix.
       bias_regularizer: Regularizer function applied to the bias vector.
-      activity_regularizer: Regularizer function applied to
-          the output of the layer (its "activation")..
       kernel_constraint: Constraint function applied to
           the `kernel` weights matrix.
       recurrent_constraint: Constraint function applied to
           the `recurrent_kernel` weights matrix.
       bias_constraint: Constraint function applied to the bias vector.
-      return_sequences: Boolean. Whether to return the last output
-          in the output sequence, or the full sequence.
-      go_backwards: Boolean (default False).
-          If True, rocess the input sequence backwards.
-      stateful: Boolean (default False). If True, the last state
-          for each sample at index i in a batch will be used as initial
-          state for the sample of index i in the following batch.
       dropout: Float between 0 and 1.
           Fraction of the units to drop for
           the linear transformation of the inputs.
       recurrent_dropout: Float between 0 and 1.
           Fraction of the units to drop for
           the linear transformation of the recurrent state.
-
-  Input shape:
-      - if data_format='channels_first'
-          5D tensor with shape:
-          `(samples,time, channels, rows, cols)`
-      - if data_format='channels_last'
-          5D tensor with shape:
-          `(samples,time, rows, cols, channels)`
-
-   Output shape:
-      - if `return_sequences`
-           - if data_format='channels_first'
-              5D tensor with shape:
-              `(samples, time, filters, output_row, output_col)`
-           - if data_format='channels_last'
-              5D tensor with shape:
-              `(samples, time, output_row, output_col, filters)`
-      - else
-          - if data_format ='channels_first'
-              4D tensor with shape:
-              `(samples, filters, output_row, output_col)`
-          - if data_format='channels_last'
-              4D tensor with shape:
-              `(samples, output_row, output_col, filters)`
-          where o_row and o_col depend on the shape of the filter and
-          the padding
-
-  Raises:
-      ValueError: in case of invalid constructor arguments.
-
   """
 
   def __init__(self,
@@ -315,27 +547,20 @@ class ConvLSTM2D(ConvRecurrent2D):
                kernel_regularizer=None,
                recurrent_regularizer=None,
                bias_regularizer=None,
-               activity_regularizer=None,
                kernel_constraint=None,
                recurrent_constraint=None,
                bias_constraint=None,
-               return_sequences=False,
-               go_backwards=False,
-               stateful=False,
                dropout=0.,
                recurrent_dropout=0.,
                **kwargs):
-    super(ConvLSTM2D, self).__init__(
-        filters,
-        kernel_size,
-        strides=strides,
-        padding=padding,
-        data_format=data_format,
-        dilation_rate=dilation_rate,
-        return_sequences=return_sequences,
-        go_backwards=go_backwards,
-        stateful=stateful,
-        **kwargs)
+    super(ConvLSTM2DCell, self).__init__(**kwargs)
+    self.filters = filters
+    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
+    self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
+    self.padding = conv_utils.normalize_padding(padding)
+    self.data_format = conv_utils.normalize_data_format(data_format)
+    self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2,
+                                                    'dilation_rate')
     self.activation = activations.get(activation)
     self.recurrent_activation = activations.get(recurrent_activation)
     self.use_bias = use_bias
@@ -348,7 +573,6 @@ class ConvLSTM2D(ConvRecurrent2D):
     self.kernel_regularizer = regularizers.get(kernel_regularizer)
     self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
     self.bias_regularizer = regularizers.get(bias_regularizer)
-    self.activity_regularizer = regularizers.get(activity_regularizer)
 
     self.kernel_constraint = constraints.get(kernel_constraint)
     self.recurrent_constraint = constraints.get(recurrent_constraint)
@@ -356,45 +580,29 @@ class ConvLSTM2D(ConvRecurrent2D):
 
     self.dropout = min(1., max(0., dropout))
     self.recurrent_dropout = min(1., max(0., recurrent_dropout))
-    self.state_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
+    self.state_size = (self.filters, self.filters)
+    self._dropout_mask = None
+    self._recurrent_dropout_mask = None
 
-  @shape_type_conversion
   def build(self, input_shape):
-    if isinstance(input_shape, list):
-      input_shape = input_shape[0]
-    batch_size = input_shape[0] if self.stateful else None
-    self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:])
-    if self.stateful:
-      self.reset_states()
-    else:
-      # initial states: 2 all-zero tensor of shape (filters)
-      self.states = [None, None]
 
     if self.data_format == 'channels_first':
-      channel_axis = 2
+      channel_axis = 1
     else:
       channel_axis = -1
     if input_shape[channel_axis] is None:
       raise ValueError('The channel dimension of the inputs '
                        'should be defined. Found `None`.')
     input_dim = input_shape[channel_axis]
-    state_shape = [None] * 4
-    state_shape[channel_axis] = input_dim
-    state_shape = tuple(state_shape)
-    self.state_spec = [
-        InputSpec(shape=state_shape),
-        InputSpec(shape=state_shape)
-    ]
     kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
     self.kernel_shape = kernel_shape
     recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
 
-    self.kernel = self.add_weight(
-        shape=kernel_shape,
-        initializer=self.kernel_initializer,
-        name='kernel',
-        regularizer=self.kernel_regularizer,
-        constraint=self.kernel_constraint)
+    self.kernel = self.add_weight(shape=kernel_shape,
+                                  initializer=self.kernel_initializer,
+                                  name='kernel',
+                                  regularizer=self.kernel_regularizer,
+                                  constraint=self.kernel_constraint)
     self.recurrent_kernel = self.add_weight(
         shape=recurrent_kernel_shape,
         initializer=self.recurrent_initializer,
@@ -402,25 +610,24 @@ class ConvLSTM2D(ConvRecurrent2D):
         regularizer=self.recurrent_regularizer,
         constraint=self.recurrent_constraint)
     if self.use_bias:
-      self.bias = self.add_weight(
-          shape=(self.filters * 4,),
-          initializer=self.bias_initializer,
-          name='bias',
-          regularizer=self.bias_regularizer,
-          constraint=self.bias_constraint)
+      self.bias = self.add_weight(shape=(self.filters * 4,),
+                                  initializer=self.bias_initializer,
+                                  name='bias',
+                                  regularizer=self.bias_regularizer,
+                                  constraint=self.bias_constraint)
       if self.unit_forget_bias:
         bias_value = np.zeros((self.filters * 4,))
-        bias_value[self.filters:self.filters * 2] = 1.
+        bias_value[self.filters: self.filters * 2] = 1.
         K.set_value(self.bias, bias_value)
     else:
       self.bias = None
 
     self.kernel_i = self.kernel[:, :, :, :self.filters]
     self.recurrent_kernel_i = self.recurrent_kernel[:, :, :, :self.filters]
-    self.kernel_f = self.kernel[:, :, :, self.filters:self.filters * 2]
+    self.kernel_f = self.kernel[:, :, :, self.filters: self.filters * 2]
     self.recurrent_kernel_f = self.recurrent_kernel[:, :, :, self.filters:
                                                     self.filters * 2]
-    self.kernel_c = self.kernel[:, :, :, self.filters * 2:self.filters * 3]
+    self.kernel_c = self.kernel[:, :, :, self.filters * 2: self.filters * 3]
     self.recurrent_kernel_c = self.recurrent_kernel[:, :, :, self.filters * 2:
                                                     self.filters * 3]
     self.kernel_o = self.kernel[:, :, :, self.filters * 3:]
@@ -428,8 +635,8 @@ class ConvLSTM2D(ConvRecurrent2D):
 
     if self.use_bias:
       self.bias_i = self.bias[:self.filters]
-      self.bias_f = self.bias[self.filters:self.filters * 2]
-      self.bias_c = self.bias[self.filters * 2:self.filters * 3]
+      self.bias_f = self.bias[self.filters: self.filters * 2]
+      self.bias_c = self.bias[self.filters * 2: self.filters * 3]
       self.bias_o = self.bias[self.filters * 3:]
     else:
       self.bias_i = None
@@ -438,166 +645,419 @@ class ConvLSTM2D(ConvRecurrent2D):
       self.bias_o = None
     self.built = True
 
-  def get_initial_state(self, inputs):
-    # (samples, timesteps, rows, cols, filters)
-    initial_state = array_ops.zeros_like(inputs)
-    # (samples, rows, cols, filters)
-    initial_state = math_ops.reduce_sum(initial_state, axis=1)
-    shape = list(self.kernel_shape)
-    shape[-1] = self.filters
-    initial_state = self.input_conv(
-        initial_state, K.zeros(tuple(shape)), padding=self.padding)
-
-    initial_states = [initial_state for _ in range(2)]
-    return initial_states
+  def call(self, inputs, states, training=None):
+    if 0 < self.dropout < 1 and self._dropout_mask is None:
+      self._dropout_mask = _generate_dropout_mask(
+          K.ones_like(inputs),
+          self.dropout,
+          training=training,
+          count=4)
+    if (0 < self.recurrent_dropout < 1 and
+        self._recurrent_dropout_mask is None):
+      self._recurrent_dropout_mask = _generate_dropout_mask(
+          K.ones_like(states[1]),
+          self.recurrent_dropout,
+          training=training,
+          count=4)
 
-  def reset_states(self):
-    if not self.stateful:
-      raise RuntimeError('Layer must be stateful.')
-    input_shape = self.input_spec[0].shape
+    # dropout matrices for input units
+    dp_mask = self._dropout_mask
+    # dropout matrices for recurrent units
+    rec_dp_mask = self._recurrent_dropout_mask
 
-    if not input_shape[0]:
-      raise ValueError('If a RNN is stateful, a complete '
-                       'input_shape must be provided '
-                       '(including batch size). '
-                       'Got input shape: ' + str(input_shape))
+    h_tm1 = states[0]  # previous memory state
+    c_tm1 = states[1]  # previous carry state
 
-    if self.return_state:
-      output_shape = tuple(self.compute_output_shape(input_shape)[0].as_list())
-    else:
-      output_shape = tuple(self.compute_output_shape(input_shape).as_list())
-    if self.return_sequences:
-      output_shape = (input_shape[0],) + output_shape[2:]
+    if 0 < self.dropout < 1.:
+      inputs_i = inputs * dp_mask[0]
+      inputs_f = inputs * dp_mask[1]
+      inputs_c = inputs * dp_mask[2]
+      inputs_o = inputs * dp_mask[3]
     else:
-      output_shape = (input_shape[0],) + output_shape[1:]
+      inputs_i = inputs
+      inputs_f = inputs
+      inputs_c = inputs
+      inputs_o = inputs
 
-    if hasattr(self, 'states'):
-      K.set_value(self.states[0],
-                  np.zeros(output_shape))
-      K.set_value(self.states[1],
-                  np.zeros(output_shape))
+    if 0 < self.recurrent_dropout < 1.:
+      h_tm1_i = h_tm1 * rec_dp_mask[0]
+      h_tm1_f = h_tm1 * rec_dp_mask[1]
+      h_tm1_c = h_tm1 * rec_dp_mask[2]
+      h_tm1_o = h_tm1 * rec_dp_mask[3]
     else:
-      self.states = [
-          K.zeros(output_shape),
-          K.zeros(output_shape)
-      ]
-
-  def get_constants(self, inputs, training=None):
-    constants = []
-    if self.implementation == 0 and 0 < self.dropout < 1:
-      ones = array_ops.zeros_like(inputs)
-      ones = math_ops.reduce_sum(ones, axis=1)
-      ones += 1
-
-      def dropped_inputs():
-        return K.dropout(ones, self.dropout)
-
-      dp_mask = [
-          K.in_train_phase(dropped_inputs, ones, training=training)
-          for _ in range(4)
-      ]
-      constants.append(dp_mask)
-    else:
-      constants.append([K.cast_to_floatx(1.) for _ in range(4)])
-
-    if 0 < self.recurrent_dropout < 1:
-      shape = list(self.kernel_shape)
-      shape[-1] = self.filters
-      ones = array_ops.zeros_like(inputs)
-      ones = math_ops.reduce_sum(ones, axis=1)
-      ones = self.input_conv(ones, K.zeros(shape), padding=self.padding)
-      ones += 1.
-
-      def dropped_inputs():  # pylint: disable=function-redefined
-        return K.dropout(ones, self.recurrent_dropout)
-
-      rec_dp_mask = [
-          K.in_train_phase(dropped_inputs, ones, training=training)
-          for _ in range(4)
-      ]
-      constants.append(rec_dp_mask)
-    else:
-      constants.append([K.cast_to_floatx(1.) for _ in range(4)])
-    return constants
+      h_tm1_i = h_tm1
+      h_tm1_f = h_tm1
+      h_tm1_c = h_tm1
+      h_tm1_o = h_tm1
+
+    x_i = self.input_conv(inputs_i, self.kernel_i, self.bias_i,
+                          padding=self.padding)
+    x_f = self.input_conv(inputs_f, self.kernel_f, self.bias_f,
+                          padding=self.padding)
+    x_c = self.input_conv(inputs_c, self.kernel_c, self.bias_c,
+                          padding=self.padding)
+    x_o = self.input_conv(inputs_o, self.kernel_o, self.bias_o,
+                          padding=self.padding)
+    h_i = self.recurrent_conv(h_tm1_i,
+                              self.recurrent_kernel_i)
+    h_f = self.recurrent_conv(h_tm1_f,
+                              self.recurrent_kernel_f)
+    h_c = self.recurrent_conv(h_tm1_c,
+                              self.recurrent_kernel_c)
+    h_o = self.recurrent_conv(h_tm1_o,
+                              self.recurrent_kernel_o)
+
+    i = self.recurrent_activation(x_i + h_i)
+    f = self.recurrent_activation(x_f + h_f)
+    c = f * c_tm1 + i * self.activation(x_c + h_c)
+    o = self.recurrent_activation(x_o + h_o)
+    h = o * self.activation(c)
+
+    if 0 < self.dropout + self.recurrent_dropout:
+      if training is None:
+        h._uses_learning_phase = True
+
+    return h, [h, c]
 
   def input_conv(self, x, w, b=None, padding='valid'):
-    conv_out = K.conv2d(
-        x,
-        w,
-        strides=self.strides,
-        padding=padding,
-        data_format=self.data_format,
-        dilation_rate=self.dilation_rate)
+    conv_out = K.conv2d(x, w, strides=self.strides,
+                        padding=padding,
+                        data_format=self.data_format,
+                        dilation_rate=self.dilation_rate)
     if b is not None:
-      conv_out = K.bias_add(conv_out, b, data_format=self.data_format)
+      conv_out = K.bias_add(conv_out, b,
+                            data_format=self.data_format)
     return conv_out
 
   def recurrent_conv(self, x, w):
-    conv_out = K.conv2d(
-        x, w, strides=(1, 1), padding='same', data_format=self.data_format)
+    conv_out = K.conv2d(x, w, strides=(1, 1),
+                        padding='same',
+                        data_format=self.data_format)
     return conv_out
 
-  def step(self, inputs, states):
-    assert len(states) == 4
-    h_tm1 = states[0]
-    c_tm1 = states[1]
-    dp_mask = states[2]
-    rec_dp_mask = states[3]
-
-    x_i = self.input_conv(
-        inputs * dp_mask[0], self.kernel_i, self.bias_i, padding=self.padding)
-    x_f = self.input_conv(
-        inputs * dp_mask[1], self.kernel_f, self.bias_f, padding=self.padding)
-    x_c = self.input_conv(
-        inputs * dp_mask[2], self.kernel_c, self.bias_c, padding=self.padding)
-    x_o = self.input_conv(
-        inputs * dp_mask[3], self.kernel_o, self.bias_o, padding=self.padding)
-    h_i = self.recurrent_conv(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i)
-    h_f = self.recurrent_conv(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f)
-    h_c = self.recurrent_conv(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c)
-    h_o = self.recurrent_conv(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o)
+  def get_config(self):
+    config = {'filters': self.filters,
+              'kernel_size': self.kernel_size,
+              'strides': self.strides,
+              'padding': self.padding,
+              'data_format': self.data_format,
+              'dilation_rate': self.dilation_rate,
+              'activation': activations.serialize(self.activation),
+              'recurrent_activation': activations.serialize(
+                  self.recurrent_activation),
+              'use_bias': self.use_bias,
+              'kernel_initializer': initializers.serialize(
+                  self.kernel_initializer),
+              'recurrent_initializer': initializers.serialize(
+                  self.recurrent_initializer),
+              'bias_initializer': initializers.serialize(self.bias_initializer),
+              'unit_forget_bias': self.unit_forget_bias,
+              'kernel_regularizer': regularizers.serialize(
+                  self.kernel_regularizer),
+              'recurrent_regularizer': regularizers.serialize(
+                  self.recurrent_regularizer),
+              'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+              'kernel_constraint': constraints.serialize(
+                  self.kernel_constraint),
+              'recurrent_constraint': constraints.serialize(
+                  self.recurrent_constraint),
+              'bias_constraint': constraints.serialize(self.bias_constraint),
+              'dropout': self.dropout,
+              'recurrent_dropout': self.recurrent_dropout}
+    base_config = super(ConvLSTM2DCell, self).get_config()
+    return dict(list(base_config.items()) + list(config.items()))
+
 
-    i = self.recurrent_activation(x_i + h_i)
-    f = self.recurrent_activation(x_f + h_f)
-    c = f * c_tm1 + i * self.activation(x_c + h_c)
-    o = self.recurrent_activation(x_o + h_o)
-    h = o * self.activation(c)
-    return h, [h, c]
+@tf_export('keras.layers.ConvLSTM2D')
+class ConvLSTM2D(ConvRNN2D):
+  """Convolutional LSTM.
+
+  It is similar to an LSTM layer, but the input transformations
+  and recurrent transformations are both convolutional.
+
+  Arguments:
+    filters: Integer, the dimensionality of the output space
+        (i.e. the number output of filters in the convolution).
+    kernel_size: An integer or tuple/list of n integers, specifying the
+        dimensions of the convolution window.
+    strides: An integer or tuple/list of n integers,
+        specifying the strides of the convolution.
+        Specifying any stride value != 1 is incompatible with specifying
+        any `dilation_rate` value != 1.
+    padding: One of `"valid"` or `"same"` (case-insensitive).
+    data_format: A string,
+        one of `channels_last` (default) or `channels_first`.
+        The ordering of the dimensions in the inputs.
+        `channels_last` corresponds to inputs with shape
+        `(batch, time, ..., channels)`
+        while `channels_first` corresponds to
+        inputs with shape `(batch, time, channels, ...)`.
+        It defaults to the `image_data_format` value found in your
+        Keras config file at `~/.keras/keras.json`.
+        If you never set it, then it will be "channels_last".
+    dilation_rate: An integer or tuple/list of n integers, specifying
+        the dilation rate to use for dilated convolution.
+        Currently, specifying any `dilation_rate` value != 1 is
+        incompatible with specifying any `strides` value != 1.
+    activation: Activation function to use.
+        If you don't specify anything, no activation is applied
+        (ie. "linear" activation: `a(x) = x`).
+    recurrent_activation: Activation function to use
+        for the recurrent step.
+    use_bias: Boolean, whether the layer uses a bias vector.
+    kernel_initializer: Initializer for the `kernel` weights matrix,
+        used for the linear transformation of the inputs.
+    recurrent_initializer: Initializer for the `recurrent_kernel`
+        weights matrix,
+        used for the linear transformation of the recurrent state.
+    bias_initializer: Initializer for the bias vector.
+    unit_forget_bias: Boolean.
+        If True, add 1 to the bias of the forget gate at initialization.
+        Use in combination with `bias_initializer="zeros"`.
+        This is recommended in [Jozefowicz et al.]
+        (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
+    kernel_regularizer: Regularizer function applied to
+        the `kernel` weights matrix.
+    recurrent_regularizer: Regularizer function applied to
+        the `recurrent_kernel` weights matrix.
+    bias_regularizer: Regularizer function applied to the bias vector.
+    activity_regularizer: Regularizer function applied to.
+    kernel_constraint: Constraint function applied to
+        the `kernel` weights matrix.
+    recurrent_constraint: Constraint function applied to
+        the `recurrent_kernel` weights matrix.
+    bias_constraint: Constraint function applied to the bias vector.
+    return_sequences: Boolean. Whether to return the last output
+        in the output sequence, or the full sequence.
+    go_backwards: Boolean (default False).
+        If True, process the input sequence backwards.
+    stateful: Boolean (default False). If True, the last state
+        for each sample at index i in a batch will be used as initial
+        state for the sample of index i in the following batch.
+    dropout: Float between 0 and 1.
+        Fraction of the units to drop for
+        the linear transformation of the inputs.
+    recurrent_dropout: Float between 0 and 1.
+        Fraction of the units to drop for
+        the linear transformation of the recurrent state.
+
+  Input shape:
+    - if data_format='channels_first'
+        5D tensor with shape:
+        `(samples,time, channels, rows, cols)`
+    - if data_format='channels_last'
+        5D tensor with shape:
+        `(samples,time, rows, cols, channels)`
+
+  Output shape:
+    - if `return_sequences`
+         - if data_format='channels_first'
+            5D tensor with shape:
+            `(samples, time, filters, output_row, output_col)`
+         - if data_format='channels_last'
+            5D tensor with shape:
+            `(samples, time, output_row, output_col, filters)`
+    - else
+        - if data_format ='channels_first'
+            4D tensor with shape:
+            `(samples, filters, output_row, output_col)`
+        - if data_format='channels_last'
+            4D tensor with shape:
+            `(samples, output_row, output_col, filters)`
+        where o_row and o_col depend on the shape of the filter and
+        the padding
+
+  Raises:
+    ValueError: in case of invalid constructor arguments.
+
+  References:
+    - [Convolutional LSTM Network: A Machine Learning Approach for
+    Precipitation Nowcasting](http://arxiv.org/abs/1506.04214v1)
+    The current implementation does not include the feedback loop on the
+    cells output.
+
+  """
+
+  def __init__(self,
+               filters,
+               kernel_size,
+               strides=(1, 1),
+               padding='valid',
+               data_format=None,
+               dilation_rate=(1, 1),
+               activation='tanh',
+               recurrent_activation='hard_sigmoid',
+               use_bias=True,
+               kernel_initializer='glorot_uniform',
+               recurrent_initializer='orthogonal',
+               bias_initializer='zeros',
+               unit_forget_bias=True,
+               kernel_regularizer=None,
+               recurrent_regularizer=None,
+               bias_regularizer=None,
+               activity_regularizer=None,
+               kernel_constraint=None,
+               recurrent_constraint=None,
+               bias_constraint=None,
+               return_sequences=False,
+               go_backwards=False,
+               stateful=False,
+               dropout=0.,
+               recurrent_dropout=0.,
+               **kwargs):
+    cell = ConvLSTM2DCell(filters=filters,
+                          kernel_size=kernel_size,
+                          strides=strides,
+                          padding=padding,
+                          data_format=data_format,
+                          dilation_rate=dilation_rate,
+                          activation=activation,
+                          recurrent_activation=recurrent_activation,
+                          use_bias=use_bias,
+                          kernel_initializer=kernel_initializer,
+                          recurrent_initializer=recurrent_initializer,
+                          bias_initializer=bias_initializer,
+                          unit_forget_bias=unit_forget_bias,
+                          kernel_regularizer=kernel_regularizer,
+                          recurrent_regularizer=recurrent_regularizer,
+                          bias_regularizer=bias_regularizer,
+                          kernel_constraint=kernel_constraint,
+                          recurrent_constraint=recurrent_constraint,
+                          bias_constraint=bias_constraint,
+                          dropout=dropout,
+                          recurrent_dropout=recurrent_dropout)
+    super(ConvLSTM2D, self).__init__(cell,
+                                     return_sequences=return_sequences,
+                                     go_backwards=go_backwards,
+                                     stateful=stateful,
+                                     **kwargs)
+    self.activity_regularizer = regularizers.get(activity_regularizer)
+
+  def call(self, inputs, mask=None, training=None, initial_state=None):
+    return super(ConvLSTM2D, self).call(inputs,
+                                        mask=mask,
+                                        training=training,
+                                        initial_state=initial_state)
+
+  @property
+  def filters(self):
+    return self.cell.filters
+
+  @property
+  def kernel_size(self):
+    return self.cell.kernel_size
+
+  @property
+  def strides(self):
+    return self.cell.strides
+
+  @property
+  def padding(self):
+    return self.cell.padding
+
+  @property
+  def data_format(self):
+    return self.cell.data_format
+
+  @property
+  def dilation_rate(self):
+    return self.cell.dilation_rate
+
+  @property
+  def activation(self):
+    return self.cell.activation
+
+  @property
+  def recurrent_activation(self):
+    return self.cell.recurrent_activation
+
+  @property
+  def use_bias(self):
+    return self.cell.use_bias
+
+  @property
+  def kernel_initializer(self):
+    return self.cell.kernel_initializer
+
+  @property
+  def recurrent_initializer(self):
+    return self.cell.recurrent_initializer
+
+  @property
+  def bias_initializer(self):
+    return self.cell.bias_initializer
+
+  @property
+  def unit_forget_bias(self):
+    return self.cell.unit_forget_bias
+
+  @property
+  def kernel_regularizer(self):
+    return self.cell.kernel_regularizer
+
+  @property
+  def recurrent_regularizer(self):
+    return self.cell.recurrent_regularizer
+
+  @property
+  def bias_regularizer(self):
+    return self.cell.bias_regularizer
+
+  @property
+  def kernel_constraint(self):
+    return self.cell.kernel_constraint
+
+  @property
+  def recurrent_constraint(self):
+    return self.cell.recurrent_constraint
+
+  @property
+  def bias_constraint(self):
+    return self.cell.bias_constraint
+
+  @property
+  def dropout(self):
+    return self.cell.dropout
+
+  @property
+  def recurrent_dropout(self):
+    return self.cell.recurrent_dropout
 
   def get_config(self):
-    config = {
-        'activation':
-            activations.serialize(self.activation),
-        'recurrent_activation':
-            activations.serialize(self.recurrent_activation),
-        'use_bias':
-            self.use_bias,
-        'kernel_initializer':
-            initializers.serialize(self.kernel_initializer),
-        'recurrent_initializer':
-            initializers.serialize(self.recurrent_initializer),
-        'bias_initializer':
-            initializers.serialize(self.bias_initializer),
-        'unit_forget_bias':
-            self.unit_forget_bias,
-        'kernel_regularizer':
-            regularizers.serialize(self.kernel_regularizer),
-        'recurrent_regularizer':
-            regularizers.serialize(self.recurrent_regularizer),
-        'bias_regularizer':
-            regularizers.serialize(self.bias_regularizer),
-        'activity_regularizer':
-            regularizers.serialize(self.activity_regularizer),
-        'kernel_constraint':
-            constraints.serialize(self.kernel_constraint),
-        'recurrent_constraint':
-            constraints.serialize(self.recurrent_constraint),
-        'bias_constraint':
-            constraints.serialize(self.bias_constraint),
-        'dropout':
-            self.dropout,
-        'recurrent_dropout':
-            self.recurrent_dropout
-    }
+    config = {'filters': self.filters,
+              'kernel_size': self.kernel_size,
+              'strides': self.strides,
+              'padding': self.padding,
+              'data_format': self.data_format,
+              'dilation_rate': self.dilation_rate,
+              'activation': activations.serialize(self.activation),
+              'recurrent_activation': activations.serialize(
+                  self.recurrent_activation),
+              'use_bias': self.use_bias,
+              'kernel_initializer': initializers.serialize(
+                  self.kernel_initializer),
+              'recurrent_initializer': initializers.serialize(
+                  self.recurrent_initializer),
+              'bias_initializer': initializers.serialize(self.bias_initializer),
+              'unit_forget_bias': self.unit_forget_bias,
+              'kernel_regularizer': regularizers.serialize(
+                  self.kernel_regularizer),
+              'recurrent_regularizer': regularizers.serialize(
+                  self.recurrent_regularizer),
+              'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+              'activity_regularizer': regularizers.serialize(
+                  self.activity_regularizer),
+              'kernel_constraint': constraints.serialize(
+                  self.kernel_constraint),
+              'recurrent_constraint': constraints.serialize(
+                  self.recurrent_constraint),
+              'bias_constraint': constraints.serialize(self.bias_constraint),
+              'dropout': self.dropout,
+              'recurrent_dropout': self.recurrent_dropout}
     base_config = super(ConvLSTM2D, self).get_config()
+    del base_config['cell']
     return dict(list(base_config.items()) + list(config.items()))
+
+  @classmethod
+  def from_config(cls, config):
+    return cls(**config)
index 60137bd..9e768b4 100644 (file)
@@ -64,6 +64,7 @@ class ConvLSTMTest(test.TestCase):
           self.assertEqual(len(states), 2)
           model = keras.models.Model(x, states[0])
           state = model.predict(inputs)
+
           self.assertAllClose(
               keras.backend.eval(layer.states[0]), state, atol=1e-4)
 
index 6a7da1a..a535f18 100644 (file)
@@ -1,21 +1,53 @@
 path: "tensorflow.keras.layers.ConvLSTM2D"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional_recurrent.ConvLSTM2D\'>"
-  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional_recurrent.ConvRecurrent2D\'>"
-  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.Recurrent\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional_recurrent.ConvRNN2D\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
   is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
   is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
   is_instance: "<class \'tensorflow.python.training.checkpointable.CheckpointableBase\'>"
   is_instance: "<type \'object\'>"
   member {
+    name: "activation"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "activity_regularizer"
     mtype: "<type \'property\'>"
   }
   member {
+    name: "bias_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "bias_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "bias_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "data_format"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dilation_rate"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dropout"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "dtype"
     mtype: "<type \'property\'>"
   }
   member {
+    name: "filters"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "graph"
     mtype: "<type \'property\'>"
   }
@@ -36,6 +68,22 @@ tf_class {
     mtype: "<type \'property\'>"
   }
   member {
+    name: "kernel_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "kernel_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "kernel_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "kernel_size"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "losses"
     mtype: "<type \'property\'>"
   }
@@ -68,10 +116,42 @@ tf_class {
     mtype: "<type \'property\'>"
   }
   member {
+    name: "padding"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_activation"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_dropout"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "scope_name"
     mtype: "<type \'property\'>"
   }
   member {
+    name: "states"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "strides"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "trainable_variables"
     mtype: "<type \'property\'>"
   }
@@ -80,10 +160,18 @@ tf_class {
     mtype: "<type \'property\'>"
   }
   member {
+    name: "unit_forget_bias"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "updates"
     mtype: "<type \'property\'>"
   }
   member {
+    name: "use_bias"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "variables"
     mtype: "<type \'property\'>"
   }
@@ -144,10 +232,6 @@ tf_class {
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
-    name: "get_constants"
-    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
     name: "get_initial_state"
     argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
@@ -188,27 +272,11 @@ tf_class {
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
-    name: "input_conv"
-    argspec: "args=[\'self\', \'x\', \'w\', \'b\', \'padding\'], varargs=None, keywords=None, defaults=[\'None\', \'valid\'], "
-  }
-  member_method {
-    name: "preprocess_input"
-    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "recurrent_conv"
-    argspec: "args=[\'self\', \'x\', \'w\'], varargs=None, keywords=None, defaults=None"
-  }
-  member_method {
     name: "reset_states"
-    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "set_weights"
     argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "step"
-    argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
-  }
 }