[TF contrib RNN] Expose some rnn classes and functionality in contrib.
authorEugene Brevdo <ebrevdo@google.com>
Thu, 8 Feb 2018 23:01:49 +0000 (15:01 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Feb 2018 23:11:31 +0000 (15:11 -0800)
PiperOrigin-RevId: 185057994

tensorflow/contrib/rnn/__init__.py
tensorflow/contrib/rnn/python/ops/gru_ops.py
tensorflow/contrib/rnn/python/ops/lstm_ops.py
tensorflow/contrib/rnn/python/ops/rnn_cell.py
tensorflow/python/ops/rnn.py
tensorflow/python/ops/rnn_cell_impl.py
tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt

index c568c67..67f3178 100644 (file)
@@ -18,6 +18,7 @@ See @{$python/contrib.rnn} guide.
 
 <!--From core-->
 @@RNNCell
+@@LayerRNNCell
 @@BasicRNNCell
 @@BasicLSTMCell
 @@GRUCell
@@ -68,6 +69,10 @@ See @{$python/contrib.rnn} guide.
 @@static_bidirectional_rnn
 @@stack_bidirectional_dynamic_rnn
 @@stack_bidirectional_rnn
+
+<!--RNN utilities-->
+@@transpose_batch_time
+@@best_effort_input_batch_size
 """
 
 from __future__ import absolute_import
@@ -85,6 +90,8 @@ from tensorflow.contrib.rnn.python.ops.lstm_ops import *
 from tensorflow.contrib.rnn.python.ops.rnn import *
 from tensorflow.contrib.rnn.python.ops.rnn_cell import *
 
+from tensorflow.python.ops.rnn import _best_effort_input_batch_size as best_effort_input_batch_size
+from tensorflow.python.ops.rnn import _transpose_batch_time as transpose_batch_time
 from tensorflow.python.ops.rnn import static_bidirectional_rnn
 from tensorflow.python.ops.rnn import static_rnn
 from tensorflow.python.ops.rnn import static_state_saving_rnn
index 4c964ec..81ca123 100644 (file)
@@ -32,7 +32,7 @@ from tensorflow.python.util.deprecation import deprecated_args
 _gru_ops_so = loader.load_op_library(
     resource_loader.get_path_to_datafile("_gru_ops.so"))
 
-LayerRNNCell = rnn_cell_impl._LayerRNNCell  # pylint: disable=invalid-name,protected-access
+LayerRNNCell = rnn_cell_impl.LayerRNNCell  # pylint: disable=invalid-name
 
 
 @ops.RegisterGradient("GRUBlockCell")
index 04f342c..f700717 100644 (file)
@@ -34,7 +34,7 @@ from tensorflow.python.platform import resource_loader
 _lstm_ops_so = loader.load_op_library(
     resource_loader.get_path_to_datafile("_lstm_ops.so"))
 
-LayerRNNCell = rnn_cell_impl._LayerRNNCell  # pylint: disable=invalid-name,protected-access
+LayerRNNCell = rnn_cell_impl.LayerRNNCell  # pylint: disable=invalid-name
 
 
 # pylint: disable=invalid-name
index fe07493..dce71c3 100644 (file)
@@ -2682,7 +2682,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
     return m, new_state
 
 
-class SRUCell(rnn_cell_impl._LayerRNNCell):
+class SRUCell(rnn_cell_impl.LayerRNNCell):
   """SRU, Simple Recurrent Unit
 
      Implementation based on
index da80e72..aa8d432 100644 (file)
@@ -83,8 +83,9 @@ def _best_effort_input_batch_size(flat_input):
   """Get static input batch size if available, with fallback to the dynamic one.
 
   Args:
-    flat_input: An iterable of time major input Tensors of shape [max_time,
-      batch_size, ...]. All inputs should have compatible batch sizes.
+    flat_input: An iterable of time major input Tensors of shape
+      `[max_time, batch_size, ...]`.
+    All inputs should have compatible batch sizes.
 
   Returns:
     The batch size in Python integer if available, or a scalar Tensor otherwise.
index f1ac3e9..923348e 100644 (file)
@@ -255,7 +255,7 @@ class RNNCell(base_layer.Layer):
     return output
 
 
-class _LayerRNNCell(RNNCell):
+class LayerRNNCell(RNNCell):
   """Subclass of RNNCells that act like proper `tf.Layer` objects.
 
   For backwards compatibility purposes, most `RNNCell` instances allow their
@@ -297,7 +297,7 @@ class _LayerRNNCell(RNNCell):
 
 
 @tf_export("nn.rnn_cell.BasicRNNCell")
-class BasicRNNCell(_LayerRNNCell):
+class BasicRNNCell(LayerRNNCell):
   """The most basic RNN cell.
 
   Args:
@@ -355,7 +355,7 @@ class BasicRNNCell(_LayerRNNCell):
 
 
 @tf_export("nn.rnn_cell.GRUCell")
-class GRUCell(_LayerRNNCell):
+class GRUCell(LayerRNNCell):
   """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
 
   Args:
@@ -473,7 +473,7 @@ class LSTMStateTuple(_LSTMStateTuple):
 
 
 @tf_export("nn.rnn_cell.BasicLSTMCell")
-class BasicLSTMCell(_LayerRNNCell):
+class BasicLSTMCell(LayerRNNCell):
   """Basic LSTM recurrent network cell.
 
   The implementation is based on: http://arxiv.org/abs/1409.2329.
@@ -598,7 +598,7 @@ class BasicLSTMCell(_LayerRNNCell):
 
 
 @tf_export("nn.rnn_cell.LSTMCell")
-class LSTMCell(_LayerRNNCell):
+class LSTMCell(LayerRNNCell):
   """Long short-term memory unit (LSTM) recurrent network cell.
 
   The default non-peephole implementation is based on:
index a2e728f..4453678 100644 (file)
@@ -1,7 +1,7 @@
 path: "tensorflow.nn.rnn_cell.BasicLSTMCell"
 tf_class {
   is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell\'>"
-  is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._LayerRNNCell\'>"
+  is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LayerRNNCell\'>"
   is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
   is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
   is_instance: "<type \'object\'>"
index 4211faa..768565d 100644 (file)
@@ -1,7 +1,7 @@
 path: "tensorflow.nn.rnn_cell.BasicRNNCell"
 tf_class {
   is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicRNNCell\'>"
-  is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._LayerRNNCell\'>"
+  is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LayerRNNCell\'>"
   is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
   is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
   is_instance: "<type \'object\'>"
index 06fdc63..6ecc134 100644 (file)
@@ -1,7 +1,7 @@
 path: "tensorflow.nn.rnn_cell.GRUCell"
 tf_class {
   is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.GRUCell\'>"
-  is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._LayerRNNCell\'>"
+  is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LayerRNNCell\'>"
   is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
   is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
   is_instance: "<type \'object\'>"
index ef48cff..4b3ca15 100644 (file)
@@ -1,7 +1,7 @@
 path: "tensorflow.nn.rnn_cell.LSTMCell"
 tf_class {
   is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LSTMCell\'>"
-  is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._LayerRNNCell\'>"
+  is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LayerRNNCell\'>"
   is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
   is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
   is_instance: "<type \'object\'>"