Fix oom_test so that it doesn't try to allocate a giant host buffer when
authorJustin Lebar <jlebar@google.com>
Thu, 3 May 2018 22:58:43 +0000 (15:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 4 May 2018 17:33:25 +0000 (10:33 -0700)
run without --config=cuda.  Sadly the best way I could come up with is
pretty hacky.

PiperOrigin-RevId: 195325149

tensorflow/compiler/tests/oom_test.py

index 1434e96..d68d320 100644 (file)
@@ -22,6 +22,8 @@ from tensorflow.compiler.tests import xla_test
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.platform import googletest
 
 
@@ -42,20 +44,33 @@ class OutOfMemoryTest(xla_test.XLATestCase):
     """
 
     def test_loop():
-      size = 2e8
+      size = int(2e8)
       while True:
         with self.test_session():
-          # Force the compiled code to not be constant by feeding in an addend.
-          p = array_ops.placeholder(dtypes.float32, shape=[])
+          # Force the compiled code to not be constant by feeding in a
+          # parameter.
+          p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1])
           with self.test_scope():
-            # Create a large R1 tensor.
-            c = array_ops.zeros([size, 1]) + p
+            # Create a computation that produces a large R1 tensor as an
+            # intermediate result.  Reduce it down so that if this file was
+            # compiled without --config=cuda, we don't force a D2H copy of a
+            # large tensor and potentially OOM the host.
+            #
+            # This is a bit tricky because XLA:GPU doesn't currently support RNG
+            # ops.  Here we rely on the fact that XLA doesn't do algebraic
+            # simplifications on conv(<ones>, <filter>).
+            c = math_ops.reduce_sum(
+                nn_ops.convolution(
+                    array_ops.ones([1, size, 1]),
+                    p,
+                    padding='SAME',
+                    data_format='NWC'))
 
-            c.eval(feed_dict={p: 1.0})
+            c.eval(feed_dict={p: [[[1.0]], [[2.0]]]})
             size *= 2
 
     self.assertRaises(errors.ResourceExhaustedError, test_loop)
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
   googletest.main()