Moving gradient registration for CudnnRNN op from contrib to core.
authorPavithra Vijay <psv@google.com>
Tue, 17 Apr 2018 19:06:50 +0000 (12:06 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 17 Apr 2018 19:09:34 +0000 (12:09 -0700)
PiperOrigin-RevId: 193234663

tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
tensorflow/python/BUILD
tensorflow/python/ops/cudnn_rnn_grad.py [new file with mode: 0644]
tensorflow/python/ops/standard_ops.py

index c28c3a1..b615824 100644 (file)
@@ -1640,31 +1640,6 @@ class CudnnRNNRelu(_CudnnRNNNoInputC):
   _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER
 
 
-@ops.RegisterGradient("CudnnRNN")
-def _cudnn_rnn_backward(op, *grad):
-  if not op.get_attr("is_training"):
-    raise ValueError(
-        "CudnnRNN must set is_training to True to be used in gradients")
-  return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
-      input=op.inputs[0],
-      input_h=op.inputs[1],
-      input_c=op.inputs[2],
-      params=op.inputs[3],
-      output=op.outputs[0],
-      output_h=op.outputs[1],
-      output_c=op.outputs[2],
-      output_backprop=grad[0],
-      output_h_backprop=grad[1],
-      output_c_backprop=grad[2],
-      reserve_space=op.outputs[3],
-      dropout=op.get_attr("dropout"),
-      seed=op.get_attr("seed"),
-      seed2=op.get_attr("seed2"),
-      rnn_mode=op.get_attr("rnn_mode"),
-      input_mode=op.get_attr("input_mode"),
-      direction=op.get_attr("direction"))
-
-
 ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn)
 ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn)
 ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn)
index 14ce8a5..569d3eb 100644 (file)
@@ -1793,6 +1793,16 @@ py_library(
 )
 
 py_library(
+    name = "cudnn_rnn_grad",
+    srcs = ["ops/cudnn_rnn_grad.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":framework_for_generated_wrappers",
+        "//tensorflow/python:cudnn_rnn_ops_gen",
+    ],
+)
+
+py_library(
     name = "data_flow_grad",
     srcs = ["ops/data_flow_grad.py"],
     srcs_version = "PY2AND3",
@@ -2465,6 +2475,7 @@ py_library(
         ":clip_ops",
         ":confusion_matrix",
         ":control_flow_ops",
+        ":cudnn_rnn_grad",
         ":data_flow_grad",
         ":data_flow_ops",
         ":framework_for_generated_wrappers",
diff --git a/tensorflow/python/ops/cudnn_rnn_grad.py b/tensorflow/python/ops/cudnn_rnn_grad.py
new file mode 100644 (file)
index 0000000..97331bb
--- /dev/null
@@ -0,0 +1,47 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gradients for CuudnnRNN operators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_cudnn_rnn_ops
+
+
+@ops.RegisterGradient("CudnnRNN")
+def _cudnn_rnn_backward(op, *grads):
+  """Gradients for the CudnnRNN op."""
+  if not op.get_attr("is_training"):
+    raise ValueError(
+        "CudnnRNN must set is_training to True to be used in gradients")
+  return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
+      input=op.inputs[0],
+      input_h=op.inputs[1],
+      input_c=op.inputs[2],
+      params=op.inputs[3],
+      output=op.outputs[0],
+      output_h=op.outputs[1],
+      output_c=op.outputs[2],
+      output_backprop=grads[0],
+      output_h_backprop=grads[1],
+      output_c_backprop=grads[2],
+      reserve_space=op.outputs[3],
+      dropout=op.get_attr("dropout"),
+      seed=op.get_attr("seed"),
+      seed2=op.get_attr("seed2"),
+      rnn_mode=op.get_attr("rnn_mode"),
+      input_mode=op.get_attr("input_mode"),
+      direction=op.get_attr("direction"))
index e90ff07..f71f98a 100644 (file)
@@ -22,12 +22,13 @@ from __future__ import print_function
 
 import sys as _sys
 
+# pylint: disable=g-bad-import-order
 # Imports the following modules so that @RegisterGradient get executed.
 from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import cudnn_rnn_grad
 from tensorflow.python.ops import data_flow_grad
 from tensorflow.python.ops import manip_grad
 from tensorflow.python.ops import math_grad
-from tensorflow.python.ops import manip_grad
 from tensorflow.python.ops import sparse_grad
 from tensorflow.python.ops import spectral_grad
 from tensorflow.python.ops import state_grad
@@ -96,6 +97,7 @@ from tensorflow.python.ops.tensor_array_ops import *
 from tensorflow.python.ops.variable_scope import *
 from tensorflow.python.ops.variables import *
 # pylint: enable=wildcard-import
+# pylint: enable=g-bad-import-order
 
 #### For use in remove_undocumented below:
 from tensorflow.python.framework import constant_op as _constant_op