From 26e36ec2c9fb061e7349b2259bc69b2140d18819 Mon Sep 17 00:00:00 2001 From: Patrick Nguyen Date: Mon, 9 Apr 2018 16:55:06 -0700 Subject: [PATCH] Export recurrent and its RNN implementation in tf.contrib. PiperOrigin-RevId: 192210794 --- tensorflow/contrib/BUILD | 1 + tensorflow/contrib/__init__.py | 1 + tensorflow/contrib/cmake/python_modules.txt | 4 + tensorflow/contrib/recurrent/BUILD | 106 +++ tensorflow/contrib/recurrent/README.md | 13 + .../python/kernel_tests/functional_rnn_test.py | 163 +++++ .../python/kernel_tests/recurrent_test.py | 192 ++++++ .../contrib/recurrent/python/ops/functional_rnn.py | 396 ++++++++++++ .../contrib/recurrent/python/ops/recurrent.py | 720 +++++++++++++++++++++ .../contrib/recurrent/python/recurrent_api.py | 29 + 10 files changed, 1625 insertions(+) create mode 100644 tensorflow/contrib/recurrent/BUILD create mode 100644 tensorflow/contrib/recurrent/README.md create mode 100644 tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py create mode 100644 tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py create mode 100644 tensorflow/contrib/recurrent/python/ops/functional_rnn.py create mode 100644 tensorflow/contrib/recurrent/python/ops/recurrent.py create mode 100644 tensorflow/contrib/recurrent/python/recurrent_api.py diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index bf69144..9bef0d8 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -81,6 +81,7 @@ py_library( "//tensorflow/contrib/quantize:quantize_graph", "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", + "//tensorflow/contrib/recurrent:recurrent_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/resampler:resampler_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 1c5b00f..aaddb06 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -66,6 +66,7 @@ from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor from tensorflow.contrib import quantization from tensorflow.contrib import quantize +from tensorflow.contrib import recurrent from tensorflow.contrib import reduce_slice_ops from tensorflow.contrib import resampler from tensorflow.contrib import rnn diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index b786c6d..340be61 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -367,6 +367,10 @@ tensorflow/contrib/receptive_field tensorflow/contrib/receptive_field/python tensorflow/contrib/receptive_field/python/util tensorflow/contrib/receptive_field/python/util/examples +tensorflow/contrib/recurrent +tensorflow/contrib/recurrent/python +tensorflow/contrib/recurrent/python/ops +tensorflow/contrib/recurrent/python/kernel_tests tensorflow/contrib/reduce_slice_ops tensorflow/contrib/reduce_slice_ops/kernels tensorflow/contrib/reduce_slice_ops/ops diff --git a/tensorflow/contrib/recurrent/BUILD b/tensorflow/contrib/recurrent/BUILD new file mode 100644 index 0000000..b3cb04c --- /dev/null +++ b/tensorflow/contrib/recurrent/BUILD @@ -0,0 +1,106 @@ +# Recurrent library. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_tests") + +py_library( + name = "recurrent_py", + srcs = ["python/recurrent_api.py"], + srcs_version = "PY2AND3", + deps = [ + ":functional_rnn_ops_py", + ":recurrent_ops_py", + ], +) + +py_library( + name = "recurrent_ops_py", + srcs = ["python/ops/recurrent.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + ], +) + +py_library( + name = "functional_rnn_ops_py", + srcs = ["python/ops/functional_rnn.py"], + srcs_version = "PY2AND3", + deps = [ + ":recurrent_ops_py", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:standard_ops", + ], +) + +cuda_py_tests( + name = "recurrent_ops_test", + size = "small", + srcs = ["python/kernel_tests/recurrent_test.py"], + additional_deps = [ + ":recurrent_ops_py", + "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:script_ops", + "//tensorflow/python:variables", + ], + tags = ["nopip"], +) + +cuda_py_tests( + name = "functional_rnn_ops_test", + size = "small", + srcs = ["python/kernel_tests/functional_rnn_test.py"], + additional_deps = [ + ":functional_rnn_ops_py", + "//third_party/py/numpy", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/tpu:tpu", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", + "//tensorflow/python:rnn_cell", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], + tags = ["nopip"], +) diff --git a/tensorflow/contrib/recurrent/README.md b/tensorflow/contrib/recurrent/README.md new file mode 100644 index 0000000..86e10ee --- /dev/null +++ b/tensorflow/contrib/recurrent/README.md @@ -0,0 +1,13 @@ +# Recurrent computation library + +The recurrent computation library contains code to perform recurrent +computations. + +Its chief application is to implement recurrent neural networks (RNNs, LSTMs, +etc), which is implemented in `functional_rnn.py`. Similar techniques may be +used to implement deep networks. + +The computation saves the activations in the forward pass, and computes the +gradients in the backward pass using a single accumulator. + +The `functional_rnn` interface is compatible with the `dynamic_rnn` API. diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py new file mode 100644 index 0000000..0f19ac7 --- /dev/null +++ b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py @@ -0,0 +1,163 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Functional RNN.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +from tensorflow.contrib.recurrent.python.ops import functional_rnn +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import rnn as rnn_lib +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import variables +import tensorflow.python.ops.nn_grad # pylint: disable=unused-import +import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import +from tensorflow.python.platform import test as test_lib +from tensorflow.python.platform import tf_logging as logging + + +def _CreateStackedLstmCell(*cell_sizes): + subcells = [rnn_cell_impl.LSTMCell(cell_size) for cell_size in cell_sizes] + return rnn_cell_impl.MultiRNNCell(subcells) + + +class FunctionalRnnTest(test_util.TensorFlowTestCase): + + _BATCH_SIZE = 3 + _TOTAL_TIME = 5 + _INPUT_SIZE = 11 + _NUM_UNITS = 7 + + # Set this to some output if you want to use it. + _LSTM_GRAPH_DEF_FILEPATH = None + + _CELLDEFS = { + 'gru': (rnn_cell_impl.GRUCell, [_NUM_UNITS]), + 'lstm': (rnn_cell_impl.LSTMCell, [_NUM_UNITS]), + 'stacked_lstm': (_CreateStackedLstmCell, [_NUM_UNITS] * 3) + } + + def _CreateCell(self, celldef_name): + func, args = self._CELLDEFS[celldef_name] + return func(*args) + + def _CreateInputs(self): + inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE, + FunctionalRnnTest._TOTAL_TIME, + FunctionalRnnTest._INPUT_SIZE]) + # Always leave one time slot empty, to check max_length behavior. + sequence_length = np.random.randint( + 0, high=FunctionalRnnTest._TOTAL_TIME - 1, + size=FunctionalRnnTest._BATCH_SIZE, + dtype=np.int) + return (inputs, sequence_length) + + def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs, + tf_sequence_length, initial_state=None, + time_major=None, scope=None): + tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs, + sequence_length=tf_sequence_length, + initial_state=initial_state, + dtype=dtypes.float32, + time_major=time_major, + scope=scope) + grad = gradients_impl.gradients(tf_result, variables.trainable_variables()) + return {'inference': tf_result, 'grad': grad} + + def _MaybeResetVariables(self, variable_cache, sess, var_list): + """Possibly resets the variables to a previously seen value.""" + reset_ops = [] + fetches = [] + for var in var_list: + if var.name in variable_cache: + reset_ops += [var.assign(variable_cache[var.name])] + else: + fetches += [(var.name, var)] + if reset_ops: + sess.run(reset_ops) + if fetches: + val = sess.run(dict(fetches)) + for n, v in val.items(): + assert n not in variable_cache + variable_cache[n] = v + + def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache, + is_dynamic): + with ops.Graph().as_default() as graph: + tf_inputs = array_ops.placeholder( + dtypes.float32, shape=numpy_inputs.shape) + tf_slen = array_ops.placeholder(dtypes.int32) + feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen} + cell = self._CreateCell(cell_name) + fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn + fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen) + with self.test_session(graph=graph) as sess: + sess.run(variables.global_variables_initializer()) + # Note that cell.trainable_variables it not always set. + self._MaybeResetVariables(variable_cache, sess, + variables.trainable_variables()) + val = sess.run(fetches, feed_dict=feeds) + graph_def = graph.as_graph_def() + return graph_def, val + + def testRunLstm(self): + """Runs a simple LSTM. Does not check output.""" + np_inputs, np_slen = self._CreateInputs() + var_cache = {} + graphdef, _ = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False) + logging.info('graphdef: %s', graphdef) + if self._LSTM_GRAPH_DEF_FILEPATH: + with open(self._LSTM_GRAPH_DEF_FILEPATH, 'w') as f: + f.write(str(graphdef)) + + def testLstm(self): + """Checks an LSTM against the reference implementation.""" + np_inputs, np_slen = self._CreateInputs() + var_cache = {} + _, func_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False) + _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, True) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testGru(self): + """Checks a GRU cell against the reference implementation.""" + np_inputs, np_slen = self._CreateInputs() + var_cache = {} + _, func_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, False) + _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, True) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testStackedLstm(self): + """Checks a stacked LSTM cell against the reference implementation.""" + np_inputs, np_slen = self._CreateInputs() + var_cache = {} + args = [np_inputs, np_slen, 'stacked_lstm', var_cache] + _, func_rnn = self._RunRnn(*(args + [False])) + _, dyn_rnn = self._RunRnn(*(args + [True])) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + +if __name__ == '__main__': + test_lib.main() diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py new file mode 100644 index 0000000..00fbd4f --- /dev/null +++ b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py @@ -0,0 +1,192 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Recurrent ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.contrib.recurrent.python.ops import recurrent +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test as test_lib +from tensorflow.python.platform import tf_logging as logging + + +_ElmanState = collections.namedtuple('ElmanState', ('h')) +_ElmanTheta = collections.namedtuple('ElmanTheta', ('w', 'b')) +_ElmanInputs = collections.namedtuple('ElmanInputs', ('x')) + + +# TODO(drpng): add test for max length computation. +class RecurrentTest(test_util.TensorFlowTestCase): + + def testBasic(self): + # pylint:disable=invalid-name + _PolyState = collections.namedtuple('PolyState', ('value', 'x_power')) + _PolyTheta = collections.namedtuple('PolyTheta', ('x')) + _PolyInputs = collections.namedtuple('PolyInputs', ('coeff')) + # pylint:enable=invalid-name + + def Poly(theta, state, inputs): + next_state = _PolyState( + value=state.value + inputs.coeff * state.x_power, + x_power=state.x_power * theta.x) + return next_state, [] + + with self.test_session() as sess: + theta = _PolyTheta(x=array_ops.constant(2.0)) + state = _PolyState( + value=array_ops.constant(0.0), + x_power=array_ops.constant(1.0)) + inputs = _PolyInputs(coeff=array_ops.constant([1., 2., 3.])) + + # x = 2 + # 1 + 2*x + 3*x^2 + ret = recurrent.Recurrent(theta, state, inputs, Poly) + + acc, state = sess.run(ret) + self.assertAllClose(acc.value, [1., 5., 17.]) + self.assertAllClose(acc.x_power, [2., 4., 8.]) + self.assertAllClose(state.value, 17.) + self.assertAllClose(state.x_power, 8.) + + y = ret[1].value + dx, d_coeff = gradients_impl.gradients(ys=[y], xs=[theta.x, inputs.coeff]) + dx_val, d_coeff_val = sess.run([dx, d_coeff]) + + # 2 + 6*x + self.assertAllClose(dx_val, 14.) + self.assertAllClose(d_coeff_val, [1., 2., 4.]) + + # acc = [1, 1+2x, 1+2x+3x^2] + # sum(acc) = 3 + 4x + 3x^2 + acc = ret[0].value + dx, d_coeff = gradients_impl.gradients( + ys=[math_ops.reduce_sum(acc)], xs=[theta.x, inputs.coeff]) + dx_val, d_coeff_val = sess.run([dx, d_coeff]) + # 4 + 6*x + self.assertAllClose(dx_val, 16.) + self.assertAllClose(d_coeff_val, [3., 4., 4.]) + + @staticmethod + def Rand(shape): + return random_ops.random_uniform( + shape, minval=-0.2, maxval=0.2, dtype=dtypes.float64) + + @staticmethod + def Elman(theta, state0, inputs): + h0, w, b, x = state0.h, theta.w, theta.b, inputs.x + xw = math_ops.matmul(array_ops.concat([x, h0], axis=1), w) + h1 = math_ops.sigmoid(xw + b) + state1 = _ElmanState(h=h1) + return (state1, state1) + + @staticmethod + def ElmanGrad(theta, state0, inputs, extras, dstate1): + + @function.Defun() + def Grad(h0, w, b, x, h1, dh1): + del b + # We hand-roll the gradient for the 2nd half of the cell as a demo. + dxwb = (dh1 * (1 - h1) * h1) + dxw, db = dxwb, math_ops.reduce_sum(dxwb, axis=0) + + # Uses tf.gradient for the 1nd half of the cell as a demo. + xw = math_ops.matmul(array_ops.concat([x, h0], axis=1), w) + dh0, dx, dw = gradients_impl.gradients( + ys=[xw], xs=[h0, x, w], grad_ys=[dxw]) + + return dh0, dx, dw, db + + dh0, dx, dw, db = Grad(state0.h, theta.w, theta.b, inputs.x, + extras.h, dstate1.h) + dstate0 = _ElmanState(h=dh0) + dinputs = _ElmanInputs(x=dx) + return (_ElmanTheta(w=dw, b=db), dstate0, dinputs) + + @staticmethod + def ElmanOut(state1): + return _ElmanState(x=state1.h) + + @staticmethod + def ElmanOutGrad(dout): + return _ElmanState(h=dout.x) + + def testElman(self): + for seqlen, use_grad in [(1, False), (1, True), (7, False), (7, True)]: + logging.info('== Elman: seqlen=%s, use_grad=%s', seqlen, use_grad) + self._ParameterizedTestElman(seqlen, use_grad) + + def _ParameterizedTestElman(self, seqlen, use_grad): + + with self.test_session() as sess: + random_seed.set_random_seed(342462) + + batch = 3 + dims = 4 + theta = _ElmanTheta(w=RecurrentTest.Rand([2 * dims, dims]), + b=RecurrentTest.Rand([dims])) + state0 = _ElmanState(h=RecurrentTest.Rand([batch, dims])) + inputs = _ElmanInputs(x=RecurrentTest.Rand([seqlen, batch, dims])) + + # Statically unrolled. + s = state0 + out = [] + for i in xrange(seqlen): + inp = _ElmanInputs(x=inputs.x[i, :]) + s, _ = RecurrentTest.Elman(theta, s, inp) + out += [s.h] + acc0, final0 = array_ops.stack(out), s.h + loss0 = math_ops.reduce_sum(acc0) + math_ops.reduce_sum(final0) + (dw0, db0, dh0, di0) = gradients_impl.gradients( + loss0, [theta.w, theta.b, state0.h, inputs.x]) + + acc1, final1 = recurrent.Recurrent( + theta=theta, + state0=state0, + inputs=inputs, + cell_fn=RecurrentTest.Elman, + cell_grad=RecurrentTest.ElmanGrad if use_grad else None) + assert isinstance(acc1, _ElmanState) + assert isinstance(final1, _ElmanState) + acc1, final1 = acc1.h, final1.h + loss1 = math_ops.reduce_sum(acc1) + math_ops.reduce_sum(final1) + (dw1, db1, dh1, di1) = gradients_impl.gradients( + loss1, [theta.w, theta.b, state0.h, inputs.x]) + + # Fetches a few values and compare them. + (acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0, + di1) = sess.run( + [acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0, di1]) + self.assertAllClose(acc0, acc1) + self.assertAllClose(final0, final1) + self.assertAllClose(dw0, dw1) + self.assertAllClose(db0, db1) + self.assertAllClose(dh0, dh1) + self.assertAllClose(di0, di1) + +if __name__ == '__main__': + test_lib.main() diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py new file mode 100644 index 0000000..a085474 --- /dev/null +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -0,0 +1,396 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A tf.nn.dynamic_rnn variant, built on the Recurrent class. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +from tensorflow.contrib.recurrent.python.ops import recurrent +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import nest + + +def _GetDTypesFromStructure(struct): + dtypes_list = [] + for x in nest.flatten(struct): + x = ops.convert_to_tensor(x) + dtypes_list.append(x.dtype) + return dtypes_list + + +def _SetShapeFromTemplate(struct, struct_template): + as_list = nest.flatten(struct) + template_as_list = nest.flatten(struct_template) + for element, template in zip(as_list, template_as_list): + element.set_shape(template.shape) + + +class _FunctionalRnnCell(object): + """Wrapper around RNNCell which separates state from computation. + + This class accomplishes the following: + * Turn the cell's `__call__` function into a pure function. The global + side effects are separated as `theta`. They are the variables created + for the weights of the computation. + * Unless the output is aliased as part of the state, extend the state to + contain the output so that we store the history in `Recurrent`. + * Set static shapes as required. + """ + + def __init__(self, rnn_cell, seq_inputs, initial_state): + assert initial_state is not None + + # TODO(drpng): Dtype needs to be configurable. + input_dtypes = [dtypes.float32] + _GetDTypesFromStructure(initial_state) + # See _index. + like_inputs_t = nest.map_structure( + lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs) + input_structure = (like_inputs_t, initial_state) + + @function.Defun(*input_dtypes) + def FlatCellStep(*flat_inputs): + """The flattened version of `rnn_cell`.""" + inputs_t, state0 = nest.pack_sequence_as(input_structure, flat_inputs) + _SetShapeFromTemplate(state0, initial_state) + _SetShapeFromTemplate(inputs_t, like_inputs_t) + outputs_t, state1 = rnn_cell(inputs_t, state0) + state_list = nest.flatten(state1) + self._output_shape = outputs_t.shape + + if outputs_t in state_list: + output_index_in_state = state_list.index(outputs_t) + else: + output_index_in_state = None + + if output_index_in_state is None: + self._prepend_output = True + self._output_state_idx = 0 + return [outputs_t] + state_list + else: + self._output_state_idx = output_index_in_state + self._prepend_output = False + # To save memory, we don't store return the output separately + # from the state list, since we know it's the same. + return state_list + + def _ToPureFunction(func): + # NOTE: This forces the creating of the function. + if func.captured_inputs: + pure_func = copy.copy(func) + # pylint: disable=protected-access + pure_func._extra_inputs = [] + return pure_func + return func + + pure_flat_cell_step = _ToPureFunction(FlatCellStep) + + def CellStep(theta, extended_state0, inputs_t): + """Performs one time steps on structured inputs. + + The purpose of this function is to turn the parameters into flattened + versions, and to resolve the parameter order difference between + `Recurrent` and `RNNCell`. + + In the event the cell returns a transformed output that is not aliased + within its state, the `extended_state0` also contains the output as its + first element. + + Args: + theta: Weights required for the computation. A structure of tensors. + extended_state0: the state0, and possibly the output at the previous + time step. A structure of tensors. + inputs_t: the inputs at time t. + + Returns: + A pair of the next state (inclusive of the output), and an empty list + (unused `extras`). + The next state is congruent to state0. + """ + extended_state0_flat = nest.flatten(extended_state0) + state0_flat = self.MaybeRemoveOutputFromState(extended_state0_flat) + full_inputs = [inputs_t] + state0_flat + theta + # Note that the thetas are additional inputs appeneded as extra + # parameters. + cell_out = pure_flat_cell_step(*full_inputs) + return cell_out, [] + + self._cell_step = CellStep + self._theta = FlatCellStep.captured_inputs + self._zero_state = rnn_cell.zero_state + self._state_template = initial_state + self._output_size = rnn_cell.output_size + + @property + def extended_initial_state(self): + if self._prepend_output: + return [array_ops.zeros(self._output_shape), self._state_template] + else: + # The base case, where the output is just the hidden state. + return self._state_template + + @property + def cell_step(self): + return self._cell_step + + @property + def theta(self): + return self._theta + + @property + def state_template(self): + return self._state_template + + @property + def output_shape(self): + return self._output_shape + + def GetOutputFromState(self, state): + return nest.flatten(state)[self._output_state_idx] + + def MaybeRemoveOutputFromState(self, flat_state): + if self._prepend_output: + return flat_state[1:] + return flat_state + + +def _ApplyLengthsToBatch(sequence_lengths, tf_output): + # TODO(drpng): just use Update so that we don't carry over the gradients? + """Sets the output to be zero at the end of the sequence.""" + # output is batch major. + batch_size, max_time, vector_size = tf_output.shape + output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size]) + output_time = array_ops.reshape(output_time, [batch_size, max_time]) + lengths = array_ops.tile( + array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time]) + is_less = math_ops.cast( + math_ops.less(output_time, lengths), dtype=dtypes.float32) + keep_mask = array_ops.tile( + array_ops.expand_dims(is_less, -1), + [1, 1, vector_size]) + final_output = keep_mask * tf_output + return final_output + + +def _PickFinalStateFromHistory(acc_state, sequence_length): + """Implements acc_state[sequence_length - 1].""" + # This will work on all platforms, unlike the regular slice. + last_value = [] + for state_var in nest.flatten(acc_state): + # We compute the following with matrix operations: + # last_var = state_var[sequence_length - 1] + shape = array_ops.shape(state_var) + max_time, batch_size = shape[0], shape[1] + output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size]) + output_time = array_ops.reshape(output_time, [batch_size, max_time]) + lengths = array_ops.tile(array_ops.reshape(sequence_length, + [-1, 1]), [1, max_time]) + last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1), + dtype=dtypes.float32) + last_idx = array_ops.transpose(last_idx) + last_idx_for_bcast = array_ops.expand_dims(last_idx, -1) + sliced = math_ops.multiply(last_idx_for_bcast, state_var) + last_var = math_ops.reduce_sum(sliced, 0) + last_value += [last_var] + return nest.pack_sequence_as(acc_state, last_value) + + +def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell, + total_time, inputs_lengths): + """Post-process output of recurrent. + + This function takes the accumulated extended state and extracts the requested + state and output. + + When `inputs_lengths` has been set, it extracts the output from the + accumulated state. It also sets outputs past. + + It also sets the static shape information. + + Args: + extended_acc_state: A structure containing the accumulated state at each + time. It may contain the output at each time as well. + extended_final_state: A structure containing the final state. It may + contain the output at the final time. + func_cell: The functional wrapper around the cell. + total_time: A scalar integer tensor. + inputs_lengths: An integer tensor with one entry per input. + + Returns: + A tuple with the outputs at each time, and the final state. + """ + if inputs_lengths is None: + flat_final_state = func_cell.MaybeRemoveOutputFromState( + nest.flatten(extended_final_state)) + tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state) + else: + # The accumulated state is over the entire sequence, so we pick it + # out from the acc_state sequence. + flat_acc_state = func_cell.MaybeRemoveOutputFromState( + nest.flatten(extended_acc_state)) + acc_state = nest.pack_sequence_as( + func_cell.state_template, flat_acc_state) + tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths) + + output_from_state = func_cell.GetOutputFromState(extended_acc_state) + tf_output = array_ops.transpose(output_from_state, [1, 0, 2]) + tf_output.set_shape( + [func_cell.output_shape[0], total_time, func_cell.output_shape[1]]) + if inputs_lengths is not None: + # Need set the outputs to zero. + tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output) + # tf_output = array_ops.zeros([4, 3, 5]) + _SetShapeFromTemplate(tf_state, func_cell.state_template) + return tf_output, tf_state + + +# pylint: disable=invalid-name +def functional_rnn(cell, inputs, sequence_length=None, + initial_state=None, dtype=None, time_major=False, + scope=None, use_tpu=False): + """Same interface as `tf.nn.dynamic_rnn`.""" + with variable_scope.variable_scope(scope or 'rnn'): + if not time_major: + inputs = nest.map_structure( + lambda t: array_ops.transpose(t, [1, 0, 2]), inputs) + inputs_flat = nest.flatten(inputs) + batch_size = array_ops.shape(inputs_flat[0])[1] + if initial_state is None: + initial_state = cell.zero_state(batch_size, dtype) + func_cell = _FunctionalRnnCell(cell, inputs, initial_state) + extended_acc_state, extended_final_state = recurrent.Recurrent( + theta=func_cell.theta, + state0=func_cell.extended_initial_state, + inputs=inputs, + cell_fn=func_cell.cell_step, + use_tpu=use_tpu) + return _PostProcessOutput(extended_acc_state, extended_final_state, + func_cell, inputs_flat[0].shape[0], sequence_length) + + +def bidirectional_functional_rnn( + cell_fw, + cell_bw, + inputs, + initial_state_fw=None, + initial_state_bw=None, + dtype=None, + sequence_length=None, + time_major=False, + use_tpu=False, + scope=None): + """Creates a bidirectional recurrent neural network. + + Performs fully dynamic unrolling of inputs in both directions. Built to be API + compatible with `tf.nn.bidirectional_dynamic_rnn`, but implemented with + functional control flow for TPU compatibility. + + Args: + cell_fw: An instance of `tf.contrib.rnn.RNNCell`. + cell_bw: An instance of `tf.contrib.rnn.RNNCell`. + inputs: The RNN inputs. If time_major == False (default), this must be a + Tensor (or hierarchical structure of Tensors) of shape + [batch_size, max_time, ...]. If time_major == True, this must be a Tensor + (or hierarchical structure of Tensors) of shape: + [max_time, batch_size, ...]. The first two dimensions must match across + all the inputs, but otherwise the ranks and other shape components may + differ. + initial_state_fw: An optional initial state for `cell_fw`. Should match + `cell_fw.zero_state` in structure and type. + initial_state_bw: An optional initial state for `cell_bw`. Should match + `cell_bw.zero_state` in structure and type. + dtype: (optional) The data type for the initial state and expected output. + Required if initial_states are not provided or RNN state has a + heterogeneous dtype. + sequence_length: An optional int32/int64 vector sized [batch_size]. Used to + copy-through state and zero-out outputs when past a batch element's + sequence length. So it's more for correctness than performance. + time_major: Whether the `inputs` tensor is in "time major" format. + use_tpu: Whether to enable TPU-compatible operation. If True, does not truly + reverse `inputs` in the backwards RNN. Once b/69305369 is fixed, we can + remove this flag. + scope: An optional scope name for the dynamic RNN. + + Returns: + outputs: A tuple of `(output_fw, output_bw)`. The output of the forward and + backward RNN. If time_major == False (default), these will + be Tensors shaped: [batch_size, max_time, cell.output_size]. If + time_major == True, these will be Tensors shaped: + [max_time, batch_size, cell.output_size]. Note, if cell.output_size is a + (possibly nested) tuple of integers or TensorShape objects, then the + output for that direction will be a tuple having the same structure as + cell.output_size, containing Tensors having shapes corresponding to the + shape data in cell.output_size. + final_states: A tuple of `(final_state_fw, final_state_bw)`. A Tensor or + hierarchical structure of Tensors indicating the final cell state in each + direction. Must have the same structure and shape as cell.zero_state. + + Raises: + ValueError: If `initial_state_fw` is None or `initial_state_bw` is None and + `dtype` is not provided. + """ + # Keep this code in sync with tf.nn.dynamic_rnn for compatibility. + with variable_scope.variable_scope(scope or 'bidirectional_rnn'): + # Forward direction + with variable_scope.variable_scope('fw') as fw_scope: + output_fw, output_state_fw = functional_rnn( + cell=cell_fw, inputs=inputs, sequence_length=sequence_length, + initial_state=initial_state_fw, dtype=dtype, + time_major=time_major, scope=fw_scope, use_tpu=use_tpu) + # Backward direction + if not time_major: + time_dim = 1 + batch_dim = 0 + else: + time_dim = 0 + batch_dim = 1 + + def _reverse(input_, seq_lengths, seq_dim, batch_dim): + if seq_lengths is not None: + return array_ops.reverse_sequence( + input=input_, seq_lengths=seq_lengths, + seq_dim=seq_dim, batch_dim=batch_dim) + else: + # See b/69305369. + assert not use_tpu, ( + 'Bidirectional with variable sequence lengths unsupported on TPU') + return array_ops.reverse(input_, axis=[seq_dim]) + + with variable_scope.variable_scope('bw') as bw_scope: + inputs_reverse = _reverse( + inputs, seq_lengths=sequence_length, + seq_dim=time_dim, batch_dim=batch_dim) + tmp, output_state_bw = functional_rnn( + cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, + initial_state=initial_state_bw, dtype=dtype, + time_major=time_major, scope=bw_scope, use_tpu=use_tpu) + + output_bw = _reverse( + tmp, seq_lengths=sequence_length, + seq_dim=time_dim, batch_dim=batch_dim) + + outputs = (output_fw, output_bw) + output_states = (output_state_fw, output_state_bw) + + return (outputs, output_states) +# pylint: enable=invalid-name diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py new file mode 100644 index 0000000..fa16b82 --- /dev/null +++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py @@ -0,0 +1,720 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Recurrent computation. + +The main interface of this module is Recurrent(). +A recurrent computation describes an auto-regressive process, where outputs +of one time step are fed to the output of the next time step. + +This module uses: + theta: the "weights" each RNN uses. + state0: the initial state of each RNN. + cell_fn: A python function describing RNN cell. It must has the following + signature: + cell_fn: (theta, state0, inputs) -> (state1, extras) + state1 is the next RNN state, extras are computed by cell_fn + and the library forwards extras to cell_fn's gradient function. + cell_grad: A python function describing the backprop gradient function + for the RNN cell. It must has the following signature: + cell_grad: (theta, state0, inputs, extras, dstate1) -> ( + dtheta, dstate0, dinputs) + dstate1 is what the backprop algorithm provides representing + gradients of state1 w.r.t. the final loss. + +In this module, we handle structures of tensors for theta, state0, inputs, +and extras. The structure is an arbitrarily nested python structure, such +as a dictionary of named tuples. + +Because the computation is a left-to-right chain, a single in-place accumulator +can be used rather than a stack. Thus a special gradient was written to reduce +unnecessary memory usage. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import inplace_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.inplace_ops import alias_inplace_update +from tensorflow.python.util import nest + + +def _AssertIsCompatible(a, b): + """Checks that `a` and `b` are nested structures of the same type.""" + # TODO(drpng): implement. + del a + del b + + +def _Index(struct, index): + """Returns a structure with `x[index]` for each tensor `x` in the structure. + + Args: + struct: A structure of tensors. + index: A scalar integer tensor. Performance is better if `index` is + on the host memory. + + Returns: + A structure of tensors congruent to `struct`. + For each key in `ret`, `rets[key] = struct[key][index]`. + """ + index = ops.convert_to_tensor(index) + index.get_shape().assert_has_rank(0) + return nest.map_structure(lambda x: x[index], struct) + + +def _Update(struct_acc, struct_x, t): + """Updates t-th row in accumulators. + + Args: + struct_acc: The accumulators. A structure of tensors. + struct_x: The new values. A structure of tensors congruent to `struct_acc`. + t: A scalar integer. Performance is better if `t` is on the device + memory. + + Returns: + A structure of tensors. Say, ret is a returned dictionary. Then, for + each key, we have: + ret[key] = struct_acc[key]; + ret[key][t, :] = struct_x[key] + """ + to_skip_update = set() + acc_lst = nest.flatten(struct_acc) + x_lst = nest.flatten(struct_x) + t = math_ops.to_int32([t]) # tf.to_int32 casts on-device tensors. + lst = [] + for acc, x in zip(acc_lst, x_lst): + if acc in to_skip_update: + # Until b/62105730 is fixed, we need to avoid inplace update for tensors + # of rank 1. could reshape to handle it, but we don't really need the + # values applied to these, so just skip their modification. + lst += [acc] + else: + lst += [alias_inplace_update(acc, t, array_ops.expand_dims(x, 0))] + return nest.pack_sequence_as(struct_acc, lst) + + +def _SeqLenDim(struct): + """Returns the 0-th dim size of tensors in a structure of tensors. + + This is the max sequence length according to the shape of the inputs. + + Args: + struct: A structure of tensors. Every tensor's 0-th dim has the same size. + + Returns: + A scalar tensor which is the size of 0-th dim of every tensors in struct. + """ + xs = nest.flatten(struct) + assert xs + dim0 = array_ops.shape(xs[0])[0] + return dim0 + + +def _Flatten(struct): + """Flattens a structure.""" + return nest.flatten(struct) + + +def _Pack(elements, struct_template): + """Packs the list of tensors according to the structure. + + In the event that `elements` should be a scalar, `struct_template` must + contain exactly one non-trivial element (for instance, `[[], {'x':elt}]`). + + Args: + elements: Elements to be packed. A list of tensor, or a single tensor. + struct_template: The container structure in which to pack them. + Returns: + A python structure of the same type as `struct_template`, containing + `elements` as its contained elements. + """ + if not nest.is_sequence(elements): + return nest.pack_sequence_as(struct_template, [elements]) + return nest.pack_sequence_as(struct_template, elements) + + +def _EmptyAcc(slen, struct_template): + """Creates a set of accumulators for tensors in structure. + + Args: + slen: The sequence length. A scalar tensor. + struct_template: A structure of tensors. + + Returns: + A structure congruent to `struct_template`. Say ret is a returned + dictionary. Then, `ret.key`, a tensor, has the same dtype as + `struct_template.key`. The tensor's shape has 1 more dimension + than the tensor `struct_template.key`. The extra 0-th dimension is of size + `slen`. E.g., if `slen=10` and `struct_template.key`'s shape is `[3, 5]`, + then, `ret.key`'s shape is `[10, 3, 5]`. + """ + + def _EmptyAccForTensor(tensor): + return inplace_ops.empty( + array_ops.concat([[slen], array_ops.shape(tensor)], axis=0), + tensor.dtype, + init=True) + + return nest.map_structure(_EmptyAccForTensor, struct_template) + + +def _EmptyLike(struct): + """Creates a set of empty initialized tensors. + + Args: + struct: A structure of tensors. + + Returns: + A struct of tensors. Each tensor has the same shape and dtype as + its corresponding tensor in `struct`. And each tensor is initialized. + """ + return nest.map_structure( + lambda x: inplace_ops.empty_like(x, init=True), struct) + + +def _Add(struct_x, struct_y): + """Adds tensors in `struct_x` with respective tensors in `struct_y`. + + Args: + struct_x: A struct of tensors. + struct_y: A struct of tensors congruent to `struct_x`. + + Returns: + A struct of tensors. Each element of the returned value + equals `x + y`, with corresponding values in `struct_x` and `struct_y`. + """ + list_x = nest.flatten(struct_x) + list_y = nest.flatten(struct_y) + z = [] + for x, y in zip(list_x, list_y): + z += [math_ops.add(x, y)] + return nest.pack_sequence_as(struct_x, z) + + +def _Dtypes(struct): + """Returns all tensors' data types in a list.""" + return [x.dtype for x in nest.flatten(struct)] + + +def _ConvertNoneGradientToZeros(xs, dxs): + """Sanitize dxs so that None becomes zeros appropriately. + + Args: + xs: A list of tensors. + dxs: A list of tensors. dxs[i] corresponds to xs[i]'s gradient. + + Returns: + A structure same as `dxs` with `None` replaced by a zero tensor. + """ + list_xs = nest.flatten(xs) + list_dxs = nest.flatten(dxs) + + # If x does not get any backprop-ed gradient, propagate zeros. + rets = [] + for (x, dx) in zip(list_xs, list_dxs): + if dx is None: + rets.append(array_ops.zeros_like(x)) + else: + rets.append(dx) + + return nest.pack_sequence_as(dxs, rets) + + +# All structures are flattened for use internally. This is for simplicity +# and also to use the Defun construct. +# In the forward pass (inference), the computation is structured as follows. +# Forward: [gradient = _Recurrent.Grad] +# Flatten structures, create accumulators. +# for t = 0..max_input_length: +# Defun ForwardLoopBody: +# Defun Fwd: flatten/pack around cell_fn +# state1 = Fwd(inputs[t], state0) +# acc_state += [state1] +# Pack structures. +# During the backward pass (backpropping the gradient from the last time +# step to the first, through the structure), the computation is structured +# as follows. +# Grad: +# Flatten structures. +# Defun Backward: +# Create create accumulated derivatives: d_theta, d_inputs, d_acc_state. +# Regarding the note at the top of the file, there is only one accumulator +# for d_theta accumulated over the whole sequence. +# for t = max_input_length -1..0: +# Defun BackwardLoopBody: +# Retrieve acc_state[t] computed in the forward pass. +# Defun Bak: flatten/back around cell_fn_grad. +# d_state1 is d_state0 from previous step (ie next time). +# d_acc_state[dev_t] += d_state1 +# d_theta_t, d_state0, d_inputs_t, = Bak() +# d_inputs[dev_t] += d_inputs +# d_theta += d_theta_t +# d_acc_state[t] += d_state1 +# Pack structures and return. +class _Recurrent(object): + """A helper class to construct a recurrent neural net.""" + + def __init__(self, cell_fn, cell_grad, theta, state0, inputs, + max_input_length, extras, use_tpu): + """RNN helper class. + + Args: + cell_fn: A python function, which computes: + state1, extras = cell_fn(theta, state0, inputs[t, :]) + cell_grad: A python function which computes: + dtheta, dstate0, dinputs[t, :] = cell_grad( + theta, state0, inputs[t, :], extras, dstate1) + theta: weights. A structure of tensors. + state0: initial state. A structure of tensors. + inputs: inputs. A structure of tensors. + max_input_length: None, or the maximum effective length of the input over + all batches. A scalar tensor. + extras: A structure of tensors. The 2nd return value of every + invocation of cell_fn is a structure of tensors with matching keys + and shapes of this `extras`. + use_tpu: A boolean indicating whether the computation is mean to + run on a TPU. + """ + self._theta = theta + self._state = state0 + self._inputs = inputs + self._max_input_length = self._MaybeComputeMaxInputLength( + inputs, max_input_length) + self._cell_fn = cell_fn + self._cell_grad = cell_grad + self._extras = extras + + # pylint: disable=unbalanced-tuple-unpacking + + # NOTE: TF Function (Fwd, Bak, ForwardLoopBody, BackwardLoopBody, + # Forward and Backward defined below) simply takes a list of + # Tensors and returns a list of Tensors. When we pass in a + # structure (a list of structures of Tensors), we use _Flatten to + # convert the structure into a list of tensor. Conversely, the + # following code often uses _Pack to formulate a structure from a + # list of tensors based on a "template". + + # Wraps cell_fn in a TF Function: + # state1 = cell_fn(theta, state0, inputs) + fwd_sig = [self._theta, self._state, self._inputs] + + compiled = use_tpu + noinline = not compiled + dev_t_type = dtypes.int32 if use_tpu else dtypes.int64 + + @function.Defun(*_Dtypes(fwd_sig)) + def Fwd(*args): + (theta, state0, inputs) = _Pack(args, fwd_sig) + state1, extras = self._cell_fn(theta, state0, inputs) + assert not function.get_extra_args(), ( + 'cell_fn is not pure with extra args: %s.' % + (function.get_extra_args())) + _AssertIsCompatible(state1, self._state) + _AssertIsCompatible(extras, self._extras) + return _Flatten([state1, extras]) + + # Wraps cell_fn in a TF Function as a for-loop's body. + # + # The loop state is composed of: + # t: The loop variable. Timestep id. + # dev_t: The loop variable mirrored on the device. + # theta: the recurrent net's weights. + # state0: the previous recurrent state. + # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t. + # acc_state: Each timestep's computed new state is also stashed into + # acc_state. + # acc_extras: Each timestep's computed extras is stashed into acc_extras + fwdloop_sig = [ + self._theta, self._state, self._inputs, self._state, self._extras + ] + + @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(fwdloop_sig)) + def ForwardLoopBody(*args): + """The body of forward loop.""" + t, dev_t = args[0], args[1] + (theta, state0, inputs, acc_state, acc_extras) = _Pack( + args[2:], fwdloop_sig) + inputs_t = _Index(inputs, t) # external input at time step t. + fwd = Fwd(*_Flatten([theta, state0, inputs_t])) + state1, extras = _Pack(fwd, [self._state, self._extras]) + # Saves state1 and extras in their accumulators. + acc_state = _Update(acc_state, state1, dev_t) + acc_extras = _Update(acc_extras, extras, dev_t) + + return [math_ops.add(dev_t, 1)] + _Flatten( + [theta, state1, inputs, acc_state, acc_extras]) + + def Grad(op, *args): + """The python grad function for the Forward function.""" + + # NOTE: tf.gradient backprops None for int32/int64 while zeros + # for float32/float64. For consistency, we always backprop + # zeros. + args = list(args) + for i, dy in enumerate(args): + if dy is None: + args[i] = array_ops.zeros_like(op.outputs[i]) + # TODO(drpng): getting the extra state here? + op_inputs = [x for x in op.inputs] + op_struct = [ + self._theta, self._state, self._inputs, self._max_input_length, + self._extras + ] + (theta, state0, inputs, max_input_length, _) = _Pack(op_inputs, op_struct) + # acc_state and acc_extras are computed by the Forward pass and + # needed by the Backward pass. + acc_state, _, acc_extras = _Pack([x for x in op.outputs], + [self._state, self._state, self._extras]) + + # Forward computes acc_state, the final state and + # acc_extras. tf.gradients gives us their gradients w.r.t. the + # final loss. Because acc_extras are not exposed by Compute(), + # it has no gradients w.r.t. the final loss (i.e., by + # construction, it must be zeros). + d_acc_state, d_state1, _ = _Pack(args, + [self._state, self._state, self._extras]) + return Backward(*_Flatten([ + theta, state0, inputs, max_input_length, acc_state, acc_extras, + d_acc_state, d_state1 + ])) + + # Forward calls ForwardLoopBody n times. Each time computes one + # time step of the recurrent net. + forward_sig = [ + self._theta, self._state, self._inputs, self._max_input_length, + self._extras + ] + + @function.Defun( + *_Dtypes(forward_sig), python_grad_func=Grad, noinline=noinline) + def Forward(*args): + """Forward pass of the recurrent net.""" + theta, state0, inputs, max_input_length, extras = _Pack(args, forward_sig) + + slen_dim = _SeqLenDim(inputs) + + # Creates accumulators for state0 and extras. + acc_state = _EmptyAcc(slen_dim, state0) + acc_extras = _EmptyAcc(slen_dim, extras) + + dev_t = array_ops.constant(0, dtype=dev_t_type) + run = functional_ops.For( + start=0, + limit=max_input_length, + delta=1, + inputs=[dev_t] + _Flatten( + [theta, state0, inputs, acc_state, acc_extras]), + body=ForwardLoopBody, + rewrite_with_while=compiled) + _, state1, _, acc_state, acc_extras = _Pack( + run[1:], + [self._theta, self._state, self._inputs, self._state, self._extras]) + + return _Flatten([acc_state, state1, acc_extras]) + + # The per-step backward computes: + # d_theta, d_state0, d_inputs = cell_grad( + # theta, state0, inputs, extras, d_state1) + # where d_state1 is the backprop-ed gradient for state1, and + # extras is the computed by the forward step to facilitate the + # backward step. + bak_sig = [ + self._theta, self._state, self._inputs, self._extras, self._state + ] + + @function.Defun(*_Dtypes(bak_sig)) + def Bak(*args): + """Backward step.""" + (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig) + (dtheta, dstate0, dinputs) = self._cell_grad(theta, state0, inputs, + extras, d_state1) + assert not function.get_extra_args(), ( + 'cell_grad is not pure with extra args: %s.' % + (function.get_extra_args())) + _AssertIsCompatible(dtheta, self._theta) + _AssertIsCompatible(dstate0, self._state) + _AssertIsCompatible(dinputs, self._inputs) + return _Flatten( + _ConvertNoneGradientToZeros([theta, state0, inputs], + [dtheta, dstate0, dinputs])) + + # Define defuns used by a functional_ops.If in BackwardLoopBody. + state_if_sig = [self._state, self._state] + + @function.Defun(*_Dtypes(state_if_sig)) + def ReturnOrigState0(*args): + """Returns original state0 from inputs.""" + (_, orig_state0) = _Pack(args, state_if_sig) + return nest.flatten(orig_state0) + + @function.Defun(*_Dtypes(state_if_sig)) + def ReturnAccState(*args): + """Returns acc_state[t-1] from inputs.""" + (acc_state, _) = _Pack(args, state_if_sig) + return nest.flatten(acc_state) + + # Wraps cell_grad gradient function in a TF Function as a + # for-loop's body for the Backward pass. + # + # The loop state is composed of: + # t: The loop variable. Timestep id. + # state0: the initial state for the entire backward loop. + # dev_t: The loop variable mirrored on the device. + # theta: the recurrent net's weights. + # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t. + # acc_state: Each timestep's computed new state was stashed into + # acc_state by the Forward pass. + # acc_extras: Each timestep's computed extras was stashed into + # acc_extras by the Forward pass. + # d_theta: All timestep's gradient for theta is accumulated (added) into + # d_theta. + # d_state1: The backprop-ed gradient for the new stated computed by + # timestep t. + # d_inputs: d_inputs[t, :] is populated by the backward time step t. + # d_acc_state: The backprop-ed gradient for acc_state. + bakloop_sig = [ + self._theta, self._state, self._inputs, self._state, self._extras, + self._theta, self._state, self._inputs, self._state + ] + + @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(bakloop_sig)) + def BackwardLoopBody(*args): + """Backward loop body function.""" + t, dev_t = args[0], args[1] + (theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state1, + d_inputs, d_acc_state) = _Pack(args[2:], bakloop_sig) + + # The input recurrent state for time step t is previous time step's + # output, or the original state0 when on time step 0. + state_from_acc = _Index(acc_state, math_ops.maximum(0, t - 1)) + state0 = functional_ops.If( + math_ops.equal(t, array_ops.constant(0, dtypes.int32)), + _Flatten([state_from_acc, orig_state0]), ReturnOrigState0, + ReturnAccState) + state0 = nest.pack_sequence_as(orig_state0, state0) + + # The external inputs for time step t. + inputs_t = _Index(inputs, t) + # The extras for time step t. + extras_t = _Index(acc_extras, t) + + d_state1 = _Add(_Index(d_acc_state, t), d_state1) + (d_theta_t, d_state0, d_inputs_t) = _Pack( + Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])), + [self._theta, self._state, self._inputs]) + d_theta = _Add(d_theta, d_theta_t) + d_inputs = _Update(d_inputs, d_inputs_t, dev_t) + return [math_ops.subtract(dev_t, 1)] + _Flatten([ + theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state0, + d_inputs, d_acc_state + ]) + + # Backward calls BackwardLoopBody n times. Each time computes the backprop + # for one time step of the recurrent net. + backward_sig = [ + self._theta, self._state, self._inputs, self._max_input_length, + self._state, self._extras, self._state, self._state + ] + + @function.Defun(*_Dtypes(backward_sig), noinline=noinline) + def Backward(*args): + """Backward pass for the recurrent net.""" + # theta, state0, inputs are Forward's inputs. + # acc_state is the accumulated 1st output of Forward. + # acc_extras is the accumulated 2nd output of Forward. + # d_acc_state is the gradient for acc_state. + # d_state1 is the gradient for the final state computed by Forward. + (theta, state0, inputs, max_input_length, acc_state, acc_extras, + d_acc_state, d_state1) = _Pack(args, backward_sig) + + # Accumulators for gradients. + d_theta = _EmptyLike(theta) + d_inputs = _EmptyLike(inputs) + + # Loop backwards. Note the loop's limit is open-ended, so goes through + # t=0. + t = max_input_length - 1 + dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t) + run = functional_ops.For( + start=t, + limit=-1, + delta=-1, + inputs=[dev_t] + _Flatten([ + theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1, + d_inputs, d_acc_state + ]), + body=BackwardLoopBody, + rewrite_with_while=compiled) + + (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0, + d_inputs, d_acc_state) = _Pack(run[1:], bakloop_sig) + + d_max_input_length = array_ops.constant(0, dtype=max_input_length.dtype) + return _Flatten( + [d_theta, d_state0, d_inputs, d_max_input_length, acc_extras]) + + self._forward = Forward + + def _MaybeComputeMaxInputLength(self, inputs, max_input_length): + if max_input_length is not None: + return max_input_length + return math_ops.reduce_max(array_ops.shape(nest.flatten(inputs)[0])[0]) + + def Compute(self): + return _Pack( + self._forward(*_Flatten([ + self._theta, self._state, self._inputs, self._max_input_length, + self._extras + ])), [self._state, self._state, self._extras])[:2] + + +def _GetCellGrad(cell_fn, cell_grad): + """Returns the gradient function for cell_fn. + + Args: + cell_fn: The recurrent neural net's cell function. + cell_grad: If not None, cell_fn's gradient function. + + Returns: + Returns cell_grad if not None. Otherwise, assume cell_fn is a python + function representing the recurrent neural net's cell function, i.e., + cell_fn: (theta, state0, inputs) -> (state1, extra) + returns its default gradient python function, i.e., + cell_grad: (theta, state0, inputs, extras, dstate1) -> ( + dtheta, dstate0, dinputs) + """ + + if cell_grad: + return cell_grad + + def CellGrad(theta, state0, inputs, extras, dstate1): + """Default gradient function for cell_fn.""" + # NOTE: The default grad function recomputes the forward + # function and does not take advantage of 'extras' returned by + # the forward function. + del extras + state1, extras = cell_fn(theta, state0, inputs) + ys = _Flatten([state1]) + xs = _Flatten([theta, state0, inputs]) + grad_ys = _Flatten([dstate1]) + grads = gradients_impl.gradients(ys=ys, xs=xs, grad_ys=grad_ys) + return _ConvertNoneGradientToZeros([theta, state0, inputs], + _Pack(grads, [theta, state0, inputs])) + + return CellGrad + + +def _IsSingleTimeStep(inputs, max_input_length): + """Returns True only if the time dimension of inputs is 1.""" + if not isinstance(max_input_length, ops.Tensor): + return max_input_length == 1 + for x in nest.flatten(inputs): + if x.shape.dims is None or x.shape[0].value != 1: + return False + return True + + +def Recurrent(theta, + state0, + inputs, + cell_fn, + cell_grad=None, + extras=None, + max_input_length=None, + use_tpu=False): + """Compute a recurrent neural net. + + Roughly, Recurrent() computes the following: + state = state0 + for t in inputs' sequence length: + state = cell_fn(theta, state, inputs[t, :]) + accumulate_state[t, :] = state + return accumulate_state, state + + theta, state, inputs are all structures of tensors. + + inputs[t, :] means taking a slice out from every tensor in the inputs. + + accumulate_state[t, :] = state means that we stash every tensor in + 'state' into a slice of the corresponding tensor in + accumulate_state. + + cell_fn is a python callable computing (building up a TensorFlow + graph) the recurrent neural network's one forward step. Two calls of + cell_fn must describe two identical computations. + + By construction, Recurrent()'s backward computation does not access + any intermediate values computed by cell_fn during forward + computation. We may extend Recurrent() to support that by taking a + customized backward function of cell_fn. + + Args: + theta: weights. A structure of tensors. + state0: initial state. A structure of tensors. + inputs: inputs. A structure of tensors. + cell_fn: A python function, which computes: + state1, extras = cell_fn(theta, state0, inputs[t, :]) + cell_grad: A python function which computes: + dtheta, dstate0, dinputs[t, :] = cell_grad( + theta, state0, inputs[t, :], extras, dstate1) + extras: A structure of tensors. The 2nd return value of every + invocation of cell_fn is a structure of tensors with matching keys + and shapes of this `extras`. + max_input_length: maximum length of effective input. This is used to + truncate the computation if the inputs have been allocated to a + larger size. A scalar tensor. + use_tpu: whether or not we are on TPU. + + Returns: + accumulate_state and the final state. + """ + if cell_grad is None and _IsSingleTimeStep(inputs, max_input_length): + # The seqlen length is staticly known as 1. Hence, we just need to + # call cell_fn once without putting it into a loop. + inputs = nest.map_structure(lambda x: array_ops.squeeze(x, axis=0), inputs) + state1, _ = cell_fn(theta, state0, inputs) + acc_state = nest.map_structure(lambda x: array_ops.expand_dims(x, axis=0), + state1) + return acc_state, state1 + + # If cell_grad is not given, derives the gradient function from + # cell_fn. + cell_grad = _GetCellGrad(cell_fn, cell_grad) + + if extras is None: + # Derives 'extras' so that we can allocate extras' accumulator. + _, extras = cell_fn(theta, state0, _Index(inputs, 0)) + extras = nest.map_structure(array_ops.zeros_like, extras) + else: + _, actual = cell_fn(theta, state0, _Index(inputs, 0)) + _AssertIsCompatible(extras, actual) + + return _Recurrent( + cell_fn=cell_fn, + cell_grad=cell_grad, + theta=theta, + state0=state0, + inputs=inputs, + max_input_length=max_input_length, + extras=extras, + use_tpu=use_tpu).Compute() diff --git a/tensorflow/contrib/recurrent/python/recurrent_api.py b/tensorflow/contrib/recurrent/python/recurrent_api.py new file mode 100644 index 0000000..ffe1dcf --- /dev/null +++ b/tensorflow/contrib/recurrent/python/recurrent_api.py @@ -0,0 +1,29 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Recurrent computations library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.recurrent.python.ops import functional_bidirectional_rnn +from tensorflow.contrib.recurrent.python.ops import functional_rnn +from tensorflow.contrib.recurrent.python.ops import Recurrent +# pylint: enable=unused-import + +del absolute_import +del division +del print_function -- 2.7.4