Add RNNClassifier
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 21:32:28 +0000 (14:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 21:35:11 +0000 (14:35 -0700)
PiperOrigin-RevId: 191941174

tensorflow/contrib/estimator/BUILD
tensorflow/contrib/estimator/__init__.py
tensorflow/contrib/estimator/python/estimator/rnn.py [new file with mode: 0644]
tensorflow/contrib/estimator/python/estimator/rnn_test.py [new file with mode: 0644]
tensorflow/python/ops/rnn_cell_impl.py

index bec0329..9f4cd44 100644 (file)
@@ -23,6 +23,7 @@ py_library(
         ":logit_fns",
         ":multi_head",
         ":replicate_model_fn",
+        ":rnn",
         "//tensorflow/python:util",
     ],
 )
@@ -412,3 +413,57 @@ cuda_py_test(
         "notap",
     ],
 )
+
+py_library(
+    name = "rnn",
+    srcs = ["python/estimator/rnn.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":extenders",
+        "//tensorflow/contrib/feature_column:feature_column_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:check_ops",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:init_ops",
+        "//tensorflow/python:layers",
+        "//tensorflow/python:partitioned_variables",
+        "//tensorflow/python:rnn",
+        "//tensorflow/python:rnn_cell",
+        "//tensorflow/python:summary",
+        "//tensorflow/python:training",
+        "//tensorflow/python:variable_scope",
+        "//tensorflow/python/estimator",
+        "//tensorflow/python/estimator:head",
+        "//tensorflow/python/estimator:optimizers",
+        "//tensorflow/python/feature_column",
+        "@six_archive//:six",
+    ],
+)
+
+py_test(
+    name = "rnn_test",
+    size = "medium",
+    srcs = ["python/estimator/rnn_test.py"],
+    srcs_version = "PY2AND3",
+    tags = [
+        "no_pip",
+        "notsan",
+    ],
+    deps = [
+        ":rnn",
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:check_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:state_ops",
+        "//tensorflow/python:summary",
+        "//tensorflow/python:training",
+        "//tensorflow/python:variables",
+        "//tensorflow/python/estimator:numpy_io",
+        "//tensorflow/python/feature_column",
+        "//third_party/py/numpy",
+        "@six_archive//:six",
+    ],
+)
index d2fc2c4..9a87fa9 100644 (file)
@@ -52,6 +52,7 @@ _allowed_symbols = [
     'linear_logit_fn_builder',
     'replicate_model_fn',
     'TowerOptimizer',
+    'RNNClassifier',
 ]
 
 remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
new file mode 100644 (file)
index 0000000..b475c12
--- /dev/null
@@ -0,0 +1,481 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Recurrent Neural Network estimators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.estimator.python.estimator import extenders
+from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.canned import optimizers
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import core as core_layers
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.summary import summary
+from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.training import training_util
+
+
+# The defaults are historical artifacts of the initial implementation, but seem
+# reasonable choices.
+_DEFAULT_LEARNING_RATE = 0.05
+_DEFAULT_CLIP_NORM = 5.0
+
+_CELL_TYPES = {'basic_rnn': rnn_cell.BasicRNNCell,
+               'lstm': rnn_cell.BasicLSTMCell,
+               'gru': rnn_cell.GRUCell}
+
+# Indicates no value was provided by the user to a kwarg.
+USE_DEFAULT = object()
+
+
+def _single_rnn_cell(num_units, cell_type):
+  cell_type = _CELL_TYPES.get(cell_type, cell_type)
+  if not cell_type or not issubclass(cell_type, rnn_cell.RNNCell):
+    raise ValueError('Supported cell types are {}; got {}'.format(
+        list(_CELL_TYPES.keys()), cell_type))
+  return cell_type(num_units=num_units)
+
+
+def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'):
+  """Convenience function to create `rnn_cell_fn` for canned RNN Estimators.
+
+  Args:
+    num_units: Iterable of integer number of hidden units per RNN layer.
+    cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying
+      the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and
+      `'gru'`.
+
+  Returns:
+    A function that takes a single argument, an instance of
+    `tf.estimator.ModeKeys`, and returns an instance derived from
+    `tf.nn.rnn_cell.RNNCell`.
+
+  Raises:
+    ValueError: If cell_type is not supported.
+  """
+  def rnn_cell_fn(mode):
+    # Unused. Part of the rnn_cell_fn interface since user specified functions
+    # may need different behavior across modes (e.g. dropout).
+    del mode
+    cells = [_single_rnn_cell(n, cell_type) for n in num_units]
+    if len(cells) == 1:
+      return cells[0]
+    return rnn_cell.MultiRNNCell(cells)
+  return rnn_cell_fn
+
+
+def _concatenate_context_input(sequence_input, context_input):
+  """Replicates `context_input` across all timesteps of `sequence_input`.
+
+  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
+  This value is appended to `sequence_input` on dimension 2 and the result is
+  returned.
+
+  Args:
+    sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
+      padded_length, d0]`.
+    context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
+
+  Returns:
+    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
+    d0 + d1]`.
+
+  Raises:
+    ValueError: If `sequence_input` does not have rank 3 or `context_input` does
+      not have rank 2.
+  """
+  seq_rank_check = check_ops.assert_rank(
+      sequence_input,
+      3,
+      message='sequence_input must have rank 3',
+      data=[array_ops.shape(sequence_input)])
+  seq_type_check = check_ops.assert_type(
+      sequence_input,
+      dtypes.float32,
+      message='sequence_input must have dtype float32; got {}.'.format(
+          sequence_input.dtype))
+  ctx_rank_check = check_ops.assert_rank(
+      context_input,
+      2,
+      message='context_input must have rank 2',
+      data=[array_ops.shape(context_input)])
+  ctx_type_check = check_ops.assert_type(
+      context_input,
+      dtypes.float32,
+      message='context_input must have dtype float32; got {}.'.format(
+          context_input.dtype))
+  with ops.control_dependencies(
+      [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
+    padded_length = array_ops.shape(sequence_input)[1]
+    tiled_context_input = array_ops.tile(
+        array_ops.expand_dims(context_input, 1),
+        array_ops.concat([[1], [padded_length], [1]], 0))
+  return array_ops.concat([sequence_input, tiled_context_input], 2)
+
+
+def _select_last_activations(activations, sequence_lengths):
+  """Selects the nth set of activations for each n in `sequence_length`.
+
+  Returns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not
+  `None`, then `output[i, :] = activations[i, sequence_length[i] - 1, :]`. If
+  `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`.
+
+  Args:
+    activations: A `Tensor` with shape `[batch_size, padded_length, k]`.
+    sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`.
+  Returns:
+    A `Tensor` of shape `[batch_size, k]`.
+  """
+  with ops.name_scope(
+      'select_last_activations', values=[activations, sequence_lengths]):
+    activations_shape = array_ops.shape(activations)
+    batch_size = activations_shape[0]
+    padded_length = activations_shape[1]
+    output_units = activations_shape[2]
+    if sequence_lengths is None:
+      sequence_lengths = padded_length
+    start_indices = math_ops.to_int64(
+        math_ops.range(batch_size) * padded_length)
+    last_indices = start_indices + sequence_lengths - 1
+    reshaped_activations = array_ops.reshape(
+        activations, [batch_size * padded_length, output_units])
+
+    last_activations = array_ops.gather(reshaped_activations, last_indices)
+    last_activations.set_shape([activations.shape[0], activations.shape[2]])
+    return last_activations
+
+
+def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns,
+                          context_feature_columns, input_layer_partitioner):
+  """Function builder for a rnn logit_fn.
+
+  Args:
+    output_units: An int indicating the dimension of the logit layer.
+    rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and
+      returns an object of type `tf.nn.rnn_cell.RNNCell`.
+    sequence_feature_columns: An iterable containing the `FeatureColumn`s
+      that represent sequential input.
+    context_feature_columns: An iterable containing the `FeatureColumn`s
+      that represent contextual input.
+    input_layer_partitioner: Partitioner for input layer.
+
+  Returns:
+    A logit_fn (see below).
+
+  Raises:
+    ValueError: If output_units is not an int.
+  """
+  if not isinstance(output_units, int):
+    raise ValueError('output_units must be an int.  Given type: {}'.format(
+        type(output_units)))
+
+  def rnn_logit_fn(features, mode):
+    """Recurrent Neural Network logit_fn.
+
+    Args:
+      features: This is the first item returned from the `input_fn`
+                passed to `train`, `evaluate`, and `predict`. This should be a
+                single `Tensor` or `dict` of same.
+      mode: Optional. Specifies if this training, evaluation or prediction. See
+            `ModeKeys`.
+
+    Returns:
+      A `Tensor` representing the logits.
+    """
+    with variable_scope.variable_scope(
+        'sequence_input_layer',
+        values=tuple(six.itervalues(features)),
+        partitioner=input_layer_partitioner):
+      sequence_input, sequence_length = seq_fc.sequence_input_layer(
+          features=features, feature_columns=sequence_feature_columns)
+      summary.histogram('sequence_length', sequence_length)
+
+      if context_feature_columns:
+        context_input = feature_column_lib.input_layer(
+            features=features,
+            feature_columns=context_feature_columns)
+        sequence_input = _concatenate_context_input(sequence_input,
+                                                    context_input)
+
+    cell = rnn_cell_fn(mode)
+    # Ignore output state.
+    rnn_outputs, _ = rnn.dynamic_rnn(
+        cell=cell,
+        inputs=sequence_input,
+        dtype=dtypes.float32,
+        time_major=False)
+    last_activations = _select_last_activations(rnn_outputs, sequence_length)
+
+    with variable_scope.variable_scope('logits', values=(rnn_outputs,)):
+      logits = core_layers.dense(
+          last_activations,
+          units=output_units,
+          activation=None,
+          kernel_initializer=init_ops.glorot_uniform_initializer())
+    return logits
+
+  return rnn_logit_fn
+
+
+def _rnn_model_fn(features,
+                  labels,
+                  mode,
+                  head,
+                  rnn_cell_fn,
+                  sequence_feature_columns,
+                  context_feature_columns,
+                  optimizer='Adagrad',
+                  input_layer_partitioner=None,
+                  config=None):
+  """Recurrent Neural Net model_fn.
+
+  Args:
+    features: dict of `Tensor` and `SparseTensor` objects returned from
+      `input_fn`.
+    labels: `Tensor` of shape [batch_size, 1] or [batch_size] with labels.
+    mode: Defines whether this is training, evaluation or prediction.
+      See `ModeKeys`.
+    head: A `head_lib._Head` instance.
+    rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and
+      returns an object of type `tf.nn.rnn_cell.RNNCell`.
+    sequence_feature_columns: Iterable containing `FeatureColumn`s that
+      represent sequential model inputs.
+    context_feature_columns: Iterable containing `FeatureColumn`s that
+      represent model inputs not associated with a specific timestep.
+    optimizer: String, `tf.Optimizer` object, or callable that creates the
+      optimizer to use for training. If not specified, will use the Adagrad
+      optimizer with a default learning rate of 0.05 and gradient clip norm of
+      5.0.
+    input_layer_partitioner: Partitioner for input layer. Defaults
+      to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+    config: `RunConfig` object to configure the runtime settings.
+
+  Returns:
+    An `EstimatorSpec` instance.
+
+  Raises:
+    ValueError: If mode or optimizer is invalid, or features has the wrong type.
+  """
+  if not isinstance(features, dict):
+    raise ValueError('features should be a dictionary of `Tensor`s. '
+                     'Given type: {}'.format(type(features)))
+
+  # If user does not provide an optimizer instance, use the optimizer specified
+  # by the string with default learning rate and gradient clipping.
+  if not isinstance(optimizer, optimizer_lib.Optimizer):
+    optimizer = optimizers.get_optimizer_instance(
+        optimizer, learning_rate=_DEFAULT_LEARNING_RATE)
+    optimizer = extenders.clip_gradients_by_norm(optimizer, _DEFAULT_CLIP_NORM)
+
+  num_ps_replicas = config.num_ps_replicas if config else 0
+  partitioner = partitioned_variables.min_max_variable_partitioner(
+      max_partitions=num_ps_replicas)
+  with variable_scope.variable_scope(
+      'rnn',
+      values=tuple(six.itervalues(features)),
+      partitioner=partitioner):
+    input_layer_partitioner = input_layer_partitioner or (
+        partitioned_variables.min_max_variable_partitioner(
+            max_partitions=num_ps_replicas,
+            min_slice_size=64 << 20))
+
+    logit_fn = _rnn_logit_fn_builder(
+        output_units=head.logits_dimension,
+        rnn_cell_fn=rnn_cell_fn,
+        sequence_feature_columns=sequence_feature_columns,
+        context_feature_columns=context_feature_columns,
+        input_layer_partitioner=input_layer_partitioner)
+    logits = logit_fn(features=features, mode=mode)
+
+    def _train_op_fn(loss):
+      """Returns the op to optimize the loss."""
+      return optimizer.minimize(
+          loss,
+          global_step=training_util.get_global_step())
+
+    return head.create_estimator_spec(
+        features=features,
+        mode=mode,
+        labels=labels,
+        train_op_fn=_train_op_fn,
+        logits=logits)
+
+
+class RNNClassifier(estimator.Estimator):
+  """A classifier for TensorFlow RNN models.
+
+  Trains a recurrent neural network model to classify instances into one of
+  multiple classes.
+
+  Example:
+
+  ```python
+  token_sequence = sequence_categorical_column_with_hash_bucket(...)
+  token_emb = embedding_column(categorical_column=token_sequence, ...)
+
+  estimator = RNNClassifier(
+      num_units=[32, 16], cell_type='lstm',
+      sequence_feature_columns=[token_emb])
+
+  # Input builders
+  def input_fn_train: # returns x, y
+    pass
+  estimator.train(input_fn=input_fn_train, steps=100)
+
+  def input_fn_eval: # returns x, y
+    pass
+  metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
+  def input_fn_predict: # returns x, None
+    pass
+  predictions = estimator.predict(input_fn=input_fn_predict)
+  ```
+
+  Input of `train` and `evaluate` should have following features,
+  otherwise there will be a `KeyError`:
+
+  * if `weight_column` is not `None`, a feature with
+    `key=weight_column` whose value is a `Tensor`.
+  * for each `column` in `sequence_feature_columns`:
+    - a feature with `key=column.name` whose `value` is a `SparseTensor`.
+  * for each `column` in `context_feature_columns`:
+    - if `column` is a `_CategoricalColumn`, a feature with `key=column.name`
+      whose `value` is a `SparseTensor`.
+    - if `column` is a `_WeightedCategoricalColumn`, two features: the first
+      with `key` the id column name, the second with `key` the weight column
+      name. Both features' `value` must be a `SparseTensor`.
+    - if `column` is a `_DenseColumn`, a feature with `key=column.name`
+      whose `value` is a `Tensor`.
+
+  Loss is calculated by using softmax cross entropy.
+
+  @compatibility(eager)
+  Estimators are not compatible with eager execution.
+  @end_compatibility
+  """
+
+  def __init__(self,
+               sequence_feature_columns,
+               context_feature_columns=None,
+               num_units=None,
+               cell_type=USE_DEFAULT,
+               rnn_cell_fn=None,
+               model_dir=None,
+               n_classes=2,
+               weight_column=None,
+               label_vocabulary=None,
+               optimizer='Adagrad',
+               input_layer_partitioner=None,
+               config=None):
+    """Initializes a `RNNClassifier` instance.
+
+    Args:
+      sequence_feature_columns: An iterable containing the `FeatureColumn`s
+        that represent sequential input. All items in the set should either be
+        sequence columns (e.g. `sequence_numeric_column`) or constructed from
+        one (e.g. `embedding_column` with `sequence_categorical_column_*` as
+        input).
+      context_feature_columns: An iterable containing the `FeatureColumn`s
+        for contextual input. The data represented by these columns will be
+        replicated and given to the RNN at each timestep. These columns must be
+        instances of classes derived from `_DenseColumn` such as
+        `numeric_column`, not the sequential variants.
+      num_units: Iterable of integer number of hidden units per RNN layer. If
+        set, `cell_type` must also be specified and `rnn_cell_fn` must be
+        `None`.
+      cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying
+        the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and
+        `'gru'`. If set, `num_units` must also be specified and `rnn_cell_fn`
+        must be `None`.
+      rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and
+        returns an object of type `tf.nn.rnn_cell.RNNCell` that will be used to
+        construct the RNN. If set, `num_units` and `cell_type` cannot be set.
+        This is for advanced users who need additional customization beyond
+        `num_units` and `cell_type`. Note that `tf.nn.rnn_cell.MultiRNNCell` is
+        needed for stacked RNNs.
+      model_dir: Directory to save model parameters, graph and etc. This can
+        also be used to load checkpoints from the directory into a estimator to
+        continue training a previously saved model.
+      n_classes: Number of label classes. Defaults to 2, namely binary
+        classification. Must be > 1.
+      weight_column: A string or a `_NumericColumn` created by
+        `tf.feature_column.numeric_column` defining feature column representing
+        weights. It is used to down weight or boost examples during training. It
+        will be multiplied by the loss of the example. If it is a string, it is
+        used as a key to fetch weight tensor from the `features`. If it is a
+        `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+        then weight_column.normalizer_fn is applied on it to get weight tensor.
+      label_vocabulary: A list of strings represents possible label values. If
+        given, labels must be string type and have any value in
+        `label_vocabulary`. If it is not given, that means labels are
+        already encoded as integer or float within [0, 1] for `n_classes=2` and
+        encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
+        Also there will be errors if vocabulary is not provided and labels are
+        string.
+      optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+        to Adagrad optimizer.
+      input_layer_partitioner: Optional. Partitioner for input layer. Defaults
+        to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+      config: `RunConfig` object to configure the runtime settings.
+
+    Raises:
+      ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not
+        compatible.
+    """
+    if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT):
+      raise ValueError(
+          'num_units and cell_type must not be specified when using rnn_cell_fn'
+      )
+    if not rnn_cell_fn:
+      if cell_type == USE_DEFAULT:
+        cell_type = 'basic_rnn'
+      rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type)
+
+    if n_classes == 2:
+      head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(  # pylint: disable=protected-access
+          weight_column=weight_column,
+          label_vocabulary=label_vocabulary)
+    else:
+      head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(  # pylint: disable=protected-access
+          n_classes, weight_column=weight_column,
+          label_vocabulary=label_vocabulary)
+    def _model_fn(features, labels, mode, config):
+      return _rnn_model_fn(
+          features=features,
+          labels=labels,
+          mode=mode,
+          head=head,
+          rnn_cell_fn=rnn_cell_fn,
+          sequence_feature_columns=tuple(sequence_feature_columns or []),
+          context_feature_columns=tuple(context_feature_columns or []),
+          optimizer=optimizer,
+          input_layer_partitioner=input_layer_partitioner,
+          config=config)
+    super(RNNClassifier, self).__init__(
+        model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
new file mode 100644 (file)
index 0000000..393f94f
--- /dev/null
@@ -0,0 +1,1131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for rnn.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import rnn
+from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator.canned import metric_keys
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column as fc
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import input as input_lib
+from tensorflow.python.training import monitored_session
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_util
+
+
+# Names of variables created by BasicRNNCell model.
+TOKEN_EMBEDDING_NAME = 'rnn/sequence_input_layer/input_layer/tokens_sequential_embedding/embedding_weights'
+CELL_WEIGHTS_NAME = 'rnn/rnn/basic_rnn_cell/kernel'
+CELL_BIAS_NAME = 'rnn/rnn/basic_rnn_cell/bias'
+MULTI_CELL_WEIGHTS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/kernel'
+MULTI_CELL_BIAS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/bias'
+LOGITS_WEIGHTS_NAME = 'rnn/logits/dense/kernel'
+LOGITS_BIAS_NAME = 'rnn/logits/dense/bias'
+
+
+def _assert_close(expected, actual, rtol=1e-04, name='assert_close'):
+  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:
+    expected = ops.convert_to_tensor(expected, name='expected')
+    actual = ops.convert_to_tensor(actual, name='actual')
+    rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected)
+    rtol = ops.convert_to_tensor(rtol, name='rtol')
+    return check_ops.assert_less(
+        rdiff,
+        rtol,
+        data=('Condition expected =~ actual did not hold element-wise:'
+              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,
+              'rtol = ', rtol,),
+        name=scope)
+
+
+def create_checkpoint(rnn_weights, rnn_biases, logits_weights, logits_biases,
+                      global_step, model_dir):
+  """Create checkpoint file with provided model weights.
+
+  Args:
+    rnn_weights: Iterable of values of weights for the RNN cell.
+    rnn_biases: Iterable of values of biases for the RNN cell.
+    logits_weights: Iterable of values for matrix connecting RNN output to
+      logits.
+    logits_biases: Iterable of values for logits bias term.
+    global_step: Initial global step to save in checkpoint.
+    model_dir: Directory into which checkpoint is saved.
+  """
+  model_weights = {}
+  model_weights[CELL_WEIGHTS_NAME] = rnn_weights
+  model_weights[CELL_BIAS_NAME] = rnn_biases
+  model_weights[LOGITS_WEIGHTS_NAME] = logits_weights
+  model_weights[LOGITS_BIAS_NAME] = logits_biases
+
+  with ops.Graph().as_default():
+    # Create model variables.
+    for k, v in six.iteritems(model_weights):
+      variables_lib.Variable(v, name=k, dtype=dtypes.float32)
+
+    # Create non-model variables.
+    global_step_var = training_util.create_global_step()
+    assign_op = global_step_var.assign(global_step)
+
+    # Initialize vars and save checkpoint.
+    with monitored_session.MonitoredTrainingSession(
+        checkpoint_dir=model_dir) as sess:
+      sess.run(assign_op)
+
+
+class RNNLogitFnTest(test.TestCase):
+  """Tests correctness of logits calculated from _rnn_logit_fn_builder."""
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def _test_logits(self, mode, rnn_units, logits_dimension, features_fn,
+                   sequence_feature_columns, context_feature_columns,
+                   expected_logits):
+    """Tests that the expected logits are calculated."""
+    with ops.Graph().as_default():
+      # Global step needed for MonitoredSession, which is in turn used to
+      # explicitly set variable weights through a checkpoint.
+      training_util.create_global_step()
+      # Use a variable scope here with 'rnn', emulating the rnn model_fn, so
+      # the checkpoint naming is shared.
+      with variable_scope.variable_scope('rnn'):
+        input_layer_partitioner = (
+            partitioned_variables.min_max_variable_partitioner(
+                max_partitions=0, min_slice_size=64 << 20))
+        logit_fn = rnn._rnn_logit_fn_builder(
+            output_units=logits_dimension,
+            rnn_cell_fn=rnn._make_rnn_cell_fn(rnn_units),
+            sequence_feature_columns=sequence_feature_columns,
+            context_feature_columns=context_feature_columns,
+            input_layer_partitioner=input_layer_partitioner)
+        # Features are constructed within this function, otherwise the Tensors
+        # containing the features would be defined outside this graph.
+        logits = logit_fn(features=features_fn(), mode=mode)
+        with monitored_session.MonitoredTrainingSession(
+            checkpoint_dir=self._model_dir) as sess:
+          self.assertAllClose(expected_logits, sess.run(logits), atol=1e-4)
+
+  def testOneDimLogits(self):
+    """Tests one-dimensional logits.
+
+    Intermediate values are rounded for ease in reading.
+    input_layer = [[[10]], [[5]]]
+    initial_state = [0, 0]
+    rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),
+                              tanh(-.2*10 - .3*0 - .4*0 +.5)]]
+                          = [[0.83, -0.91]]
+    rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),
+                              tanh(-.2*5 - .3*.83 + .4*.91 +.5)]]
+                          = [[0.53, -0.37]]
+    logits = [[-1*0.53 - 1*0.37 + 0.3]] = [[-0.6033]]
+    """
+    base_global_step = 100
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1.], [1.]],
+        logits_biases=[0.3],
+        global_step=base_global_step,
+        model_dir=self._model_dir)
+
+    def features_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5.],
+                  indices=[[0, 0], [0, 1]],
+                  dense_shape=[1, 2]),
+      }
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+    context_feature_columns = []
+    for mode in [
+        model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+        model_fn.ModeKeys.PREDICT
+    ]:
+      self._test_logits(
+          mode,
+          rnn_units=[2],
+          logits_dimension=1,
+          features_fn=features_fn,
+          sequence_feature_columns=sequence_feature_columns,
+          context_feature_columns=context_feature_columns,
+          expected_logits=[[-0.6033]])
+
+  def testMultiDimLogits(self):
+    """Tests multi-dimensional logits.
+
+    Intermediate values are rounded for ease in reading.
+    input_layer = [[[10]], [[5]]]
+    initial_state = [0, 0]
+    rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),
+                              tanh(-.2*10 - .3*0 - .4*0 +.5)]]
+                          = [[0.83, -0.91]]
+    rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),
+                              tanh(-.2*5 - .3*.83 + .4*.91 +.5)]]
+                          = [[0.53, -0.37]]
+    logits = [[-1*0.53 - 1*0.37 + 0.3],
+              [0.5*0.53 + 0.3*0.37 + 0.4],
+              [0.2*0.53 - 0.1*0.37 + 0.5]
+           = [[-0.6033, 0.7777, 0.5698]]
+    """
+    base_global_step = 100
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+        logits_biases=[0.3, 0.4, 0.5],
+        global_step=base_global_step,
+        model_dir=self._model_dir)
+
+    def features_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5.],
+                  indices=[[0, 0], [0, 1]],
+                  dense_shape=[1, 2]),
+      }
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+    context_feature_columns = []
+
+    for mode in [
+        model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+        model_fn.ModeKeys.PREDICT
+    ]:
+      self._test_logits(
+          mode,
+          rnn_units=[2],
+          logits_dimension=3,
+          features_fn=features_fn,
+          sequence_feature_columns=sequence_feature_columns,
+          context_feature_columns=context_feature_columns,
+          expected_logits=[[-0.6033, 0.7777, 0.5698]])
+
+  def testMultiExampleMultiDim(self):
+    """Tests multiple examples and multi-dimensional logits.
+
+    Intermediate values are rounded for ease in reading.
+    input_layer = [[[10], [5]], [[2], [7]]]
+    initial_state = [[0, 0], [0, 0]]
+    rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),
+                              tanh(-.2*10 - .3*0 - .4*0 +.5)],
+                             [tanh(.1*2 + .2*0 + .3*0 +.2),
+                              tanh(-.2*2 - .3*0 - .4*0 +.5)]]
+                          = [[0.83, -0.91], [0.38, 0.10]]
+    rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),
+                              tanh(-.2*5 - .3*.83 + .4*.91 +.5)],
+                             [tanh(.1*7 + .2*.38 + .3*.10 +.2),
+                              tanh(-.2*7 - .3*.38 - .4*.10 +.5)]]
+                          = [[0.53, -0.37], [0.76, -0.78]
+    logits = [[-1*0.53 - 1*0.37 + 0.3,
+               0.5*0.53 + 0.3*0.37 + 0.4,
+               0.2*0.53 - 0.1*0.37 + 0.5],
+              [-1*0.76 - 1*0.78 + 0.3,
+               0.5*0.76 +0.3*0.78 + 0.4,
+               0.2*0.76 -0.1*0.78 + 0.5]]
+           = [[-0.6033, 0.7777, 0.5698], [-1.2473, 1.0170, 0.5745]]
+    """
+    base_global_step = 100
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+        logits_biases=[0.3, 0.4, 0.5],
+        global_step=base_global_step,
+        model_dir=self._model_dir)
+
+    def features_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5., 2., 7.],
+                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+                  dense_shape=[2, 2]),
+      }
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))
+    ]
+    context_feature_columns = []
+
+    for mode in [
+        model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+        model_fn.ModeKeys.PREDICT
+    ]:
+      self._test_logits(
+          mode,
+          rnn_units=[2],
+          logits_dimension=3,
+          features_fn=features_fn,
+          sequence_feature_columns=sequence_feature_columns,
+          context_feature_columns=context_feature_columns,
+          expected_logits=[[-0.6033, 0.7777, 0.5698],
+                           [-1.2473, 1.0170, 0.5745]])
+
+  def testMultiExamplesDifferentLength(self):
+    """Tests multiple examples with different lengths.
+
+    Intermediate values are rounded for ease in reading.
+    input_layer = [[[10], [5]], [[2], [0]]]
+    initial_state = [[0, 0], [0, 0]]
+    rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),
+                              tanh(-.2*10 - .3*0 - .4*0 +.5)],
+                             [tanh(.1*2 + .2*0 + .3*0 +.2),
+                              tanh(-.2*2 - .3*0 - .4*0 +.5)]]
+                          = [[0.83, -0.91], [0.38, 0.10]]
+    rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),
+                              tanh(-.2*5 - .3*.83 + .4*.91 +.5)],
+                             [<ignored-padding>]]
+                          = [[0.53, -0.37], [<ignored-padding>]]
+    logits = [[-1*0.53 - 1*0.37 + 0.3],
+              [-1*0.38 + 1*0.10 + 0.3]]
+           = [[-0.6033], [0.0197]]
+    """
+    base_global_step = 100
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1.], [1.]],
+        logits_biases=[0.3],
+        global_step=base_global_step,
+        model_dir=self._model_dir)
+
+    def features_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5., 2.],
+                  indices=[[0, 0], [0, 1], [1, 0]],
+                  dense_shape=[2, 2]),
+      }
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+    context_feature_columns = []
+
+    for mode in [
+        model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+        model_fn.ModeKeys.PREDICT
+    ]:
+      self._test_logits(
+          mode,
+          rnn_units=[2],
+          logits_dimension=1,
+          features_fn=features_fn,
+          sequence_feature_columns=sequence_feature_columns,
+          context_feature_columns=context_feature_columns,
+          expected_logits=[[-0.6033], [0.0197]])
+
+  def testMultiExamplesWithContext(self):
+    """Tests multiple examples with context features.
+
+    Intermediate values are rounded for ease in reading.
+    input_layer = [[[10, -0.5], [5, -0.5]], [[2, 0.8], [0, 0]]]
+    initial_state = [[0, 0], [0, 0]]
+    rnn_output_timestep_1 = [[tanh(.1*10 - 1*.5 + .2*0 + .3*0 +.2),
+                              tanh(-.2*10 - 0.9*.5 - .3*0 - .4*0 +.5)],
+                             [tanh(.1*2 + 1*.8 + .2*0 + .3*0 +.2),
+                              tanh(-.2*2 + .9*.8 - .3*0 - .4*0 +.5)]]
+                          = [[0.60, -0.96], [0.83, 0.68]]
+    rnn_output_timestep_2 = [[tanh(.1*5 - 1*.5 + .2*.60 - .3*.96 +.2),
+                              tanh(-.2*5 - .9*.5 - .3*.60 + .4*.96 +.5)],
+                             [<ignored-padding>]]
+                          = [[0.03, -0.63], [<ignored-padding>]]
+    logits = [[-1*0.03 - 1*0.63 + 0.3],
+              [-1*0.83 + 1*0.68 + 0.3]]
+           = [[-0.3662], [0.1414]]
+    """
+    base_global_step = 100
+    create_checkpoint(
+        # Context features weights are inserted between input and state weights.
+        rnn_weights=[[.1, -.2], [1., 0.9], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1.], [1.]],
+        logits_biases=[0.3],
+        global_step=base_global_step,
+        model_dir=self._model_dir)
+
+    def features_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5., 2.],
+                  indices=[[0, 0], [0, 1], [1, 0]],
+                  dense_shape=[2, 2]),
+          'context': [[-0.5], [0.8]],
+      }
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+    context_feature_columns = [fc.numeric_column('context', shape=(1,))]
+
+    for mode in [
+        model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+        model_fn.ModeKeys.PREDICT
+    ]:
+      self._test_logits(
+          mode,
+          rnn_units=[2],
+          logits_dimension=1,
+          features_fn=features_fn,
+          sequence_feature_columns=sequence_feature_columns,
+          context_feature_columns=context_feature_columns,
+          expected_logits=[[-0.3662], [0.1414]])
+
+  def testMultiExamplesMultiFeatures(self):
+    """Tests examples with multiple sequential feature columns.
+
+    Intermediate values are rounded for ease in reading.
+    input_layer = [[[1, 0, 10], [0, 1, 5]], [[1, 0, 2], [0, 0, 0]]]
+    initial_state = [[0, 0], [0, 0]]
+    rnn_output_timestep_1 = [[tanh(.5*1 + 1*0 + .1*10 + .2*0 + .3*0 +.2),
+                              tanh(-.5*1 - 1*0 - .2*10 - .3*0 - .4*0 +.5)],
+                             [tanh(.5*1 + 1*0 + .1*2 + .2*0 + .3*0 +.2),
+                              tanh(-.5*1 - 1*0 - .2*2 - .3*0 - .4*0 +.5)]]
+                          = [[0.94, -0.96], [0.72, -0.38]]
+    rnn_output_timestep_2 = [[tanh(.5*0 + 1*1 + .1*5 + .2*.94 - .3*.96 +.2),
+                              tanh(-.5*0 - 1*1 - .2*5 - .3*.94 + .4*.96 +.5)],
+                             [<ignored-padding>]]
+                          = [[0.92, -0.88], [<ignored-padding>]]
+    logits = [[-1*0.92 - 1*0.88 + 0.3],
+              [-1*0.72 - 1*0.38 + 0.3]]
+           = [[-1.5056], [-0.7962]]
+    """
+    base_global_step = 100
+    create_checkpoint(
+        # FeatureColumns are sorted alphabetically, so on_sale weights are
+        # inserted before price.
+        rnn_weights=[[.5, -.5], [1., -1.], [.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1.], [1.]],
+        logits_biases=[0.3],
+        global_step=base_global_step,
+        model_dir=self._model_dir)
+
+    def features_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5., 2.],
+                  indices=[[0, 0], [0, 1], [1, 0]],
+                  dense_shape=[2, 2]),
+          'on_sale':
+              sparse_tensor.SparseTensor(
+                  values=[0, 1, 0],
+                  indices=[[0, 0], [0, 1], [1, 0]],
+                  dense_shape=[2, 2]),
+      }
+
+    price_column = seq_fc.sequence_numeric_column('price', shape=(1,))
+    on_sale_column = fc.indicator_column(
+        seq_fc.sequence_categorical_column_with_identity(
+            'on_sale', num_buckets=2))
+    sequence_feature_columns = [price_column, on_sale_column]
+    context_feature_columns = []
+
+    for mode in [
+        model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+        model_fn.ModeKeys.PREDICT
+    ]:
+      self._test_logits(
+          mode,
+          rnn_units=[2],
+          logits_dimension=1,
+          features_fn=features_fn,
+          sequence_feature_columns=sequence_feature_columns,
+          context_feature_columns=context_feature_columns,
+          expected_logits=[[-1.5056], [-0.7962]])
+
+
+class RNNClassifierTrainingTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def _assert_checkpoint(
+      self, n_classes, input_units, cell_units, expected_global_step):
+
+    shapes = {
+        name: shape for (name, shape) in
+        checkpoint_utils.list_variables(self._model_dir)
+    }
+
+    self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
+    self.assertEqual(
+        expected_global_step,
+        checkpoint_utils.load_variable(
+            self._model_dir, ops.GraphKeys.GLOBAL_STEP))
+
+    # RNN Cell variables.
+    if len(cell_units) > 1:
+      for i, cell_unit in enumerate(cell_units):
+        self.assertEqual([input_units + cell_unit, cell_unit],
+                         shapes[MULTI_CELL_WEIGHTS_NAME_PATTERN % i])
+        self.assertEqual([cell_unit],
+                         shapes[MULTI_CELL_BIAS_NAME_PATTERN % i])
+        input_units = cell_unit
+    elif len(cell_units) == 1:
+      self.assertEqual([input_units + cell_unit, cell_unit],
+                       shapes[CELL_WEIGHTS_NAME])
+      self.assertEqual([cell_unit], shapes[CELL_BIAS_NAME])
+
+    # Logits variables.
+    logits_dimension = n_classes if n_classes > 2 else 1
+    self.assertEqual([cell_units[-1], logits_dimension],
+                     shapes[LOGITS_WEIGHTS_NAME])
+    self.assertEqual([logits_dimension], shapes[LOGITS_BIAS_NAME])
+
+  def _mock_optimizer(self, expected_loss=None):
+    expected_var_names = [
+        '%s/part_0:0' % CELL_BIAS_NAME,
+        '%s/part_0:0' % CELL_WEIGHTS_NAME,
+        '%s/part_0:0' % LOGITS_BIAS_NAME,
+        '%s/part_0:0' % LOGITS_WEIGHTS_NAME,
+    ]
+
+    def _minimize(loss, global_step):
+      trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+      self.assertItemsEqual(
+          expected_var_names,
+          [var.name for var in trainable_vars])
+
+      # Verify loss. We can't check the value directly, so we add an assert op.
+      self.assertEquals(0, loss.shape.ndims)
+      if expected_loss is None:
+        return state_ops.assign_add(global_step, 1).op
+      assert_loss = _assert_close(
+          math_ops.to_float(expected_loss, name='expected'),
+          loss,
+          name='assert_loss')
+      with ops.control_dependencies((assert_loss,)):
+        return state_ops.assign_add(global_step, 1).op
+
+    mock_optimizer = test.mock.NonCallableMock(
+        spec=optimizer.Optimizer,
+        wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
+    mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize)
+
+    # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.
+    # So, return mock_optimizer itself for deepcopy.
+    mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
+    return mock_optimizer
+
+  def testConflictingRNNCellFn(self):
+    col = seq_fc.sequence_categorical_column_with_hash_bucket(
+        'tokens', hash_bucket_size=10)
+    embed = fc.embedding_column(col, dimension=2)
+    cell_units = [4, 2]
+
+    with self.assertRaisesRegexp(
+        ValueError,
+        'num_units and cell_type must not be specified when using rnn_cell_fn'):
+      rnn.RNNClassifier(
+          sequence_feature_columns=[embed],
+          rnn_cell_fn=lambda x: x,
+          num_units=cell_units)
+
+    with self.assertRaisesRegexp(
+        ValueError,
+        'num_units and cell_type must not be specified when using rnn_cell_fn'):
+      rnn.RNNClassifier(
+          sequence_feature_columns=[embed],
+          rnn_cell_fn=lambda x: x,
+          cell_type='lstm')
+
+  def _testFromScratchWithDefaultOptimizer(self, n_classes):
+    def train_input_fn():
+      return {
+          'tokens':
+              sparse_tensor.SparseTensor(
+                  values=['the', 'cat', 'sat'],
+                  indices=[[0, 0], [0, 1], [0, 2]],
+                  dense_shape=[1, 3]),
+      }, [[1]]
+
+    col = seq_fc.sequence_categorical_column_with_hash_bucket(
+        'tokens', hash_bucket_size=10)
+    embed = fc.embedding_column(col, dimension=2)
+    input_units = 2
+
+    cell_units = [4, 2]
+    est = rnn.RNNClassifier(
+        sequence_feature_columns=[embed],
+        num_units=cell_units,
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+
+    # Train for a few steps, and validate final checkpoint.
+    num_steps = 10
+    est.train(input_fn=train_input_fn, steps=num_steps)
+    self._assert_checkpoint(n_classes, input_units, cell_units, num_steps)
+
+  def testBinaryClassFromScratchWithDefaultOptimizer(self):
+    self._testFromScratchWithDefaultOptimizer(n_classes=2)
+
+  def testMultiClassFromScratchWithDefaultOptimizer(self):
+    self._testFromScratchWithDefaultOptimizer(n_classes=4)
+
+  def testFromScratchWithCustomRNNCellFn(self):
+    def train_input_fn():
+      return {
+          'tokens':
+              sparse_tensor.SparseTensor(
+                  values=['the', 'cat', 'sat'],
+                  indices=[[0, 0], [0, 1], [0, 2]],
+                  dense_shape=[1, 3]),
+      }, [[1]]
+
+    col = seq_fc.sequence_categorical_column_with_hash_bucket(
+        'tokens', hash_bucket_size=10)
+    embed = fc.embedding_column(col, dimension=2)
+    input_units = 2
+    cell_units = [4, 2]
+    n_classes = 2
+
+    def rnn_cell_fn(mode):
+      del mode  # unused
+      cells = [rnn_cell.BasicRNNCell(num_units=n) for n in cell_units]
+      return rnn_cell.MultiRNNCell(cells)
+
+    est = rnn.RNNClassifier(
+        sequence_feature_columns=[embed],
+        rnn_cell_fn=rnn_cell_fn,
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+
+    # Train for a few steps, and validate final checkpoint.
+    num_steps = 10
+    est.train(input_fn=train_input_fn, steps=num_steps)
+    self._assert_checkpoint(n_classes, input_units, cell_units, num_steps)
+
+  def _testExampleWeight(self, n_classes):
+    def train_input_fn():
+      return {
+          'tokens':
+              sparse_tensor.SparseTensor(
+                  values=['the', 'cat', 'sat', 'dog', 'barked'],
+                  indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],
+                  dense_shape=[2, 3]),
+          'w': [[1], [2]],
+      }, [[1], [0]]
+
+    col = seq_fc.sequence_categorical_column_with_hash_bucket(
+        'tokens', hash_bucket_size=10)
+    embed = fc.embedding_column(col, dimension=2)
+    input_units = 2
+
+    cell_units = [4, 2]
+    est = rnn.RNNClassifier(
+        num_units=cell_units,
+        sequence_feature_columns=[embed],
+        n_classes=n_classes,
+        weight_column='w',
+        model_dir=self._model_dir)
+
+    # Train for a few steps, and validate final checkpoint.
+    num_steps = 10
+    est.train(input_fn=train_input_fn, steps=num_steps)
+    self._assert_checkpoint(n_classes, input_units, cell_units, num_steps)
+
+  def testBinaryClassWithExampleWeight(self):
+    self._testExampleWeight(n_classes=2)
+
+  def testMultiClassWithExampleWeight(self):
+    self._testExampleWeight(n_classes=4)
+
+  def testBinaryClassFromCheckpoint(self):
+    initial_global_step = 100
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1.], [1.]],
+        logits_biases=[0.3],
+        global_step=initial_global_step,
+        model_dir=self._model_dir)
+
+    def train_input_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5., 2.],
+                  indices=[[0, 0], [0, 1], [1, 0]],
+                  dense_shape=[2, 2]),
+      }, [[0], [1]]
+
+    # Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.
+    # See that test for loss calculation.
+    mock_optimizer = self._mock_optimizer(expected_loss=1.119661)
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+    est = rnn.RNNClassifier(
+        num_units=[2],
+        sequence_feature_columns=sequence_feature_columns,
+        n_classes=2,
+        optimizer=mock_optimizer,
+        model_dir=self._model_dir)
+    self.assertEqual(0, mock_optimizer.minimize.call_count)
+    est.train(input_fn=train_input_fn, steps=10)
+    self.assertEqual(1, mock_optimizer.minimize.call_count)
+
+  def testMultiClassFromCheckpoint(self):
+    initial_global_step = 100
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+        logits_biases=[0.3, 0.4, 0.5],
+        global_step=initial_global_step,
+        model_dir=self._model_dir)
+
+    def train_input_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5., 2., 7.],
+                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+                  dense_shape=[2, 2]),
+      }, [[0], [1]]
+
+    # Uses same checkpoint and examples as testMultiClassEvaluationMetrics.
+    # See that test for loss calculation.
+    mock_optimizer = self._mock_optimizer(expected_loss=2.662932)
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+    est = rnn.RNNClassifier(
+        num_units=[2],
+        sequence_feature_columns=sequence_feature_columns,
+        n_classes=3,
+        optimizer=mock_optimizer,
+        model_dir=self._model_dir)
+    self.assertEqual(0, mock_optimizer.minimize.call_count)
+    est.train(input_fn=train_input_fn, steps=10)
+    self.assertEqual(1, mock_optimizer.minimize.call_count)
+
+
+def sorted_key_dict(unsorted_dict):
+  return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}
+
+
+class RNNClassifierEvaluationTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def testBinaryClassEvaluationMetrics(self):
+    global_step = 100
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1.], [1.]],
+        logits_biases=[0.3],
+        global_step=global_step,
+        model_dir=self._model_dir)
+
+    def eval_input_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5., 2.],
+                  indices=[[0, 0], [0, 1], [1, 0]],
+                  dense_shape=[2, 2]),
+      }, [[0], [1]]
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+
+    est = rnn.RNNClassifier(
+        num_units=[2],
+        sequence_feature_columns=sequence_feature_columns,
+        n_classes=2,
+        model_dir=self._model_dir)
+    eval_metrics = est.evaluate(eval_input_fn, steps=1)
+
+    # Uses identical numbers to testMultiExamplesWithDifferentLength.
+    # See that test for logits calculation.
+    # logits = [[-0.603282], [0.019719]]
+    # probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]
+    # loss = -label * ln(p) - (1 - label) * ln(1 - p)
+    #      = [[0.436326], [0.683335]]
+    expected_metrics = {
+        ops.GraphKeys.GLOBAL_STEP: global_step,
+        metric_keys.MetricKeys.LOSS: 1.119661,
+        metric_keys.MetricKeys.LOSS_MEAN: 0.559831,
+        metric_keys.MetricKeys.ACCURACY: 1.0,
+        metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,
+        metric_keys.MetricKeys.LABEL_MEAN: 0.5,
+        metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+        # With default threshold of 0.5, the model is a perfect classifier.
+        metric_keys.MetricKeys.RECALL: 1.0,
+        metric_keys.MetricKeys.PRECISION: 1.0,
+        # Positive example is scored above negative, so AUC = 1.0.
+        metric_keys.MetricKeys.AUC: 1.0,
+        metric_keys.MetricKeys.AUC_PR: 1.0,
+    }
+    self.assertAllClose(
+        sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
+
+  def testMultiClassEvaluationMetrics(self):
+    global_step = 100
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+        logits_biases=[0.3, 0.4, 0.5],
+        global_step=global_step,
+        model_dir=self._model_dir)
+
+    def eval_input_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5., 2., 7.],
+                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+                  dense_shape=[2, 2]),
+      }, [[0], [1]]
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+
+    est = rnn.RNNClassifier(
+        num_units=[2],
+        sequence_feature_columns=sequence_feature_columns,
+        n_classes=3,
+        model_dir=self._model_dir)
+    eval_metrics = est.evaluate(eval_input_fn, steps=1)
+
+    # Uses identical numbers to testMultiExampleMultiDim.
+    # See that test for logits calculation.
+    # logits = [[-0.603282, 0.777708, 0.569756],
+    #           [-1.247356, 1.017018, 0.574481]]
+    # logits_exp = exp(logits) / (1 + exp(logits))
+    #            = [[0.547013, 2.176468, 1.767836],
+    #               [0.287263, 2.764937, 1.776208]]
+    # softmax_probabilities = logits_exp / logits_exp.sum()
+    #                       = [[0.121793, 0.484596, 0.393611],
+    #                          [0.059494, 0.572639, 0.367866]]
+    # loss = -1. * log(softmax[label])
+    #      = [[2.105432], [0.557500]]
+    expected_metrics = {
+        ops.GraphKeys.GLOBAL_STEP: global_step,
+        metric_keys.MetricKeys.LOSS: 2.662932,
+        metric_keys.MetricKeys.LOSS_MEAN: 1.331466,
+        metric_keys.MetricKeys.ACCURACY: 0.5,
+    }
+
+    self.assertAllClose(
+        sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
+
+
+class RNNClassifierPredictionTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def testBinaryClassPredictions(self):
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1.], [1.]],
+        logits_biases=[0.3],
+        global_step=0,
+        model_dir=self._model_dir)
+
+    def predict_input_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5.],
+                  indices=[[0, 0], [0, 1]],
+                  dense_shape=[1, 2]),
+      }
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+    label_vocabulary = ['class_0', 'class_1']
+
+    est = rnn.RNNClassifier(
+        num_units=[2],
+        sequence_feature_columns=sequence_feature_columns,
+        n_classes=2,
+        label_vocabulary=label_vocabulary,
+        model_dir=self._model_dir)
+    # Uses identical numbers to testOneDimLogits.
+    # See that test for logits calculation.
+    # logits = [-0.603282]
+    # logistic = exp(-0.6033) / (1 + exp(-0.6033)) = [0.353593]
+    # probabilities = [0.646407, 0.353593]
+    # class_ids = argmax(probabilities) = [0]
+    predictions = next(est.predict(predict_input_fn))
+    self.assertAllClose([-0.603282],
+                        predictions[prediction_keys.PredictionKeys.LOGITS])
+    self.assertAllClose([0.353593],
+                        predictions[prediction_keys.PredictionKeys.LOGISTIC])
+    self.assertAllClose(
+        [0.646407, 0.353593],
+        predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+    self.assertAllClose([0],
+                        predictions[prediction_keys.PredictionKeys.CLASS_IDS])
+    self.assertEqual([b'class_0'],
+                     predictions[prediction_keys.PredictionKeys.CLASSES])
+
+  def testMultiClassPredictions(self):
+    create_checkpoint(
+        rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+        rnn_biases=[.2, .5],
+        logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+        logits_biases=[0.3, 0.4, 0.5],
+        global_step=0,
+        model_dir=self._model_dir)
+
+    def predict_input_fn():
+      return {
+          'price':
+              sparse_tensor.SparseTensor(
+                  values=[10., 5.],
+                  indices=[[0, 0], [0, 1]],
+                  dense_shape=[1, 2]),
+      }
+
+    sequence_feature_columns = [
+        seq_fc.sequence_numeric_column('price', shape=(1,))]
+    label_vocabulary = ['class_0', 'class_1', 'class_2']
+
+    est = rnn.RNNClassifier(
+        num_units=[2],
+        sequence_feature_columns=sequence_feature_columns,
+        n_classes=3,
+        label_vocabulary=label_vocabulary,
+        model_dir=self._model_dir)
+    # Uses identical numbers to testMultiDimLogits.
+    # See that test for logits calculation.
+    # logits = [-0.603282, 0.777708, 0.569756]
+    # logits_exp = exp(logits) = [0.547013, 2.176468, 1.767836]
+    # softmax_probabilities = logits_exp / logits_exp.sum()
+    #                       = [0.121793, 0.484596, 0.393611]
+    # class_ids = argmax(probabilities) = [1]
+    predictions = next(est.predict(predict_input_fn))
+    self.assertAllClose([-0.603282, 0.777708, 0.569756],
+                        predictions[prediction_keys.PredictionKeys.LOGITS])
+    self.assertAllClose(
+        [0.121793, 0.484596, 0.393611],
+        predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+    self.assertAllClose([1],
+                        predictions[prediction_keys.PredictionKeys.CLASS_IDS])
+    self.assertEqual([b'class_1'],
+                     predictions[prediction_keys.PredictionKeys.CLASSES])
+
+
+class RNNClassifierIntegrationTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def _test_complete_flow(
+      self, train_input_fn, eval_input_fn, predict_input_fn, n_classes,
+      batch_size):
+    col = seq_fc.sequence_categorical_column_with_hash_bucket(
+        'tokens', hash_bucket_size=10)
+    embed = fc.embedding_column(col, dimension=2)
+    feature_columns = [embed]
+
+    cell_units = [4, 2]
+    est = rnn.RNNClassifier(
+        num_units=cell_units,
+        sequence_feature_columns=feature_columns,
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+
+    # TRAIN
+    num_steps = 10
+    est.train(train_input_fn, steps=num_steps)
+
+    # EVALUATE
+    scores = est.evaluate(eval_input_fn)
+    self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+    self.assertIn('loss', six.iterkeys(scores))
+
+    # PREDICT
+    predicted_proba = np.array([
+        x[prediction_keys.PredictionKeys.PROBABILITIES]
+        for x in est.predict(predict_input_fn)
+    ])
+    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
+
+    # EXPORT
+    feature_spec = {
+        'tokens': parsing_ops.VarLenFeature(dtypes.string),
+        'label': parsing_ops.FixedLenFeature([1], dtypes.int64),
+    }
+    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+        feature_spec)
+    export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+                                       serving_input_receiver_fn)
+    self.assertTrue(gfile.Exists(export_dir))
+
+  def testNumpyInputFn(self):
+    """Tests complete flow with numpy_input_fn."""
+    n_classes = 3
+    batch_size = 10
+    words = ['dog', 'cat', 'bird', 'the', 'a', 'sat', 'flew', 'slept']
+    # Numpy only supports dense input, so all examples will have same length.
+    # TODO(b/73160931): Update test when support for prepadded data exists.
+    sequence_length = 3
+
+    features = []
+    for _ in range(batch_size):
+      sentence = random.sample(words, sequence_length)
+      features.append(sentence)
+
+    x_data = np.array(features)
+    y_data = np.random.randint(n_classes, size=batch_size)
+
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'tokens': x_data},
+        y=y_data,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    eval_input_fn = numpy_io.numpy_input_fn(
+        x={'tokens': x_data},
+        y=y_data,
+        batch_size=batch_size,
+        shuffle=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x={'tokens': x_data},
+        batch_size=batch_size,
+        shuffle=False)
+
+    self._test_complete_flow(
+        train_input_fn=train_input_fn,
+        eval_input_fn=eval_input_fn,
+        predict_input_fn=predict_input_fn,
+        n_classes=n_classes,
+        batch_size=batch_size)
+
+  def testParseExampleInputFn(self):
+    """Tests complete flow with input_fn constructed from parse_example."""
+    n_classes = 3
+    batch_size = 10
+    words = [b'dog', b'cat', b'bird', b'the', b'a', b'sat', b'flew', b'slept']
+
+    serialized_examples = []
+    for _ in range(batch_size):
+      sequence_length = random.randint(1, len(words))
+      sentence = random.sample(words, sequence_length)
+      label = random.randint(0, n_classes - 1)
+      example = example_pb2.Example(features=feature_pb2.Features(
+          feature={
+              'tokens':
+                  feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+                      value=sentence)),
+              'label':
+                  feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+                      value=[label])),
+          }))
+      serialized_examples.append(example.SerializeToString())
+
+    feature_spec = {
+        'tokens': parsing_ops.VarLenFeature(dtypes.string),
+        'label': parsing_ops.FixedLenFeature([1], dtypes.int64),
+    }
+    def _train_input_fn():
+      features = parsing_ops.parse_example(serialized_examples, feature_spec)
+      labels = features.pop('label')
+      return features, labels
+    def _eval_input_fn():
+      features = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      labels = features.pop('label')
+      return features, labels
+    def _predict_input_fn():
+      features = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      features.pop('label')
+      return features, None
+
+    self._test_complete_flow(
+        train_input_fn=_train_input_fn,
+        eval_input_fn=_eval_input_fn,
+        predict_input_fn=_predict_input_fn,
+        n_classes=n_classes,
+        batch_size=batch_size)
+
+
+if __name__ == '__main__':
+  test.main()
index fe380c4..cbc2dcf 100644 (file)
@@ -1206,7 +1206,16 @@ class DeviceWrapper(RNNCell):
 
 @tf_export("nn.rnn_cell.MultiRNNCell")
 class MultiRNNCell(RNNCell):
-  """RNN cell composed sequentially of multiple simple cells."""
+  """RNN cell composed sequentially of multiple simple cells.
+
+  Example:
+
+  ```python
+  num_units = [128, 64]
+  cells = [BasicLSTMCell(num_units=n) for n in num_units]
+  stacked_rnn_cell = MultiRNNCell(cells)
+  ```
+  """
 
   def __init__(self, cells, state_is_tuple=True):
     """Create a RNN cell composed sequentially of a number of RNNCells.