Improvement to eager linear regression benchmark
authorAkshay Modi <nareshmodi@google.com>
Tue, 6 Mar 2018 21:06:53 +0000 (13:06 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 21:11:11 +0000 (13:11 -0800)
Before:
entry {
  name: "EagerLinearRegressionBenchmark.eager_train_cpu"
  iters: 2000
  wall_time: 2.45178794861
  extras {
    key: "examples_per_sec"
    value {
      double_value: 52206.7987456
    }
  }
}

After:
entry {
  name: "EagerLinearRegressionBenchmark.eager_train_cpu"
  iters: 2000
  wall_time: 1.9873790741
  extras {
    key: "examples_per_sec"
    value {
      double_value: 64406.4344182
    }
  }
}
PiperOrigin-RevId: 188068838

13 files changed:
tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
tensorflow/python/eager/backprop.py
tensorflow/python/eager/context.py
tensorflow/python/eager/pywrap_tfe_src.cc
tensorflow/python/framework/tensor_shape.py
tensorflow/python/framework/test_util.py
tensorflow/python/layers/base.py
tensorflow/python/layers/core.py
tensorflow/python/ops/math_grad.py
tensorflow/python/ops/math_ops.py
tensorflow/python/ops/nn_ops.py
tensorflow/python/ops/resource_variable_ops.py
tensorflow/python/training/gradient_descent.py

index 157a636..6ab847c 100644 (file)
@@ -54,7 +54,7 @@ class LinearModel(tf.keras.Model):
 
 
 def mean_square_loss(model, xs, ys):
-  return tf.reduce_mean(tf.square(model(xs) - ys))
+  return tf.reduce_mean(tf.square(tf.subtract(model(xs), ys)))
 
 
 def fit(model, dataset, optimizer, verbose=False, logdir=None):
index 14bcc60..88de1a9 100644 (file)
@@ -18,7 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import functools
 import operator
 import threading
@@ -43,26 +42,6 @@ from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
 
 
-class _TensorCache(object):
-  """Simple cache which evicts items based on length in a FIFO manner."""
-
-  def __init__(self, max_items=256):
-    self._data = collections.OrderedDict()
-    self._max_items = max_items if max_items else 256
-
-  def put(self, key, value):
-    self._data[key] = value
-
-    if len(self._data) > self._max_items:
-      self._data.popitem(last=False)
-
-  def get(self, key):
-    return self._data.get(key, None)
-
-  def flush(self):
-    self._data = {}
-
-
 _op_attr_type_cache = {}
 
 
@@ -622,7 +601,7 @@ def _num_elements(grad):
   raise ValueError("`grad` not a Tensor or IndexedSlices.")
 
 
-_zeros_cache = _TensorCache()
+_zeros_cache = context._TensorCache()  # pylint: disable=protected-access
 
 
 def _fast_fill(value, shape, dtype):
index 0e9c21b..fb27ab6 100644 (file)
@@ -54,6 +54,26 @@ DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
     pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
 
 
+class _TensorCache(object):
+  """Simple cache which evicts items based on length in a FIFO manner."""
+
+  def __init__(self, max_items=256):
+    self._data = collections.OrderedDict()
+    self._max_items = max_items if max_items else 256
+
+  def put(self, key, value):
+    self._data[key] = value
+
+    if len(self._data) > self._max_items:
+      self._data.popitem(last=False)
+
+  def get(self, key):
+    return self._data.get(key, None)
+
+  def flush(self):
+    self._data = {}
+
+
 # TODO(agarwal): better name ?
 class _EagerContext(threading.local):
   """Thread local eager context."""
@@ -67,6 +87,7 @@ class _EagerContext(threading.local):
     self.recording_summaries = False
     self.summary_writer_resource = None
     self.scalar_cache = {}
+    self.ones_rank_cache = _TensorCache()
 
 
 ContextStackEntry = collections.namedtuple(
@@ -251,6 +272,10 @@ class Context(object):
     """Per-device cache for scalars."""
     return self._eager_context.scalar_cache
 
+  def ones_rank_cache(self):
+    """Per-device cache for scalars."""
+    return self._eager_context.ones_rank_cache
+
   @property
   def scope_name(self):
     """Returns scope name for the current thread."""
index 27c9d05..9146e2b 100644 (file)
@@ -93,6 +93,34 @@ Py_ssize_t TensorShapeNumDims(PyObject* value) {
   return size;
 }
 
+bool IsInteger(PyObject* py_value) {
+#if PY_MAJOR_VERSION >= 3
+  return PyLong_Check(py_value);
+#else
+  return PyInt_Check(py_value);
+#endif
+}
+
+bool ParseDimensionValue(const string& key, PyObject* py_value,
+                         TF_Status* status, int64_t* value) {
+  if (IsInteger(py_value)) {
+    return ParseInt64Value(key, py_value, status, value);
+  }
+
+  tensorflow::Safe_PyObjectPtr dimension_value(
+      PyObject_GetAttrString(py_value, "_value"));
+  if (dimension_value == nullptr) {
+    TF_SetStatus(
+        status, TF_INVALID_ARGUMENT,
+        tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
+                                    ", got ", py_value->ob_type->tp_name)
+            .c_str());
+    return false;
+  }
+
+  return ParseInt64Value(key, dimension_value.get(), status, value);
+}
+
 bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
                       const char** value) {
   if (PyBytes_Check(py_value)) {
@@ -119,14 +147,6 @@ bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
   return true;
 }
 
-bool IsInteger(PyObject* py_value) {
-#if PY_MAJOR_VERSION >= 3
-  return PyLong_Check(py_value);
-#else
-  return PyInt_Check(py_value);
-#endif
-}
-
 // The passed in py_value is expected to be an object of the python type
 // dtypes.DType or an int.
 bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
@@ -135,7 +155,8 @@ bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
     return ParseIntValue(key, py_value, status, value);
   }
 
-  PyObject* py_type_enum = PyObject_GetAttrString(py_value, "_type_enum");
+  tensorflow::Safe_PyObjectPtr py_type_enum(
+      PyObject_GetAttrString(py_value, "_type_enum"));
   if (py_type_enum == nullptr) {
     TF_SetStatus(
         status, TF_INVALID_ARGUMENT,
@@ -145,13 +166,7 @@ bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
     return false;
   }
 
-  if (!ParseIntValue(key, py_type_enum, status, value)) {
-    Py_DECREF(py_type_enum);
-    return false;
-  }
-
-  Py_DECREF(py_type_enum);
-  return true;
+  return ParseIntValue(key, py_type_enum.get(), status, value);
 }
 
 bool SetOpAttrList(
@@ -240,7 +255,8 @@ bool SetOpAttrList(
           auto inner_py_value = PySequence_ITEM(py_value, j);
           if (inner_py_value == Py_None) {
             *offset = -1;
-          } else if (!ParseInt64Value(key, inner_py_value, status, offset)) {
+          } else if (!ParseDimensionValue(key, inner_py_value, status,
+                                          offset)) {
             return false;
           }
           ++offset;
@@ -424,7 +440,8 @@ bool SetOpAttrScalar(
         auto inner_py_value = PySequence_ITEM(py_value, i);
         if (inner_py_value == Py_None) {
           dims[i] = -1;
-        } else if (!ParseInt64Value(key, inner_py_value, status, &dims[i])) {
+        } else if (!ParseDimensionValue(key, inner_py_value, status,
+                                        &dims[i])) {
           return false;
         }
       }
index 222071c..6f2ab84 100644 (file)
@@ -456,6 +456,7 @@ class TensorShape(object):
       else:
         # Got a list of dimensions
         self._dims = [as_dimension(d) for d in dims_iter]
+    self._ndims = None
 
   def __repr__(self):
     return "TensorShape(%r)" % self._dims
@@ -473,19 +474,26 @@ class TensorShape(object):
     """Returns a list of Dimensions, or None if the shape is unspecified."""
     return self._dims
 
+  @dims.setter
+  def dims(self, dims):
+    self._dims = dims
+    self._ndims = None
+
   @property
   def ndims(self):
     """Returns the rank of this shape, or None if it is unspecified."""
     if self._dims is None:
       return None
     else:
-      return len(self._dims)
+      if self._ndims is None:
+        self._ndims = len(self._dims)
+      return self._ndims
 
   def __len__(self):
     """Returns the rank of this shape, or raises ValueError if unspecified."""
     if self._dims is None:
       raise ValueError("Cannot take the length of Shape with unknown rank.")
-    return len(self._dims)
+    return self.ndims
 
   def __bool__(self):
     """Returns True if this shape contains non-zero information."""
index 78252e4..1c8398e 100644 (file)
@@ -472,6 +472,7 @@ def assert_no_new_tensors(f):
     # Make an effort to clear caches, which would otherwise look like leaked
     # Tensors.
     backprop._zeros_cache.flush()
+    context.get_default_context().ones_rank_cache().flush()
     context.get_default_context().scalar_cache().clear()
     gc.collect()
     tensors_after = [
index c6d16a3..15f7278 100644 (file)
@@ -129,10 +129,10 @@ class Layer(checkpointable.CheckpointableBase):
     self._reuse = kwargs.get('_reuse')
     self._graph = None  # Will be set at build time.
     self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
-    call_fn_args = estimator_util.fn_args(self.call)
-    self._compute_previous_mask = ('mask' in call_fn_args or
+    self._call_fn_args = estimator_util.fn_args(self.call)
+    self._compute_previous_mask = ('mask' in self._call_fn_args or
                                    hasattr(self, 'compute_mask'))
-    self._call_has_scope_arg = 'scope' in call_fn_args
+    self._call_has_scope_arg = 'scope' in self._call_fn_args
 
     # These lists will be filled via successive calls
     # to self._add_inbound_node().
@@ -642,8 +642,9 @@ class Layer(checkpointable.CheckpointableBase):
     if (not hasattr(self, '_compute_previous_mask') or
         self._compute_previous_mask):
       previous_mask = _collect_previous_mask(inputs)
-      if ('mask' in estimator_util.fn_args(self.call) and
-          'mask' not in kwargs and
+      if not hasattr(self, '_call_fn_args'):
+        self._call_fn_args = estimator_util.fn_args(self.call)
+      if ('mask' in self._call_fn_args and 'mask' not in kwargs and
           not _is_all_none(previous_mask)):
         # The previous layer generated a mask, and mask was not explicitly pass
         # to __call__, hence we set previous_mask as the default value.
@@ -699,7 +700,9 @@ class Layer(checkpointable.CheckpointableBase):
           # TODO(agarwal): Fix the sub-classes and avoid this complexity.
           call_has_scope_arg = self._call_has_scope_arg
         except AttributeError:
-          call_has_scope_arg = 'scope' in estimator_util.fn_args(self.call)
+          self._call_fn_args = estimator_util.fn_args(self.call)
+          self._call_has_scope_arg = 'scope' in self._call_fn_args
+          call_has_scope_arg = self._call_has_scope_arg
         if call_has_scope_arg:
           kwargs['scope'] = scope
         # Check input assumptions set after layer building, e.g. input shape.
index 6970bf9..bdbbc59 100644 (file)
@@ -35,6 +35,7 @@ from tensorflow.python.layers import utils
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import gen_math_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import standard_ops
@@ -159,7 +160,7 @@ class Dense(base.Layer):
         output_shape = shape[:-1] + [self.units]
         outputs.set_shape(output_shape)
     else:
-      outputs = standard_ops.matmul(inputs, self.kernel)
+      outputs = gen_math_ops.mat_mul(inputs, self.kernel)
     if self.use_bias:
       outputs = nn.bias_add(outputs, self.bias)
     if self.activation is not None:
index 51e19b4..55dd0c0 100644 (file)
@@ -52,10 +52,18 @@ def _SumGrad(op, grad):
     if axes is not None:
       rank = len(input_0_shape)
       if np.array_equal(axes, np.arange(rank)):  # Reduce all dims.
-        grad = array_ops.reshape(grad, [1] * rank)
+        if context.in_graph_mode():
+          new_shape = [1] * rank
+        else:
+          ctx = context.context()
+          new_shape = ctx.ones_rank_cache().get(rank)
+          if new_shape is None:
+            new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
+            ctx.ones_rank_cache().put(rank, new_shape)
+        grad = array_ops.reshape(grad, new_shape)
         # If shape is not fully defined (but rank is), we use Shape.
         if None not in input_0_shape:
-          input_shape = input_0_shape
+          input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32)
         else:
           input_shape = array_ops.shape(op.inputs[0])
         return [array_ops.tile(grad, input_shape), None]
@@ -338,7 +346,8 @@ def _SquareGrad(op, grad):
   # Added control dependencies to prevent 2*x from being computed too early.
   with ops.control_dependencies([grad]):
     x = math_ops.conj(x)
-    return math_ops.multiply(grad, math_ops.multiply(x, 2.0))
+    y = constant_op.constant(2.0, dtype=x.dtype)
+    return math_ops.multiply(grad, math_ops.multiply(x, y))
 
 
 @ops.RegisterGradient("Sqrt")
index 14d6862..c019a58 100644 (file)
@@ -176,6 +176,11 @@ arg_max = deprecation.deprecated(None, "Use `argmax` instead")(arg_max)  # pylin
 arg_min = deprecation.deprecated(None, "Use `argmin` instead")(arg_min)  # pylint: disable=used-before-assignment
 
 
+# This is set by resource_variable_ops.py. It is included in this way since
+# there is a circular dependency between math_ops and resource_variable_ops
+_resource_variable_type = None
+
+
 def _set_doc(doc):
 
   def _decorator(func):
@@ -2002,8 +2007,15 @@ def matmul(a,
     if transpose_b and adjoint_b:
       raise ValueError("Only one of transpose_b and adjoint_b can be True.")
 
-    a = ops.convert_to_tensor(a, name="a")
-    b = ops.convert_to_tensor(b, name="b")
+    if context.in_graph_mode():
+      a = ops.convert_to_tensor(a, name="a")
+      b = ops.convert_to_tensor(b, name="b")
+    else:
+      if not isinstance(a, (ops.EagerTensor, _resource_variable_type)):
+        a = ops.convert_to_tensor(a, name="a")
+      if not isinstance(b, (ops.EagerTensor, _resource_variable_type)):
+        b = ops.convert_to_tensor(b, name="b")
+
     # TODO(apassos) remove _shape_tuple here when it is not needed.
     a_shape = a._shape_tuple()  # pylint: disable=protected-access
     b_shape = b._shape_tuple()  # pylint: disable=protected-access
index 852ab36..66a05f2 100644 (file)
@@ -1504,8 +1504,9 @@ def bias_add(value, bias, data_format=None, name=None):
     A `Tensor` with the same type as `value`.
   """
   with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
-    value = ops.convert_to_tensor(value, name="input")
-    bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
+    if context.in_graph_mode():
+      value = ops.convert_to_tensor(value, name="input")
+      bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
     return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name)
 
 
index 6c5d692..5b8af80 100644 (file)
@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import gen_state_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variables
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import
@@ -483,6 +484,7 @@ class ResourceVariable(variables.Variable):
       # all in graph mode.
       self._handle_deleter = EagerResourceDeleter(
           handle=self._handle, handle_device=self._handle.device)
+    self._cached_shape_as_list = None
 
   def _init_from_proto(self, variable_def, import_scope=None):
     """Initializes from `VariableDef` proto."""
@@ -529,6 +531,7 @@ class ResourceVariable(variables.Variable):
     self._graph_element = g.get_tensor_by_name(
         self._handle.op.name + "/Read/ReadVariableOp:0")
     self._constraint = None
+    self._cached_shape_as_list = None
 
   def __nonzero__(self):
     return self.__bool__()
@@ -561,6 +564,20 @@ class ResourceVariable(variables.Variable):
     """The shape of this variable."""
     return self._shape
 
+  def _shape_as_list(self):
+    if self._cached_shape_as_list:
+      return self._cached_shape_as_list
+    if self.shape.ndims is None:
+      return None
+    self._cached_shape_as_list = [dim.value for dim in self.shape.dims]
+    return self._cached_shape_as_list
+
+  def _shape_tuple(self):
+    shape = self._shape_as_list()
+    if shape is None:
+      return None
+    return tuple(shape)
+
   @property
   def create(self):
     """The op responsible for initializing this variable."""
@@ -934,6 +951,7 @@ class ResourceVariable(variables.Variable):
 
 
 pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable)
+math_ops._resource_variable_type = ResourceVariable  # pylint: disable=protected-access
 
 
 def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
@@ -985,6 +1003,7 @@ class _UnreadVariable(ResourceVariable):
 
   def set_shape(self, shape):
     self._shape = shape
+    self._cached_shape_as_list = None
 
   @property
   def op(self):
index 380e14e..538164a 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
@@ -43,6 +44,7 @@ class GradientDescentOptimizer(optimizer.Optimizer):
     """
     super(GradientDescentOptimizer, self).__init__(use_locking, name)
     self._learning_rate = learning_rate
+    self._learning_rate_tensor = None
 
   def _apply_dense(self, grad, var):
     return training_ops.apply_gradient_descent(
@@ -69,5 +71,6 @@ class GradientDescentOptimizer(optimizer.Optimizer):
     return var.scatter_sub(delta, use_locking=self._use_locking)
 
   def _prepare(self):
-    self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
-                                                       name="learning_rate")
+    if context.in_graph_mode() or self._learning_rate_tensor is None:
+      self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+                                                         name="learning_rate")