_NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER
-@ops.RegisterGradient("CudnnRNN")
-def _cudnn_rnn_backward(op, *grad):
- if not op.get_attr("is_training"):
- raise ValueError(
- "CudnnRNN must set is_training to True to be used in gradients")
- return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
- input=op.inputs[0],
- input_h=op.inputs[1],
- input_c=op.inputs[2],
- params=op.inputs[3],
- output=op.outputs[0],
- output_h=op.outputs[1],
- output_c=op.outputs[2],
- output_backprop=grad[0],
- output_h_backprop=grad[1],
- output_c_backprop=grad[2],
- reserve_space=op.outputs[3],
- dropout=op.get_attr("dropout"),
- seed=op.get_attr("seed"),
- seed2=op.get_attr("seed2"),
- rnn_mode=op.get_attr("rnn_mode"),
- input_mode=op.get_attr("input_mode"),
- direction=op.get_attr("direction"))
-
-
ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn)
)
py_library(
+ name = "cudnn_rnn_grad",
+ srcs = ["ops/cudnn_rnn_grad.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_for_generated_wrappers",
+ "//tensorflow/python:cudnn_rnn_ops_gen",
+ ],
+)
+
+py_library(
name = "data_flow_grad",
srcs = ["ops/data_flow_grad.py"],
srcs_version = "PY2AND3",
":clip_ops",
":confusion_matrix",
":control_flow_ops",
+ ":cudnn_rnn_grad",
":data_flow_grad",
":data_flow_ops",
":framework_for_generated_wrappers",
--- /dev/null
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gradients for CuudnnRNN operators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_cudnn_rnn_ops
+
+
+@ops.RegisterGradient("CudnnRNN")
+def _cudnn_rnn_backward(op, *grads):
+ """Gradients for the CudnnRNN op."""
+ if not op.get_attr("is_training"):
+ raise ValueError(
+ "CudnnRNN must set is_training to True to be used in gradients")
+ return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
+ input=op.inputs[0],
+ input_h=op.inputs[1],
+ input_c=op.inputs[2],
+ params=op.inputs[3],
+ output=op.outputs[0],
+ output_h=op.outputs[1],
+ output_c=op.outputs[2],
+ output_backprop=grads[0],
+ output_h_backprop=grads[1],
+ output_c_backprop=grads[2],
+ reserve_space=op.outputs[3],
+ dropout=op.get_attr("dropout"),
+ seed=op.get_attr("seed"),
+ seed2=op.get_attr("seed2"),
+ rnn_mode=op.get_attr("rnn_mode"),
+ input_mode=op.get_attr("input_mode"),
+ direction=op.get_attr("direction"))
import sys as _sys
+# pylint: disable=g-bad-import-order
# Imports the following modules so that @RegisterGradient get executed.
from tensorflow.python.ops import array_grad
+from tensorflow.python.ops import cudnn_rnn_grad
from tensorflow.python.ops import data_flow_grad
from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import math_grad
-from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
from tensorflow.python.ops import state_grad
from tensorflow.python.ops.variable_scope import *
from tensorflow.python.ops.variables import *
# pylint: enable=wildcard-import
+# pylint: enable=g-bad-import-order
#### For use in remove_undocumented below:
from tensorflow.python.framework import constant_op as _constant_op