Minor eager performance improvements
authorAkshay Modi <nareshmodi@google.com>
Sat, 26 May 2018 00:37:01 +0000 (17:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 26 May 2018 00:39:34 +0000 (17:39 -0700)
- remove linear regression dependence on global step.
  This speeds things up a lot for the benchmark (since it removes a bunch of
  unnecessary code), but is obviously not a fair comparison.
  I think its worth doing, since I don't see any reason to have a global step
  in eager.

- nn_ops dropout had an unnecessary convert_to_tensor, convert back to numpy
  (with a GPU this would copy out, copy back).
- cudnn_recurrent reshape would always fallback to the slow path - so I just
  converted it to be in the fastpath - this will be low impact.

- tensor_shape should not generate a new object every time
- remove unnecessary list creation and searching in some dtypes functions

PiperOrigin-RevId: 198127757

tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
tensorflow/python/framework/dtypes.py
tensorflow/python/framework/tensor_shape.py
tensorflow/python/keras/layers/cudnn_recurrent.py
tensorflow/python/ops/nn_ops.py

index 2259c20..099b712 100644 (file)
@@ -75,7 +75,6 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None):
   mse = lambda xs, ys: mean_square_loss(model, xs, ys)
   loss_and_grads = tfe.implicit_value_and_gradients(mse)
 
-  tf.train.get_or_create_global_step()
   if logdir:
     # Support for TensorBoard summaries. Once training has started, use:
     #   tensorboard --logdir=<logdir>
@@ -87,12 +86,13 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None):
     if verbose:
       print("Iteration %d: loss = %s" % (i, loss.numpy()))
 
-    optimizer.apply_gradients(grads, global_step=tf.train.get_global_step())
+    optimizer.apply_gradients(grads)
 
     if logdir:
       with summary_writer.as_default():
         with tf.contrib.summary.always_record_summaries():
-          tf.contrib.summary.scalar("loss", loss)
+          tf.contrib.summary.scalar("loss", loss, step=i)
+          tf.contrib.summary.scalar("step", i, step=i)
 
 
 def synthetic_dataset(w, b, noise_level, batch_size, num_batches):
index 7f9ef53..c3f70df 100644 (file)
@@ -120,11 +120,7 @@ class DType(object):
 
   @property
   def is_numpy_compatible(self):
-    numpy_incompatible = [
-        types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE,
-        types_pb2.DT_RESOURCE_REF
-    ]
-    return self._type_enum not in numpy_incompatible
+    return self._type_enum not in _NUMPY_INCOMPATIBLE
 
   @property
   def as_numpy_dtype(self):
@@ -162,7 +158,7 @@ class DType(object):
   @property
   def is_quantized(self):
     """Returns whether this is a quantized data type."""
-    return self.base_dtype in [qint8, quint8, qint16, quint16, qint32]
+    return self.base_dtype in _QUANTIZED_DTYPES_NO_REF
 
   @property
   def is_unsigned(self):
@@ -401,6 +397,11 @@ quint16_ref = DType(types_pb2.DT_QUINT16_REF)
 qint32_ref = DType(types_pb2.DT_QINT32_REF)
 bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
 
+_NUMPY_INCOMPATIBLE = frozenset([
+    types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE,
+    types_pb2.DT_RESOURCE_REF
+])
+
 # Maintain an intern table so that we don't have to create a large
 # number of small objects.
 _INTERN_TABLE = {
@@ -645,10 +646,10 @@ _TF_TO_NP = {
         _np_bfloat16,
 }
 
-QUANTIZED_DTYPES = frozenset([
-    qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref,
-    quint16_ref, qint32_ref
-])
+_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32])
+_QUANTIZED_DTYPES_REF = frozenset(
+    [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref])
+QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF)
 tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
 
 _PYTHON_TO_TF = {
@@ -662,10 +663,9 @@ def as_dtype(type_value):
   """Converts the given `type_value` to a `DType`.
 
   Args:
-    type_value: A value that can be converted to a `tf.DType`
-      object. This may currently be a `tf.DType` object, a
-      [`DataType`
-        enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
+    type_value: A value that can be converted to a `tf.DType` object. This may
+      currently be a `tf.DType` object, a [`DataType`
+      enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
       a string type name, or a `numpy.dtype`.
 
   Returns:
index 0dd2946..c9be3d5 100644 (file)
@@ -961,9 +961,12 @@ def unknown_shape(ndims=None):
     return TensorShape([Dimension(None)] * ndims)
 
 
+_SCALAR_SHAPE = TensorShape([])
+
+
 def scalar():
   """Returns a shape representing a scalar."""
-  return TensorShape([])
+  return _SCALAR_SHAPE
 
 
 def vector(length):
index 5c4a2db..ad65942 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import collections
 
+from tensorflow.python.framework import constant_op
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import constraints
 from tensorflow.python.keras import initializers
@@ -71,10 +72,11 @@ class _CuDNNRNN(RNN):
     self.constants_spec = None
     self._states = None
     self._num_constants = None
+    self._vector_shape = constant_op.constant([-1])
 
   def _canonical_to_params(self, weights, biases):
-    weights = [array_ops.reshape(x, (-1,)) for x in weights]
-    biases = [array_ops.reshape(x, (-1,)) for x in biases]
+    weights = [array_ops.reshape(x, self._vector_shape) for x in weights]
+    biases = [array_ops.reshape(x, self._vector_shape) for x in biases]
     return array_ops.concat(weights + biases, axis=0)
 
   def call(self, inputs, mask=None, training=None, initial_state=None):
index 09a4425..a0b55eb 100644 (file)
@@ -2311,13 +2311,22 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):  # pylint: di
     if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1:
       raise ValueError("keep_prob must be a scalar tensor or a float in the "
                        "range (0, 1], got %g" % keep_prob)
-    keep_prob = ops.convert_to_tensor(
-        keep_prob, dtype=x.dtype, name="keep_prob")
-    keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
 
-    # Do nothing if we know keep_prob == 1
-    if tensor_util.constant_value(keep_prob) == 1:
+    # Early return if nothing needs to be dropped.
+    if isinstance(keep_prob, float) and keep_prob == 1:
       return x
+    if context.executing_eagerly():
+      if isinstance(keep_prob, ops.EagerTensor):
+        if keep_prob.numpy() == 1:
+          return x
+    else:
+      keep_prob = ops.convert_to_tensor(
+          keep_prob, dtype=x.dtype, name="keep_prob")
+      keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
+
+      # Do nothing if we know keep_prob == 1
+      if tensor_util.constant_value(keep_prob) == 1:
+        return x
 
     noise_shape = _get_noise_shape(x, noise_shape)