From a6eb244b2b8ee4d9592a705c4bc0771e4d708565 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Fri, 25 May 2018 17:37:01 -0700 Subject: [PATCH] Minor eager performance improvements - remove linear regression dependence on global step. This speeds things up a lot for the benchmark (since it removes a bunch of unnecessary code), but is obviously not a fair comparison. I think its worth doing, since I don't see any reason to have a global step in eager. - nn_ops dropout had an unnecessary convert_to_tensor, convert back to numpy (with a GPU this would copy out, copy back). - cudnn_recurrent reshape would always fallback to the slow path - so I just converted it to be in the fastpath - this will be low impact. - tensor_shape should not generate a new object every time - remove unnecessary list creation and searching in some dtypes functions PiperOrigin-RevId: 198127757 --- .../linear_regression/linear_regression.py | 6 ++--- tensorflow/python/framework/dtypes.py | 28 +++++++++++----------- tensorflow/python/framework/tensor_shape.py | 5 +++- tensorflow/python/keras/layers/cudnn_recurrent.py | 6 +++-- tensorflow/python/ops/nn_ops.py | 19 +++++++++++---- 5 files changed, 39 insertions(+), 25 deletions(-) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 2259c20..099b712 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -75,7 +75,6 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): mse = lambda xs, ys: mean_square_loss(model, xs, ys) loss_and_grads = tfe.implicit_value_and_gradients(mse) - tf.train.get_or_create_global_step() if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= @@ -87,12 +86,13 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if verbose: print("Iteration %d: loss = %s" % (i, loss.numpy())) - optimizer.apply_gradients(grads, global_step=tf.train.get_global_step()) + optimizer.apply_gradients(grads) if logdir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("loss", loss) + tf.contrib.summary.scalar("loss", loss, step=i) + tf.contrib.summary.scalar("step", i, step=i) def synthetic_dataset(w, b, noise_level, batch_size, num_batches): diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 7f9ef53..c3f70df 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -120,11 +120,7 @@ class DType(object): @property def is_numpy_compatible(self): - numpy_incompatible = [ - types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE, - types_pb2.DT_RESOURCE_REF - ] - return self._type_enum not in numpy_incompatible + return self._type_enum not in _NUMPY_INCOMPATIBLE @property def as_numpy_dtype(self): @@ -162,7 +158,7 @@ class DType(object): @property def is_quantized(self): """Returns whether this is a quantized data type.""" - return self.base_dtype in [qint8, quint8, qint16, quint16, qint32] + return self.base_dtype in _QUANTIZED_DTYPES_NO_REF @property def is_unsigned(self): @@ -401,6 +397,11 @@ quint16_ref = DType(types_pb2.DT_QUINT16_REF) qint32_ref = DType(types_pb2.DT_QINT32_REF) bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) +_NUMPY_INCOMPATIBLE = frozenset([ + types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE, + types_pb2.DT_RESOURCE_REF +]) + # Maintain an intern table so that we don't have to create a large # number of small objects. _INTERN_TABLE = { @@ -645,10 +646,10 @@ _TF_TO_NP = { _np_bfloat16, } -QUANTIZED_DTYPES = frozenset([ - qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref, - quint16_ref, qint32_ref -]) +_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32]) +_QUANTIZED_DTYPES_REF = frozenset( + [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref]) +QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF) tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES") _PYTHON_TO_TF = { @@ -662,10 +663,9 @@ def as_dtype(type_value): """Converts the given `type_value` to a `DType`. Args: - type_value: A value that can be converted to a `tf.DType` - object. This may currently be a `tf.DType` object, a - [`DataType` - enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), + type_value: A value that can be converted to a `tf.DType` object. This may + currently be a `tf.DType` object, a [`DataType` + enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto), a string type name, or a `numpy.dtype`. Returns: diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 0dd2946..c9be3d5 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -961,9 +961,12 @@ def unknown_shape(ndims=None): return TensorShape([Dimension(None)] * ndims) +_SCALAR_SHAPE = TensorShape([]) + + def scalar(): """Returns a shape representing a scalar.""" - return TensorShape([]) + return _SCALAR_SHAPE def vector(length): diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py index 5c4a2db..ad65942 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections +from tensorflow.python.framework import constant_op from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers @@ -71,10 +72,11 @@ class _CuDNNRNN(RNN): self.constants_spec = None self._states = None self._num_constants = None + self._vector_shape = constant_op.constant([-1]) def _canonical_to_params(self, weights, biases): - weights = [array_ops.reshape(x, (-1,)) for x in weights] - biases = [array_ops.reshape(x, (-1,)) for x in biases] + weights = [array_ops.reshape(x, self._vector_shape) for x in weights] + biases = [array_ops.reshape(x, self._vector_shape) for x in biases] return array_ops.concat(weights + biases, axis=0) def call(self, inputs, mask=None, training=None, initial_state=None): diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 09a4425..a0b55eb 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -2311,13 +2311,22 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) - keep_prob = ops.convert_to_tensor( - keep_prob, dtype=x.dtype, name="keep_prob") - keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) - # Do nothing if we know keep_prob == 1 - if tensor_util.constant_value(keep_prob) == 1: + # Early return if nothing needs to be dropped. + if isinstance(keep_prob, float) and keep_prob == 1: return x + if context.executing_eagerly(): + if isinstance(keep_prob, ops.EagerTensor): + if keep_prob.numpy() == 1: + return x + else: + keep_prob = ops.convert_to_tensor( + keep_prob, dtype=x.dtype, name="keep_prob") + keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) + + # Do nothing if we know keep_prob == 1 + if tensor_util.constant_value(keep_prob) == 1: + return x noise_shape = _get_noise_shape(x, noise_shape) -- 2.7.4