Use constants in tf.zeros if the constant won't be too big.
authorSkye Wanderman-Milne <skyewm@google.com>
Thu, 5 Apr 2018 15:47:47 +0000 (08:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 15:54:14 +0000 (08:54 -0700)
Using fill saves on GraphDef size, but can slow down models since the
total number of ops is greater (fill + shape + constant op). This
change makes us only use fill for large shapes.

PiperOrigin-RevId: 191747456

tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
tensorflow/python/ops/array_ops.py

index 63fdd91..c7d8586 100644 (file)
@@ -842,12 +842,12 @@ class RNNCellTest(test.TestCase):
       batch_size = 3
       input_size = 4
       expected_state_c = np.array(
-          [[6.450831e-04, 4.697885e-04], [9.862894e-05, 7.212213e-04],
-           [4.401947e-04, 9.143004e-04]],
+          [[0.00072015, 0.00036633], [0.00083481, 0.00047266],
+           [0.00085111, 0.00053054]],
           dtype=np.float32)
       expected_state_h = np.array(
-          [[4.621217e-04, 3.365449e-04], [7.438179e-05, 5.439147e-04],
-           [3.347936e-04, 6.953785e-04]],
+          [[0.0005159, 0.00026243], [0.00062958, 0.00035646],
+           [0.00064732, 0.00040351]],
           dtype=np.float32)
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
index 68d4466..fa26e07 100644 (file)
@@ -1566,6 +1566,16 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
 # pylint: enable=invalid-name
 
 
+def _constant_if_small(value, shape, dtype, name):
+  try:
+    if np.prod(shape) < 1000:
+      return constant(value, shape=shape, dtype=dtype, name=name)
+  except TypeError:
+    # Happens when shape is a Tensor, list with Tensor elements, etc.
+    pass
+  return None
+
+
 @tf_export("zeros")
 def zeros(shape, dtype=dtypes.float32, name=None):
   """Creates a tensor with all elements set to zero.
@@ -1596,8 +1606,15 @@ def zeros(shape, dtype=dtypes.float32, name=None):
       zero = ""
     else:
       zero = 0
+
     if not isinstance(shape, ops.Tensor):
       try:
+        # Create a constant if it won't be very big. Otherwise create a fill op
+        # to prevent serialized GraphDefs from becoming too large.
+        output = _constant_if_small(zero, shape, dtype, name)
+        if output is not None:
+          return output
+
         # Go through tensor shapes to get int64-if-needed semantics
         shape = constant_op._tensor_shape_tensor_conversion_function(
             tensor_shape.TensorShape(shape))
@@ -1729,6 +1746,12 @@ def ones(shape, dtype=dtypes.float32, name=None):
     one = True if dtype == dtypes.bool else 1
     if not isinstance(shape, ops.Tensor):
       try:
+        # Create a constant if it won't be very big. Otherwise create a fill op
+        # to prevent serialized GraphDefs from becoming too large.
+        output = _constant_if_small(one, shape, dtype, name)
+        if output is not None:
+          return output
+
         # Go through tensor shapes to get int64-if-needed semantics
         shape = constant_op._tensor_shape_tensor_conversion_function(
             tensor_shape.TensorShape(shape))