From cb43bb37bfd5468efd92b03848edf6f3f06bfd5b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 6 Apr 2018 14:32:28 -0700 Subject: [PATCH] Add RNNClassifier PiperOrigin-RevId: 191941174 --- tensorflow/contrib/estimator/BUILD | 55 + tensorflow/contrib/estimator/__init__.py | 1 + .../contrib/estimator/python/estimator/rnn.py | 481 +++++++++ .../contrib/estimator/python/estimator/rnn_test.py | 1131 ++++++++++++++++++++ tensorflow/python/ops/rnn_cell_impl.py | 11 +- 5 files changed, 1678 insertions(+), 1 deletion(-) create mode 100644 tensorflow/contrib/estimator/python/estimator/rnn.py create mode 100644 tensorflow/contrib/estimator/python/estimator/rnn_test.py diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index bec0329..9f4cd44 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -23,6 +23,7 @@ py_library( ":logit_fns", ":multi_head", ":replicate_model_fn", + ":rnn", "//tensorflow/python:util", ], ) @@ -412,3 +413,57 @@ cuda_py_test( "notap", ], ) + +py_library( + name = "rnn", + srcs = ["python/estimator/rnn.py"], + srcs_version = "PY2AND3", + deps = [ + ":extenders", + "//tensorflow/contrib/feature_column:feature_column_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:rnn", + "//tensorflow/python:rnn_cell", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/feature_column", + "@six_archive//:six", + ], +) + +py_test( + name = "rnn_test", + size = "medium", + srcs = ["python/estimator/rnn_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":rnn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index d2fc2c4..9a87fa9 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -52,6 +52,7 @@ _allowed_symbols = [ 'linear_logit_fn_builder', 'replicate_model_fn', 'TowerOptimizer', + 'RNNClassifier', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py new file mode 100644 index 0000000..b475c12 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -0,0 +1,481 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Recurrent Neural Network estimators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import training_util + + +# The defaults are historical artifacts of the initial implementation, but seem +# reasonable choices. +_DEFAULT_LEARNING_RATE = 0.05 +_DEFAULT_CLIP_NORM = 5.0 + +_CELL_TYPES = {'basic_rnn': rnn_cell.BasicRNNCell, + 'lstm': rnn_cell.BasicLSTMCell, + 'gru': rnn_cell.GRUCell} + +# Indicates no value was provided by the user to a kwarg. +USE_DEFAULT = object() + + +def _single_rnn_cell(num_units, cell_type): + cell_type = _CELL_TYPES.get(cell_type, cell_type) + if not cell_type or not issubclass(cell_type, rnn_cell.RNNCell): + raise ValueError('Supported cell types are {}; got {}'.format( + list(_CELL_TYPES.keys()), cell_type)) + return cell_type(num_units=num_units) + + +def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'): + """Convenience function to create `rnn_cell_fn` for canned RNN Estimators. + + Args: + num_units: Iterable of integer number of hidden units per RNN layer. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. + + Returns: + A function that takes a single argument, an instance of + `tf.estimator.ModeKeys`, and returns an instance derived from + `tf.nn.rnn_cell.RNNCell`. + + Raises: + ValueError: If cell_type is not supported. + """ + def rnn_cell_fn(mode): + # Unused. Part of the rnn_cell_fn interface since user specified functions + # may need different behavior across modes (e.g. dropout). + del mode + cells = [_single_rnn_cell(n, cell_type) for n in num_units] + if len(cells) == 1: + return cells[0] + return rnn_cell.MultiRNNCell(cells) + return rnn_cell_fn + + +def _concatenate_context_input(sequence_input, context_input): + """Replicates `context_input` across all timesteps of `sequence_input`. + + Expands dimension 1 of `context_input` then tiles it `sequence_length` times. + This value is appended to `sequence_input` on dimension 2 and the result is + returned. + + Args: + sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, + padded_length, d0]`. + context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. + + Returns: + A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, + d0 + d1]`. + + Raises: + ValueError: If `sequence_input` does not have rank 3 or `context_input` does + not have rank 2. + """ + seq_rank_check = check_ops.assert_rank( + sequence_input, + 3, + message='sequence_input must have rank 3', + data=[array_ops.shape(sequence_input)]) + seq_type_check = check_ops.assert_type( + sequence_input, + dtypes.float32, + message='sequence_input must have dtype float32; got {}.'.format( + sequence_input.dtype)) + ctx_rank_check = check_ops.assert_rank( + context_input, + 2, + message='context_input must have rank 2', + data=[array_ops.shape(context_input)]) + ctx_type_check = check_ops.assert_type( + context_input, + dtypes.float32, + message='context_input must have dtype float32; got {}.'.format( + context_input.dtype)) + with ops.control_dependencies( + [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): + padded_length = array_ops.shape(sequence_input)[1] + tiled_context_input = array_ops.tile( + array_ops.expand_dims(context_input, 1), + array_ops.concat([[1], [padded_length], [1]], 0)) + return array_ops.concat([sequence_input, tiled_context_input], 2) + + +def _select_last_activations(activations, sequence_lengths): + """Selects the nth set of activations for each n in `sequence_length`. + + Returns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not + `None`, then `output[i, :] = activations[i, sequence_length[i] - 1, :]`. If + `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`. + + Args: + activations: A `Tensor` with shape `[batch_size, padded_length, k]`. + sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`. + Returns: + A `Tensor` of shape `[batch_size, k]`. + """ + with ops.name_scope( + 'select_last_activations', values=[activations, sequence_lengths]): + activations_shape = array_ops.shape(activations) + batch_size = activations_shape[0] + padded_length = activations_shape[1] + output_units = activations_shape[2] + if sequence_lengths is None: + sequence_lengths = padded_length + start_indices = math_ops.to_int64( + math_ops.range(batch_size) * padded_length) + last_indices = start_indices + sequence_lengths - 1 + reshaped_activations = array_ops.reshape( + activations, [batch_size * padded_length, output_units]) + + last_activations = array_ops.gather(reshaped_activations, last_indices) + last_activations.set_shape([activations.shape[0], activations.shape[2]]) + return last_activations + + +def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, + context_feature_columns, input_layer_partitioner): + """Function builder for a rnn logit_fn. + + Args: + output_units: An int indicating the dimension of the logit layer. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell`. + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. + context_feature_columns: An iterable containing the `FeatureColumn`s + that represent contextual input. + input_layer_partitioner: Partitioner for input layer. + + Returns: + A logit_fn (see below). + + Raises: + ValueError: If output_units is not an int. + """ + if not isinstance(output_units, int): + raise ValueError('output_units must be an int. Given type: {}'.format( + type(output_units))) + + def rnn_logit_fn(features, mode): + """Recurrent Neural Network logit_fn. + + Args: + features: This is the first item returned from the `input_fn` + passed to `train`, `evaluate`, and `predict`. This should be a + single `Tensor` or `dict` of same. + mode: Optional. Specifies if this training, evaluation or prediction. See + `ModeKeys`. + + Returns: + A `Tensor` representing the logits. + """ + with variable_scope.variable_scope( + 'sequence_input_layer', + values=tuple(six.itervalues(features)), + partitioner=input_layer_partitioner): + sequence_input, sequence_length = seq_fc.sequence_input_layer( + features=features, feature_columns=sequence_feature_columns) + summary.histogram('sequence_length', sequence_length) + + if context_feature_columns: + context_input = feature_column_lib.input_layer( + features=features, + feature_columns=context_feature_columns) + sequence_input = _concatenate_context_input(sequence_input, + context_input) + + cell = rnn_cell_fn(mode) + # Ignore output state. + rnn_outputs, _ = rnn.dynamic_rnn( + cell=cell, + inputs=sequence_input, + dtype=dtypes.float32, + time_major=False) + last_activations = _select_last_activations(rnn_outputs, sequence_length) + + with variable_scope.variable_scope('logits', values=(rnn_outputs,)): + logits = core_layers.dense( + last_activations, + units=output_units, + activation=None, + kernel_initializer=init_ops.glorot_uniform_initializer()) + return logits + + return rnn_logit_fn + + +def _rnn_model_fn(features, + labels, + mode, + head, + rnn_cell_fn, + sequence_feature_columns, + context_feature_columns, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Recurrent Neural Net model_fn. + + Args: + features: dict of `Tensor` and `SparseTensor` objects returned from + `input_fn`. + labels: `Tensor` of shape [batch_size, 1] or [batch_size] with labels. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + head: A `head_lib._Head` instance. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell`. + sequence_feature_columns: Iterable containing `FeatureColumn`s that + represent sequential model inputs. + context_feature_columns: Iterable containing `FeatureColumn`s that + represent model inputs not associated with a specific timestep. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use the Adagrad + optimizer with a default learning rate of 0.05 and gradient clip norm of + 5.0. + input_layer_partitioner: Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Returns: + An `EstimatorSpec` instance. + + Raises: + ValueError: If mode or optimizer is invalid, or features has the wrong type. + """ + if not isinstance(features, dict): + raise ValueError('features should be a dictionary of `Tensor`s. ' + 'Given type: {}'.format(type(features))) + + # If user does not provide an optimizer instance, use the optimizer specified + # by the string with default learning rate and gradient clipping. + if not isinstance(optimizer, optimizer_lib.Optimizer): + optimizer = optimizers.get_optimizer_instance( + optimizer, learning_rate=_DEFAULT_LEARNING_RATE) + optimizer = extenders.clip_gradients_by_norm(optimizer, _DEFAULT_CLIP_NORM) + + num_ps_replicas = config.num_ps_replicas if config else 0 + partitioner = partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas) + with variable_scope.variable_scope( + 'rnn', + values=tuple(six.itervalues(features)), + partitioner=partitioner): + input_layer_partitioner = input_layer_partitioner or ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas, + min_slice_size=64 << 20)) + + logit_fn = _rnn_logit_fn_builder( + output_units=head.logits_dimension, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + input_layer_partitioner=input_layer_partitioner) + logits = logit_fn(features=features, mode=mode) + + def _train_op_fn(loss): + """Returns the op to optimize the loss.""" + return optimizer.minimize( + loss, + global_step=training_util.get_global_step()) + + return head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + +class RNNClassifier(estimator.Estimator): + """A classifier for TensorFlow RNN models. + + Trains a recurrent neural network model to classify instances into one of + multiple classes. + + Example: + + ```python + token_sequence = sequence_categorical_column_with_hash_bucket(...) + token_emb = embedding_column(categorical_column=token_sequence, ...) + + estimator = RNNClassifier( + num_units=[32, 16], cell_type='lstm', + sequence_feature_columns=[token_emb]) + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `sequence_feature_columns`: + - a feature with `key=column.name` whose `value` is a `SparseTensor`. + * for each `column` in `context_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss is calculated by using softmax cross entropy. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + sequence_feature_columns, + context_feature_columns=None, + num_units=None, + cell_type=USE_DEFAULT, + rnn_cell_fn=None, + model_dir=None, + n_classes=2, + weight_column=None, + label_vocabulary=None, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Initializes a `RNNClassifier` instance. + + Args: + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. All items in the set should either be + sequence columns (e.g. `sequence_numeric_column`) or constructed from + one (e.g. `embedding_column` with `sequence_categorical_column_*` as + input). + context_feature_columns: An iterable containing the `FeatureColumn`s + for contextual input. The data represented by these columns will be + replicated and given to the RNN at each timestep. These columns must be + instances of classes derived from `_DenseColumn` such as + `numeric_column`, not the sequential variants. + num_units: Iterable of integer number of hidden units per RNN layer. If + set, `cell_type` must also be specified and `rnn_cell_fn` must be + `None`. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. If set, `num_units` must also be specified and `rnn_cell_fn` + must be `None`. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell` that will be used to + construct the RNN. If set, `num_units` and `cell_type` cannot be set. + This is for advanced users who need additional customization beyond + `num_units` and `cell_type`. Note that `tf.nn.rnn_cell.MultiRNNCell` is + needed for stacked RNNs. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + n_classes: Number of label classes. Defaults to 2, namely binary + classification. Must be > 1. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + label_vocabulary: A list of strings represents possible label values. If + given, labels must be string type and have any value in + `label_vocabulary`. If it is not given, that means labels are + already encoded as integer or float within [0, 1] for `n_classes=2` and + encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . + Also there will be errors if vocabulary is not provided and labels are + string. + optimizer: An instance of `tf.Optimizer` used to train the model. Defaults + to Adagrad optimizer. + input_layer_partitioner: Optional. Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not + compatible. + """ + if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT): + raise ValueError( + 'num_units and cell_type must not be specified when using rnn_cell_fn' + ) + if not rnn_cell_fn: + if cell_type == USE_DEFAULT: + cell_type = 'basic_rnn' + rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type) + + if n_classes == 2: + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access + weight_column=weight_column, + label_vocabulary=label_vocabulary) + else: + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access + n_classes, weight_column=weight_column, + label_vocabulary=label_vocabulary) + def _model_fn(features, labels, mode, config): + return _rnn_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=tuple(sequence_feature_columns or []), + context_feature_columns=tuple(context_feature_columns or []), + optimizer=optimizer, + input_layer_partitioner=input_layer_partitioner, + config=config) + super(RNNClassifier, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py new file mode 100644 index 0000000..393f94f --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -0,0 +1,1131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rnn.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import rnn +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import input as input_lib +from tensorflow.python.training import monitored_session +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_util + + +# Names of variables created by BasicRNNCell model. +TOKEN_EMBEDDING_NAME = 'rnn/sequence_input_layer/input_layer/tokens_sequential_embedding/embedding_weights' +CELL_WEIGHTS_NAME = 'rnn/rnn/basic_rnn_cell/kernel' +CELL_BIAS_NAME = 'rnn/rnn/basic_rnn_cell/bias' +MULTI_CELL_WEIGHTS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/kernel' +MULTI_CELL_BIAS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/bias' +LOGITS_WEIGHTS_NAME = 'rnn/logits/dense/kernel' +LOGITS_BIAS_NAME = 'rnn/logits/dense/bias' + + +def _assert_close(expected, actual, rtol=1e-04, name='assert_close'): + with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: + expected = ops.convert_to_tensor(expected, name='expected') + actual = ops.convert_to_tensor(actual, name='actual') + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) + rtol = ops.convert_to_tensor(rtol, name='rtol') + return check_ops.assert_less( + rdiff, + rtol, + data=('Condition expected =~ actual did not hold element-wise:' + 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, + 'rtol = ', rtol,), + name=scope) + + +def create_checkpoint(rnn_weights, rnn_biases, logits_weights, logits_biases, + global_step, model_dir): + """Create checkpoint file with provided model weights. + + Args: + rnn_weights: Iterable of values of weights for the RNN cell. + rnn_biases: Iterable of values of biases for the RNN cell. + logits_weights: Iterable of values for matrix connecting RNN output to + logits. + logits_biases: Iterable of values for logits bias term. + global_step: Initial global step to save in checkpoint. + model_dir: Directory into which checkpoint is saved. + """ + model_weights = {} + model_weights[CELL_WEIGHTS_NAME] = rnn_weights + model_weights[CELL_BIAS_NAME] = rnn_biases + model_weights[LOGITS_WEIGHTS_NAME] = logits_weights + model_weights[LOGITS_BIAS_NAME] = logits_biases + + with ops.Graph().as_default(): + # Create model variables. + for k, v in six.iteritems(model_weights): + variables_lib.Variable(v, name=k, dtype=dtypes.float32) + + # Create non-model variables. + global_step_var = training_util.create_global_step() + assign_op = global_step_var.assign(global_step) + + # Initialize vars and save checkpoint. + with monitored_session.MonitoredTrainingSession( + checkpoint_dir=model_dir) as sess: + sess.run(assign_op) + + +class RNNLogitFnTest(test.TestCase): + """Tests correctness of logits calculated from _rnn_logit_fn_builder.""" + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_logits(self, mode, rnn_units, logits_dimension, features_fn, + sequence_feature_columns, context_feature_columns, + expected_logits): + """Tests that the expected logits are calculated.""" + with ops.Graph().as_default(): + # Global step needed for MonitoredSession, which is in turn used to + # explicitly set variable weights through a checkpoint. + training_util.create_global_step() + # Use a variable scope here with 'rnn', emulating the rnn model_fn, so + # the checkpoint naming is shared. + with variable_scope.variable_scope('rnn'): + input_layer_partitioner = ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=0, min_slice_size=64 << 20)) + logit_fn = rnn._rnn_logit_fn_builder( + output_units=logits_dimension, + rnn_cell_fn=rnn._make_rnn_cell_fn(rnn_units), + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + input_layer_partitioner=input_layer_partitioner) + # Features are constructed within this function, otherwise the Tensors + # containing the features would be defined outside this graph. + logits = logit_fn(features=features_fn(), mode=mode) + with monitored_session.MonitoredTrainingSession( + checkpoint_dir=self._model_dir) as sess: + self.assertAllClose(expected_logits, sess.run(logits), atol=1e-4) + + def testOneDimLogits(self): + """Tests one-dimensional logits. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10]], [[5]]] + initial_state = [0, 0] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)]] + = [[0.53, -0.37]] + logits = [[-1*0.53 - 1*0.37 + 0.3]] = [[-0.6033]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [] + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033]]) + + def testMultiDimLogits(self): + """Tests multi-dimensional logits. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10]], [[5]]] + initial_state = [0, 0] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)]] + = [[0.53, -0.37]] + logits = [[-1*0.53 - 1*0.37 + 0.3], + [0.5*0.53 + 0.3*0.37 + 0.4], + [0.2*0.53 - 0.1*0.37 + 0.5] + = [[-0.6033, 0.7777, 0.5698]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=3, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033, 0.7777, 0.5698]]) + + def testMultiExampleMultiDim(self): + """Tests multiple examples and multi-dimensional logits. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10], [5]], [[2], [7]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)], + [tanh(.1*2 + .2*0 + .3*0 +.2), + tanh(-.2*2 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91], [0.38, 0.10]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)], + [tanh(.1*7 + .2*.38 + .3*.10 +.2), + tanh(-.2*7 - .3*.38 - .4*.10 +.5)]] + = [[0.53, -0.37], [0.76, -0.78] + logits = [[-1*0.53 - 1*0.37 + 0.3, + 0.5*0.53 + 0.3*0.37 + 0.4, + 0.2*0.53 - 0.1*0.37 + 0.5], + [-1*0.76 - 1*0.78 + 0.3, + 0.5*0.76 +0.3*0.78 + 0.4, + 0.2*0.76 -0.1*0.78 + 0.5]] + = [[-0.6033, 0.7777, 0.5698], [-1.2473, 1.0170, 0.5745]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2., 7.], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,)) + ] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=3, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033, 0.7777, 0.5698], + [-1.2473, 1.0170, 0.5745]]) + + def testMultiExamplesDifferentLength(self): + """Tests multiple examples with different lengths. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10], [5]], [[2], [0]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)], + [tanh(.1*2 + .2*0 + .3*0 +.2), + tanh(-.2*2 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91], [0.38, 0.10]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)], + []] + = [[0.53, -0.37], []] + logits = [[-1*0.53 - 1*0.37 + 0.3], + [-1*0.38 + 1*0.10 + 0.3]] + = [[-0.6033], [0.0197]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033], [0.0197]]) + + def testMultiExamplesWithContext(self): + """Tests multiple examples with context features. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10, -0.5], [5, -0.5]], [[2, 0.8], [0, 0]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.1*10 - 1*.5 + .2*0 + .3*0 +.2), + tanh(-.2*10 - 0.9*.5 - .3*0 - .4*0 +.5)], + [tanh(.1*2 + 1*.8 + .2*0 + .3*0 +.2), + tanh(-.2*2 + .9*.8 - .3*0 - .4*0 +.5)]] + = [[0.60, -0.96], [0.83, 0.68]] + rnn_output_timestep_2 = [[tanh(.1*5 - 1*.5 + .2*.60 - .3*.96 +.2), + tanh(-.2*5 - .9*.5 - .3*.60 + .4*.96 +.5)], + []] + = [[0.03, -0.63], []] + logits = [[-1*0.03 - 1*0.63 + 0.3], + [-1*0.83 + 1*0.68 + 0.3]] + = [[-0.3662], [0.1414]] + """ + base_global_step = 100 + create_checkpoint( + # Context features weights are inserted between input and state weights. + rnn_weights=[[.1, -.2], [1., 0.9], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + 'context': [[-0.5], [0.8]], + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [fc.numeric_column('context', shape=(1,))] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.3662], [0.1414]]) + + def testMultiExamplesMultiFeatures(self): + """Tests examples with multiple sequential feature columns. + + Intermediate values are rounded for ease in reading. + input_layer = [[[1, 0, 10], [0, 1, 5]], [[1, 0, 2], [0, 0, 0]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.5*1 + 1*0 + .1*10 + .2*0 + .3*0 +.2), + tanh(-.5*1 - 1*0 - .2*10 - .3*0 - .4*0 +.5)], + [tanh(.5*1 + 1*0 + .1*2 + .2*0 + .3*0 +.2), + tanh(-.5*1 - 1*0 - .2*2 - .3*0 - .4*0 +.5)]] + = [[0.94, -0.96], [0.72, -0.38]] + rnn_output_timestep_2 = [[tanh(.5*0 + 1*1 + .1*5 + .2*.94 - .3*.96 +.2), + tanh(-.5*0 - 1*1 - .2*5 - .3*.94 + .4*.96 +.5)], + []] + = [[0.92, -0.88], []] + logits = [[-1*0.92 - 1*0.88 + 0.3], + [-1*0.72 - 1*0.38 + 0.3]] + = [[-1.5056], [-0.7962]] + """ + base_global_step = 100 + create_checkpoint( + # FeatureColumns are sorted alphabetically, so on_sale weights are + # inserted before price. + rnn_weights=[[.5, -.5], [1., -1.], [.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + 'on_sale': + sparse_tensor.SparseTensor( + values=[0, 1, 0], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + } + + price_column = seq_fc.sequence_numeric_column('price', shape=(1,)) + on_sale_column = fc.indicator_column( + seq_fc.sequence_categorical_column_with_identity( + 'on_sale', num_buckets=2)) + sequence_feature_columns = [price_column, on_sale_column] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-1.5056], [-0.7962]]) + + +class RNNClassifierTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _assert_checkpoint( + self, n_classes, input_units, cell_units, expected_global_step): + + shapes = { + name: shape for (name, shape) in + checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual( + expected_global_step, + checkpoint_utils.load_variable( + self._model_dir, ops.GraphKeys.GLOBAL_STEP)) + + # RNN Cell variables. + if len(cell_units) > 1: + for i, cell_unit in enumerate(cell_units): + self.assertEqual([input_units + cell_unit, cell_unit], + shapes[MULTI_CELL_WEIGHTS_NAME_PATTERN % i]) + self.assertEqual([cell_unit], + shapes[MULTI_CELL_BIAS_NAME_PATTERN % i]) + input_units = cell_unit + elif len(cell_units) == 1: + self.assertEqual([input_units + cell_unit, cell_unit], + shapes[CELL_WEIGHTS_NAME]) + self.assertEqual([cell_unit], shapes[CELL_BIAS_NAME]) + + # Logits variables. + logits_dimension = n_classes if n_classes > 2 else 1 + self.assertEqual([cell_units[-1], logits_dimension], + shapes[LOGITS_WEIGHTS_NAME]) + self.assertEqual([logits_dimension], shapes[LOGITS_BIAS_NAME]) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s/part_0:0' % CELL_BIAS_NAME, + '%s/part_0:0' % CELL_WEIGHTS_NAME, + '%s/part_0:0' % LOGITS_BIAS_NAME, + '%s/part_0:0' % LOGITS_WEIGHTS_NAME, + ] + + def _minimize(loss, global_step): + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + return state_ops.assign_add(global_step, 1).op + assert_loss = _assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + return state_ops.assign_add(global_step, 1).op + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def testConflictingRNNCellFn(self): + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + cell_units = [4, 2] + + with self.assertRaisesRegexp( + ValueError, + 'num_units and cell_type must not be specified when using rnn_cell_fn'): + rnn.RNNClassifier( + sequence_feature_columns=[embed], + rnn_cell_fn=lambda x: x, + num_units=cell_units) + + with self.assertRaisesRegexp( + ValueError, + 'num_units and cell_type must not be specified when using rnn_cell_fn'): + rnn.RNNClassifier( + sequence_feature_columns=[embed], + rnn_cell_fn=lambda x: x, + cell_type='lstm') + + def _testFromScratchWithDefaultOptimizer(self, n_classes): + def train_input_fn(): + return { + 'tokens': + sparse_tensor.SparseTensor( + values=['the', 'cat', 'sat'], + indices=[[0, 0], [0, 1], [0, 2]], + dense_shape=[1, 3]), + }, [[1]] + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + input_units = 2 + + cell_units = [4, 2] + est = rnn.RNNClassifier( + sequence_feature_columns=[embed], + num_units=cell_units, + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train(input_fn=train_input_fn, steps=num_steps) + self._assert_checkpoint(n_classes, input_units, cell_units, num_steps) + + def testBinaryClassFromScratchWithDefaultOptimizer(self): + self._testFromScratchWithDefaultOptimizer(n_classes=2) + + def testMultiClassFromScratchWithDefaultOptimizer(self): + self._testFromScratchWithDefaultOptimizer(n_classes=4) + + def testFromScratchWithCustomRNNCellFn(self): + def train_input_fn(): + return { + 'tokens': + sparse_tensor.SparseTensor( + values=['the', 'cat', 'sat'], + indices=[[0, 0], [0, 1], [0, 2]], + dense_shape=[1, 3]), + }, [[1]] + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + input_units = 2 + cell_units = [4, 2] + n_classes = 2 + + def rnn_cell_fn(mode): + del mode # unused + cells = [rnn_cell.BasicRNNCell(num_units=n) for n in cell_units] + return rnn_cell.MultiRNNCell(cells) + + est = rnn.RNNClassifier( + sequence_feature_columns=[embed], + rnn_cell_fn=rnn_cell_fn, + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train(input_fn=train_input_fn, steps=num_steps) + self._assert_checkpoint(n_classes, input_units, cell_units, num_steps) + + def _testExampleWeight(self, n_classes): + def train_input_fn(): + return { + 'tokens': + sparse_tensor.SparseTensor( + values=['the', 'cat', 'sat', 'dog', 'barked'], + indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], + dense_shape=[2, 3]), + 'w': [[1], [2]], + }, [[1], [0]] + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + input_units = 2 + + cell_units = [4, 2] + est = rnn.RNNClassifier( + num_units=cell_units, + sequence_feature_columns=[embed], + n_classes=n_classes, + weight_column='w', + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train(input_fn=train_input_fn, steps=num_steps) + self._assert_checkpoint(n_classes, input_units, cell_units, num_steps) + + def testBinaryClassWithExampleWeight(self): + self._testExampleWeight(n_classes=2) + + def testMultiClassWithExampleWeight(self): + self._testExampleWeight(n_classes=4) + + def testBinaryClassFromCheckpoint(self): + initial_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=initial_global_step, + model_dir=self._model_dir) + + def train_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + }, [[0], [1]] + + # Uses same checkpoint and examples as testBinaryClassEvaluationMetrics. + # See that test for loss calculation. + mock_optimizer = self._mock_optimizer(expected_loss=1.119661) + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=2, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + est.train(input_fn=train_input_fn, steps=10) + self.assertEqual(1, mock_optimizer.minimize.call_count) + + def testMultiClassFromCheckpoint(self): + initial_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=initial_global_step, + model_dir=self._model_dir) + + def train_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2., 7.], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }, [[0], [1]] + + # Uses same checkpoint and examples as testMultiClassEvaluationMetrics. + # See that test for loss calculation. + mock_optimizer = self._mock_optimizer(expected_loss=2.662932) + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=3, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + est.train(input_fn=train_input_fn, steps=10) + self.assertEqual(1, mock_optimizer.minimize.call_count) + + +def sorted_key_dict(unsorted_dict): + return {k: unsorted_dict[k] for k in sorted(unsorted_dict)} + + +class RNNClassifierEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def testBinaryClassEvaluationMetrics(self): + global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=global_step, + model_dir=self._model_dir) + + def eval_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + }, [[0], [1]] + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=2, + model_dir=self._model_dir) + eval_metrics = est.evaluate(eval_input_fn, steps=1) + + # Uses identical numbers to testMultiExamplesWithDifferentLength. + # See that test for logits calculation. + # logits = [[-0.603282], [0.019719]] + # probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]] + # loss = -label * ln(p) - (1 - label) * ln(1 - p) + # = [[0.436326], [0.683335]] + expected_metrics = { + ops.GraphKeys.GLOBAL_STEP: global_step, + metric_keys.MetricKeys.LOSS: 1.119661, + metric_keys.MetricKeys.LOSS_MEAN: 0.559831, + metric_keys.MetricKeys.ACCURACY: 1.0, + metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262, + metric_keys.MetricKeys.LABEL_MEAN: 0.5, + metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, + # With default threshold of 0.5, the model is a perfect classifier. + metric_keys.MetricKeys.RECALL: 1.0, + metric_keys.MetricKeys.PRECISION: 1.0, + # Positive example is scored above negative, so AUC = 1.0. + metric_keys.MetricKeys.AUC: 1.0, + metric_keys.MetricKeys.AUC_PR: 1.0, + } + self.assertAllClose( + sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics)) + + def testMultiClassEvaluationMetrics(self): + global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=global_step, + model_dir=self._model_dir) + + def eval_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2., 7.], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }, [[0], [1]] + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=3, + model_dir=self._model_dir) + eval_metrics = est.evaluate(eval_input_fn, steps=1) + + # Uses identical numbers to testMultiExampleMultiDim. + # See that test for logits calculation. + # logits = [[-0.603282, 0.777708, 0.569756], + # [-1.247356, 1.017018, 0.574481]] + # logits_exp = exp(logits) / (1 + exp(logits)) + # = [[0.547013, 2.176468, 1.767836], + # [0.287263, 2.764937, 1.776208]] + # softmax_probabilities = logits_exp / logits_exp.sum() + # = [[0.121793, 0.484596, 0.393611], + # [0.059494, 0.572639, 0.367866]] + # loss = -1. * log(softmax[label]) + # = [[2.105432], [0.557500]] + expected_metrics = { + ops.GraphKeys.GLOBAL_STEP: global_step, + metric_keys.MetricKeys.LOSS: 2.662932, + metric_keys.MetricKeys.LOSS_MEAN: 1.331466, + metric_keys.MetricKeys.ACCURACY: 0.5, + } + + self.assertAllClose( + sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics)) + + +class RNNClassifierPredictionTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def testBinaryClassPredictions(self): + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=0, + model_dir=self._model_dir) + + def predict_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + label_vocabulary = ['class_0', 'class_1'] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=2, + label_vocabulary=label_vocabulary, + model_dir=self._model_dir) + # Uses identical numbers to testOneDimLogits. + # See that test for logits calculation. + # logits = [-0.603282] + # logistic = exp(-0.6033) / (1 + exp(-0.6033)) = [0.353593] + # probabilities = [0.646407, 0.353593] + # class_ids = argmax(probabilities) = [0] + predictions = next(est.predict(predict_input_fn)) + self.assertAllClose([-0.603282], + predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose([0.353593], + predictions[prediction_keys.PredictionKeys.LOGISTIC]) + self.assertAllClose( + [0.646407, 0.353593], + predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose([0], + predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertEqual([b'class_0'], + predictions[prediction_keys.PredictionKeys.CLASSES]) + + def testMultiClassPredictions(self): + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=0, + model_dir=self._model_dir) + + def predict_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + label_vocabulary = ['class_0', 'class_1', 'class_2'] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=3, + label_vocabulary=label_vocabulary, + model_dir=self._model_dir) + # Uses identical numbers to testMultiDimLogits. + # See that test for logits calculation. + # logits = [-0.603282, 0.777708, 0.569756] + # logits_exp = exp(logits) = [0.547013, 2.176468, 1.767836] + # softmax_probabilities = logits_exp / logits_exp.sum() + # = [0.121793, 0.484596, 0.393611] + # class_ids = argmax(probabilities) = [1] + predictions = next(est.predict(predict_input_fn)) + self.assertAllClose([-0.603282, 0.777708, 0.569756], + predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose( + [0.121793, 0.484596, 0.393611], + predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose([1], + predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertEqual([b'class_1'], + predictions[prediction_keys.PredictionKeys.CLASSES]) + + +class RNNClassifierIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, n_classes, + batch_size): + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + + cell_units = [4, 2] + est = rnn.RNNClassifier( + num_units=cell_units, + sequence_feature_columns=feature_columns, + n_classes=n_classes, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUATE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PROBABILITIES] + for x in est.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) + + # EXPORT + feature_spec = { + 'tokens': parsing_ops.VarLenFeature(dtypes.string), + 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), + } + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def testNumpyInputFn(self): + """Tests complete flow with numpy_input_fn.""" + n_classes = 3 + batch_size = 10 + words = ['dog', 'cat', 'bird', 'the', 'a', 'sat', 'flew', 'slept'] + # Numpy only supports dense input, so all examples will have same length. + # TODO(b/73160931): Update test when support for prepadded data exists. + sequence_length = 3 + + features = [] + for _ in range(batch_size): + sentence = random.sample(words, sequence_length) + features.append(sentence) + + x_data = np.array(features) + y_data = np.random.randint(n_classes, size=batch_size) + + train_input_fn = numpy_io.numpy_input_fn( + x={'tokens': x_data}, + y=y_data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'tokens': x_data}, + y=y_data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'tokens': x_data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + n_classes=n_classes, + batch_size=batch_size) + + def testParseExampleInputFn(self): + """Tests complete flow with input_fn constructed from parse_example.""" + n_classes = 3 + batch_size = 10 + words = [b'dog', b'cat', b'bird', b'the', b'a', b'sat', b'flew', b'slept'] + + serialized_examples = [] + for _ in range(batch_size): + sequence_length = random.randint(1, len(words)) + sentence = random.sample(words, sequence_length) + label = random.randint(0, n_classes - 1) + example = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'tokens': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=sentence)), + 'label': + feature_pb2.Feature(int64_list=feature_pb2.Int64List( + value=[label])), + })) + serialized_examples.append(example.SerializeToString()) + + feature_spec = { + 'tokens': parsing_ops.VarLenFeature(dtypes.string), + 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), + } + def _train_input_fn(): + features = parsing_ops.parse_example(serialized_examples, feature_spec) + labels = features.pop('label') + return features, labels + def _eval_input_fn(): + features = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + labels = features.pop('label') + return features, labels + def _predict_input_fn(): + features = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + features.pop('label') + return features, None + + self._test_complete_flow( + train_input_fn=_train_input_fn, + eval_input_fn=_eval_input_fn, + predict_input_fn=_predict_input_fn, + n_classes=n_classes, + batch_size=batch_size) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index fe380c4..cbc2dcf 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -1206,7 +1206,16 @@ class DeviceWrapper(RNNCell): @tf_export("nn.rnn_cell.MultiRNNCell") class MultiRNNCell(RNNCell): - """RNN cell composed sequentially of multiple simple cells.""" + """RNN cell composed sequentially of multiple simple cells. + + Example: + + ```python + num_units = [128, 64] + cells = [BasicLSTMCell(num_units=n) for n in num_units] + stacked_rnn_cell = MultiRNNCell(cells) + ``` + """ def __init__(self, cells, state_is_tuple=True): """Create a RNN cell composed sequentially of a number of RNNCells. -- 2.7.4