Improve errors raised when an object does not match the RNNCell interface.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sun, 11 Mar 2018 17:00:02 +0000 (10:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 11 Mar 2018 17:04:13 +0000 (10:04 -0700)
PiperOrigin-RevId: 188651070

tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
tensorflow/contrib/rnn/python/ops/core_rnn_cell.py
tensorflow/contrib/rnn/python/ops/rnn_cell.py
tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
tensorflow/contrib/seq2seq/python/ops/basic_decoder.py
tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
tensorflow/python/ops/rnn.py
tensorflow/python/ops/rnn_cell_impl.py

index 7de55a0..69f7b8e 100644 (file)
@@ -455,8 +455,8 @@ class RNNCellTest(test.TestCase):
         self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
 
   def testAttentionCellWrapperFailures(self):
-    with self.assertRaisesRegexp(TypeError,
-                                 "The parameter cell is not RNNCell."):
+    with self.assertRaisesRegexp(
+        TypeError, rnn_cell_impl.ASSERT_LIKE_RNNCELL_ERROR_REGEXP):
       contrib_rnn_cell.AttentionCellWrapper(None, 0)
 
     num_units = 8
@@ -1203,7 +1203,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
         h1 = array_ops.zeros([1, 2])
         state1 = rnn_cell.LSTMStateTuple(c1, h1)
         state = (state0, state1)
-        single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False)
+        single_cell = lambda: contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False)  # pylint: disable=line-too-long
         cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
         g, out_m = cell(x, state)
         sess.run([variables.global_variables_initializer()])
@@ -1235,7 +1235,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
         self.assertAllClose(expected_state1_h, actual_state1_h, 1e-5)
 
       with variable_scope.variable_scope(
-          "other", initializer=init_ops.constant_initializer(0.5)) as vs:
+          "other", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros(
             [1, 3])  # Test BasicLSTMCell with input_size != num_units.
         c = array_ops.zeros([1, 2])
index 8109ebc..645f826 100644 (file)
@@ -40,7 +40,6 @@ from tensorflow.python.util import nest
 
 # pylint: disable=protected-access,invalid-name
 RNNCell = rnn_cell_impl.RNNCell
-_like_rnncell = rnn_cell_impl._like_rnncell
 _WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME
 _BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME
 # pylint: enable=protected-access,invalid-name
@@ -221,8 +220,7 @@ class EmbeddingWrapper(RNNCell):
       ValueError: if embedding_classes is not positive.
     """
     super(EmbeddingWrapper, self).__init__(_reuse=reuse)
-    if not _like_rnncell(cell):
-      raise TypeError("The parameter cell is not RNNCell.")
+    rnn_cell_impl.assert_like_rnncell("cell", cell)
     if embedding_classes <= 0 or embedding_size <= 0:
       raise ValueError("Both embedding_classes and embedding_size must be > 0: "
                        "%d, %d." % (embedding_classes, embedding_size))
@@ -301,8 +299,7 @@ class InputProjectionWrapper(RNNCell):
     super(InputProjectionWrapper, self).__init__(_reuse=reuse)
     if input_size is not None:
       logging.warn("%s: The input_size parameter is deprecated.", self)
-    if not _like_rnncell(cell):
-      raise TypeError("The parameter cell is not RNNCell.")
+    rnn_cell_impl.assert_like_rnncell("cell", cell)
     self._cell = cell
     self._num_proj = num_proj
     self._activation = activation
@@ -356,8 +353,7 @@ class OutputProjectionWrapper(RNNCell):
       ValueError: if output_size is not positive.
     """
     super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
-    if not _like_rnncell(cell):
-      raise TypeError("The parameter cell is not RNNCell.")
+    rnn_cell_impl.assert_like_rnncell("cell", cell)
     if output_size < 1:
       raise ValueError("Parameter output_size must be > 0: %d." % output_size)
     self._cell = cell
index 6bea8d4..3028eda 100644 (file)
@@ -1143,8 +1143,7 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
           `state_is_tuple` is `False` or if attn_length is zero or less.
     """
     super(AttentionCellWrapper, self).__init__(_reuse=reuse)
-    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
-      raise TypeError("The parameter cell is not RNNCell.")
+    rnn_cell_impl.assert_like_rnncell("cell", cell)
     if nest.is_sequence(cell.state_size) and not state_is_tuple:
       raise ValueError(
           "Cell returns tuple of states, but the flag "
index 0a53fd6..f8da5a3 100644 (file)
@@ -1152,9 +1152,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
         is a list, and its length does not match that of `attention_layer_size`.
     """
     super(AttentionWrapper, self).__init__(name=name)
-    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
-      raise TypeError(
-          "cell must be an RNNCell, saw type: %s" % type(cell).__name__)
+    rnn_cell_impl.assert_like_rnncell("cell", cell)
     if isinstance(attention_mechanism, (list, tuple)):
       self._is_multi = True
       attention_mechanisms = attention_mechanism
index ed22623..7eb95e5 100644 (file)
@@ -59,8 +59,7 @@ class BasicDecoder(decoder.Decoder):
     Raises:
       TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
     """
-    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
-      raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
+    rnn_cell_impl.assert_like_rnncell("cell", cell)
     if not isinstance(helper, helper_py.Helper):
       raise TypeError("helper must be a Helper, received: %s" % type(helper))
     if (output_layer is not None
index d6184d6..22dc7f2 100644 (file)
@@ -195,8 +195,7 @@ class BeamSearchDecoder(decoder.Decoder):
       ValueError: If `start_tokens` is not a vector or
         `end_token` is not a scalar.
     """
-    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
-      raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
+    rnn_cell_impl.assert_like_rnncell("cell", cell)  # pylint: disable=protected-access
     if (output_layer is not None and
         not isinstance(output_layer, layers_base.Layer)):
       raise TypeError(
index 625d433..c59eccc 100644 (file)
@@ -45,7 +45,6 @@ from tensorflow.python.util.tf_export import tf_export
 
 # pylint: disable=protected-access
 _concat = rnn_cell_impl._concat
-_like_rnncell = rnn_cell_impl._like_rnncell
 # pylint: enable=protected-access
 
 
@@ -403,11 +402,8 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
   Raises:
     TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
   """
-
-  if not _like_rnncell(cell_fw):
-    raise TypeError("cell_fw must be an instance of RNNCell")
-  if not _like_rnncell(cell_bw):
-    raise TypeError("cell_bw must be an instance of RNNCell")
+  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
+  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
 
   with vs.variable_scope(scope or "bidirectional_rnn"):
     # Forward direction
@@ -568,8 +564,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
     TypeError: If `cell` is not an instance of RNNCell.
     ValueError: If inputs is None or an empty list.
   """
-  if not _like_rnncell(cell):
-    raise TypeError("cell must be an instance of RNNCell")
+  rnn_cell_impl.assert_like_rnncell("cell", cell)
 
   with vs.variable_scope(scope or "rnn") as varscope:
     # Create a new scope in which the caching device is either
@@ -1015,9 +1010,8 @@ def raw_rnn(cell, loop_fn,
     TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not
       a `callable`.
   """
+  rnn_cell_impl.assert_like_rnncell("cell", cell)
 
-  if not _like_rnncell(cell):
-    raise TypeError("cell must be an instance of RNNCell")
   if not callable(loop_fn):
     raise TypeError("loop_fn must be a callable")
 
@@ -1229,9 +1223,7 @@ def static_rnn(cell,
     ValueError: If `inputs` is `None` or an empty list, or if the input depth
       (column size) cannot be inferred from inputs via shape inference.
   """
-
-  if not _like_rnncell(cell):
-    raise TypeError("cell must be an instance of RNNCell")
+  rnn_cell_impl.assert_like_rnncell("cell", cell)
   if not nest.is_sequence(inputs):
     raise TypeError("inputs must be a sequence")
   if not inputs:
@@ -1469,11 +1461,8 @@ def static_bidirectional_rnn(cell_fw,
     TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
     ValueError: If inputs is None or an empty list.
   """
-
-  if not _like_rnncell(cell_fw):
-    raise TypeError("cell_fw must be an instance of RNNCell")
-  if not _like_rnncell(cell_bw):
-    raise TypeError("cell_bw must be an instance of RNNCell")
+  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
+  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
   if not nest.is_sequence(inputs):
     raise TypeError("inputs must be a sequence")
   if not inputs:
index e61d108..fe380c4 100644 (file)
@@ -55,6 +55,8 @@ _BIAS_VARIABLE_NAME = "bias"
 _WEIGHTS_VARIABLE_NAME = "kernel"
 
 
+# TODO(jblespiau): Remove this function when we are sure there are no longer
+# any usage (even if protected, it is being used). Prefer assert_like_rnncell.
 def _like_rnncell(cell):
   """Checks that a given object is an RNNCell by using duck typing."""
   conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
@@ -62,6 +64,45 @@ def _like_rnncell(cell):
   return all(conditions)
 
 
+# This can be used with self.assertRaisesRegexp for assert_like_rnncell.
+ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell"
+
+
+def assert_like_rnncell(cell_name, cell):
+  """Raises a TypeError if cell is not like an RNNCell.
+
+  NOTE: Do not rely on the error message (in particular in tests) which can be
+  subject to change to increase readability. Use
+  ASSERT_LIKE_RNNCELL_ERROR_REGEXP.
+
+  Args:
+    cell_name: A string to give a meaningful error referencing to the name
+      of the functionargument.
+    cell: The object which should behave like an RNNCell.
+
+  Raises:
+    TypeError: A human-friendly exception.
+  """
+  conditions = [
+      hasattr(cell, "output_size"),
+      hasattr(cell, "state_size"),
+      hasattr(cell, "zero_state"),
+      callable(cell),
+  ]
+  errors = [
+      "'output_size' property is missing",
+      "'state_size' property is missing",
+      "'zero_state' method is missing",
+      "is not callable"
+  ]
+
+  if not all(conditions):
+
+    errors = [error for error, cond in zip(errors, conditions) if not cond]
+    raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format(
+        cell_name, cell, ", ".join(errors)))
+
+
 def _concat(prefix, suffix, static=False):
   """Concat that enables int, Tensor, or TensorShape values.
 
@@ -914,8 +955,8 @@ class DropoutWrapper(RNNCell):
         but not `callable`.
       ValueError: if any of the keep_probs are not between 0 and 1.
     """
-    if not _like_rnncell(cell):
-      raise TypeError("The parameter cell is not a RNNCell.")
+    assert_like_rnncell("cell", cell)
+
     if (dropout_state_filter_visitor is not None
         and not callable(dropout_state_filter_visitor)):
       raise TypeError("dropout_state_filter_visitor must be callable")