Change Softmax on CUDA to use fp32 for denominator when input/output are fp16.
authorJames Qin <jamesqin@google.com>
Wed, 21 Mar 2018 22:55:30 +0000 (15:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 22:58:16 +0000 (15:58 -0700)
This avoids potential overflow in the denominator, also makes sure accumulation is done
in high precision.

PiperOrigin-RevId: 189982655

tensorflow/core/kernels/softmax_op_gpu.cu.cc
tensorflow/python/framework/test_util.py
tensorflow/python/kernel_tests/BUILD
tensorflow/python/kernel_tests/softmax_op_test.py

index 1f4a82a..130d693 100644 (file)
@@ -33,8 +33,42 @@ namespace tensorflow {
 
 namespace {
 
+template <typename U, typename T>
+__device__ __host__ EIGEN_STRONG_INLINE
+    typename std::enable_if<!std::is_same<T, U>::value, U>::type
+    strict_cast(T t);
+
+template <typename U, typename T>
+__device__ __host__ EIGEN_STRONG_INLINE
+    typename std::enable_if<std::is_same<T, U>::value, U>::type
+    strict_cast(T t) {
+  return t;
+}
+
+template <>
+__device__ __host__ EIGEN_STRONG_INLINE float strict_cast<float, Eigen::half>(
+    Eigen::half t) {
+  return functor::HalfToFloat()(t);
+}
+
+template <>
+__device__ __host__ EIGEN_STRONG_INLINE Eigen::half
+strict_cast<Eigen::half, float>(float t) {
+  return functor::FloatToHalf()(t);
+}
+
 template <typename T>
-__global__ void GenerateNormalizedProb(const T* logits, const T* sum_probs,
+struct softmax_traits {
+  using accumulator_type = T;
+};
+
+template <>
+struct softmax_traits<Eigen::half> {
+  using accumulator_type = float;
+};
+
+template <typename T, typename U>
+__global__ void GenerateNormalizedProb(const T* logits, const U* sum_probs,
                                        const T* max_logits, T* output,
                                        const int num_rows, const int num_cols,
                                        const bool in_log_space) {
@@ -43,25 +77,33 @@ __global__ void GenerateNormalizedProb(const T* logits, const T* sum_probs,
   const int row = tid / num_cols;
   const int col = tid % num_cols;
 
+  // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
+  U input = strict_cast<U>(logits[tid]);
+  U max_val = strict_cast<U>(ldg(max_logits + row));
+  U result;
+
   if (row < num_rows && col < num_cols) {
-    if (in_log_space)
-      output[tid] =
-          logits[tid] - ldg(max_logits + row) - log(ldg(sum_probs + row));
-    else
-      output[tid] =
-          exp(logits[tid] - ldg(max_logits + row)) / ldg(sum_probs + row);
+    if (in_log_space) {
+      result = input - max_val - log(ldg(sum_probs + row));
+    } else {
+      result = exp(input - max_val) / ldg(sum_probs + row);
+    }
+    output[tid] = strict_cast<T>(result);
   }
 }
 
-template <typename T>
+template <typename T, typename U>
 struct SubtractAndExpFunctor {
   __host__ __device__ SubtractAndExpFunctor(const T* logits,
                                             const T* max_logits,
                                             const int num_cols)
       : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {}
 
-  __host__ __device__ T operator()(const int gid) const {
-    return exp(logits_[gid] - ldg(max_logits_ + gid / num_cols_));
+  __host__ __device__ U operator()(const int gid) const {
+    // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
+    const U diff =
+        strict_cast<U>(logits_[gid] - ldg(max_logits_ + gid / num_cols_));
+    return exp(diff);
   }
 
   const T* logits_;
@@ -80,7 +122,6 @@ void DoRowReduction(OpKernelContext* context, T* output, InputIter input,
   functor::ReduceImpl<T, Op, T*, InputIter, ReductionAxes>(
       context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
 }
-
 }  // namespace
 
 template <typename T>
@@ -108,8 +149,10 @@ class SoftmaxOpGPU : public OpKernel {
       OP_REQUIRES_OK(context,
                      context->allocate_temp(DataTypeToEnum<T>::value,
                                             softmax_out->shape(), &max_logits));
+
+      typedef typename softmax_traits<T>::accumulator_type acc_type;
       OP_REQUIRES_OK(context,
-                     context->allocate_temp(DataTypeToEnum<T>::value,
+                     context->allocate_temp(DataTypeToEnum<acc_type>::value,
                                             softmax_out->shape(), &sum_probs));
 
       DoRowReduction<T, cub::Max, const T*>(
@@ -120,25 +163,28 @@ class SoftmaxOpGPU : public OpKernel {
       const int numBlocks = Eigen::divup(rows * cols, numThreads);
 
       cub::CountingInputIterator<int> counting_iterator(0);
-      typedef cub::TransformInputIterator<T, SubtractAndExpFunctor<T>,
+      typedef cub::TransformInputIterator<acc_type,
+                                          SubtractAndExpFunctor<T, acc_type>,
                                           cub::CountingInputIterator<int>>
           InputIterType;
 
       InputIterType input_itr(
           counting_iterator,
-          SubtractAndExpFunctor<T>(
+          SubtractAndExpFunctor<T, acc_type>(
               reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
               reinterpret_cast<const T*>(max_logits.flat<T>().data()), cols));
 
-      DoRowReduction<T, cub::Sum, InputIterType>(
-          context, const_cast<T*>(sum_probs.flat<T>().data()), input_itr, rows,
-          cols);
+      DoRowReduction<acc_type, cub::Sum, InputIterType>(
+          context, const_cast<acc_type*>(sum_probs.flat<acc_type>().data()),
+          input_itr, rows, cols);
 
-      GenerateNormalizedProb<<<numBlocks, numThreads, 0, cu_stream>>>(
-          reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
-          reinterpret_cast<const T*>(sum_probs.flat<T>().data()),
-          reinterpret_cast<const T*>(max_logits.flat<T>().data()),
-          const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_);
+      GenerateNormalizedProb<T, acc_type>
+          <<<numBlocks, numThreads, 0, cu_stream>>>(
+              reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
+              reinterpret_cast<const acc_type*>(
+                  sum_probs.flat<acc_type>().data()),
+              reinterpret_cast<const T*>(max_logits.flat<T>().data()),
+              const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_);
     }
   }
 
index d8f8569..43106b6 100644 (file)
@@ -53,6 +53,7 @@ from tensorflow.python.eager import tape  # pylint: disable=unused-import
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
@@ -201,6 +202,7 @@ def _strip_checkpoint_v2_randomized(graph_def):
 def IsGoogleCudaEnabled():
   return pywrap_tensorflow.IsGoogleCudaEnabled()
 
+
 def CudaSupportsHalfMatMulAndConv():
   return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv()
 
@@ -335,6 +337,8 @@ def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
     # Make sure default graph reflects prev_value in case next test doesn't call
     # reset_default_graph().
     ops.reset_default_graph()
+
+
 # pylint: disable=protected-access
 
 
@@ -451,7 +455,8 @@ def with_c_api(cls):
   # If the C API is already enabled, don't do anything. Some tests break if the
   # same test is run twice, so this allows us to turn on the C API by default
   # without breaking these tests.
-  if ops._USE_C_API: return cls
+  if ops._USE_C_API:
+    return cls
 
   for name, value in cls.__dict__.copy().items():
     if callable(value) and name.startswith("test"):
@@ -469,6 +474,7 @@ def assert_no_new_pyobjects_executing_eagerly(f):
   Useful for checking that there are no missing Py_DECREFs in the C exercised by
   a bit of Python.
   """
+
   def decorator(self, **kwargs):
     """Warms up, gets an object count, runs the test, checks for new objects."""
     with context.eager_mode():
@@ -483,8 +489,10 @@ def assert_no_new_pyobjects_executing_eagerly(f):
       new_count = len(gc.get_objects())
       self.assertEqual(previous_count, new_count)
       gc.enable()
+
   return decorator
 
+
 def assert_no_new_tensors(f):
   """Decorator for asserting that no new Tensors persist after a test.
 
@@ -508,17 +516,15 @@ def assert_no_new_tensors(f):
 
     def _is_tensorflow_object(obj):
       try:
-        return isinstance(obj, (
-            ops.Tensor,
-            variables.Variable,
-            tensor_shape.Dimension,
-            tensor_shape.TensorShape))
+        return isinstance(obj,
+                          (ops.Tensor, variables.Variable,
+                           tensor_shape.Dimension, tensor_shape.TensorShape))
       except ReferenceError:
         # If the object no longer exists, we don't care about it.
         return False
 
-    tensors_before = set(id(obj) for obj in gc.get_objects()
-                         if _is_tensorflow_object(obj))
+    tensors_before = set(
+        id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
     outside_graph_key = ops.get_default_graph()._graph_key
     with ops.Graph().as_default():
       # Run the test in a new graph so that collections get cleared when it's
@@ -572,18 +578,18 @@ def assert_no_garbage_created(f):
           "likely due to a reference cycle. New objects in cycle(s):")
       for i, obj in enumerate(gc.garbage[previous_garbage:]):
         try:
-          logging.error(
-              "Object %d of %d" % (i, len(gc.garbage) - previous_garbage))
+          logging.error("Object %d of %d", i,
+                        len(gc.garbage) - previous_garbage)
+
           def _safe_object_str(obj):
             return "<%s %d>" % (obj.__class__.__name__, id(obj))
-          logging.error("  Object type: %s" % (_safe_object_str(obj),))
-          logging.error("  Referrer types: %s" % (
-              ', '.join([_safe_object_str(ref)
-                         for ref in gc.get_referrers(obj)]),))
-          logging.error("  Referent types: %s" % (
-              ', '.join([_safe_object_str(ref)
-                         for ref in gc.get_referents(obj)]),))
-          logging.error("  Object attribute names: %s" % (dir(obj),))
+
+          logging.error("  Object type: %s", _safe_object_str(obj))
+          logging.error("  Referrer types: %s", ", ".join(
+              [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
+          logging.error("  Referent types: %s", ", ".join(
+              [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
+          logging.error("  Object attribute names: %s", dir(obj))
           logging.error("  Object __str__:")
           logging.error(obj)
           logging.error("  Object __repr__:")
@@ -705,15 +711,23 @@ def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
       return 0, 0
     return int(match.group(1)), int(match.group(2))
 
-  for local_device in device_lib.list_local_devices():
-    if local_device.device_type == "GPU":
-      if (min_cuda_compute_capability is None or
-          compute_capability_from_device_desc(local_device.physical_device_desc)
-          >= min_cuda_compute_capability):
+  try:
+    for local_device in device_lib.list_local_devices():
+      if local_device.device_type == "GPU":
+        if (min_cuda_compute_capability is None or
+            compute_capability_from_device_desc(
+                local_device.physical_device_desc) >=
+            min_cuda_compute_capability):
+          return True
+      if local_device.device_type == "SYCL" and not cuda_only:
         return True
-    if local_device.device_type == "SYCL" and not cuda_only:
-      return True
-  return False
+    return False
+  except errors_impl.NotFoundError as e:
+    if not all([x in str(e) for x in ["CUDA", "not find"]]):
+      raise e
+    else:
+      logging.error(str(e))
+      return False
 
 
 @contextlib.contextmanager
@@ -1256,9 +1270,9 @@ class TensorFlowTestCase(googletest.TestCase):
             msg="Mismatched value: a%s is different from b%s." % (path_str,
                                                                   path_str))
       except TypeError as e:
-        msg = "Error: a%s has %s, but b%s has %s" % (
-            path_str, type(a), path_str, type(b))
-        e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:])
+        msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a),
+                                                     path_str, type(b))
+        e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
         raise
 
   def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
@@ -1438,8 +1452,7 @@ class TensorFlowTestCase(googletest.TestCase):
     """
     device1 = pydev.canonical_name(device1)
     device2 = pydev.canonical_name(device2)
-    self.assertEqual(device1, device2,
-                     "Devices %s and %s are not equal. %s" % 
+    self.assertEqual(device1, device2, "Devices %s and %s are not equal. %s" %
                      (device1, device2, msg))
 
   # Fix Python 3 compatibility issues
index d9571fa..ece1da0 100644 (file)
@@ -1910,7 +1910,7 @@ cuda_py_test(
 
 cuda_py_test(
     name = "softmax_op_test",
-    size = "small",
+    size = "medium",
     srcs = ["softmax_op_test.py"],
     additional_deps = [
         "//third_party/py/numpy",
index 2b8e99e..981f96b 100644 (file)
@@ -18,14 +18,17 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import unittest
 import numpy as np
 
+
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
 
 
 @test_util.with_c_api
@@ -41,9 +44,10 @@ class SoftmaxTest(test.TestCase):
             features, axis=dim), one_only_on_dim))
     softmax = e / np.reshape(np.sum(e, axis=dim), one_only_on_dim)
     if log:
-      return np.log(softmax)
+      res = np.log(softmax)
     else:
-      return softmax
+      res = softmax
+    return res
 
   def _testSoftmax(self, np_features, dim=-1, log=False, use_gpu=False):
     # A previous version of the code checked the op name rather than the op type
@@ -53,9 +57,9 @@ class SoftmaxTest(test.TestCase):
     np_softmax = self._npSoftmax(np_features, dim=dim, log=log)
     with self.test_session(use_gpu=use_gpu):
       if log:
-        tf_softmax = nn_ops.log_softmax(np_features, dim=dim, name=name)
+        tf_softmax = nn_ops.log_softmax(np_features, axis=dim, name=name)
       else:
-        tf_softmax = nn_ops.softmax(np_features, dim=dim, name=name)
+        tf_softmax = nn_ops.softmax(np_features, axis=dim, name=name)
       out = tf_softmax.eval()
     self.assertAllCloseAccordingToType(np_softmax, out)
     self.assertShapeEqual(np_softmax, tf_softmax)
@@ -117,10 +121,32 @@ class SoftmaxTest(test.TestCase):
     self._testAll(
         np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float32))
 
+  @unittest.skipUnless(test.is_built_with_cuda(),
+                       "Test only applicable when running on GPUs")
+  def testFloatGPU(self):
+    if test.is_gpu_available(cuda_only=True):
+      rows = [2**x + np.random.randint(0, 1024) for x in range(1, 10)]
+      cols = [2**x + np.random.randint(0, 1024) for x in range(1, 10)]
+      for row, col in zip(rows, cols):
+        logging.info("Testing softmax float dtype in shape [%d, %d]", row, col)
+        data = np.random.rand(row, col)
+        self._testAll(data.astype(np.float32))
+
   def testHalf(self):
     self._testAll(
         np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float16))
 
+  @unittest.skipUnless(test.is_built_with_cuda(),
+                       "Test only applicable when running on GPUs")
+  def testHalfGPU(self):
+    if test.is_gpu_available(cuda_only=True):
+      rows = [2**x + np.random.randint(0, 1024) for x in range(1, 8)]
+      cols = [2**x + np.random.randint(0, 1024) for x in range(1, 8)]
+      for row, col in zip(rows, cols):
+        logging.info("Testing softmax half dtype in shape [%d, %d]", row, col)
+        data = np.random.rand(row, col)
+        self._testAll(data.astype(np.float16))
+
   def testDouble(self):
     self._testSoftmax(
         np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64))
@@ -169,7 +195,7 @@ class SoftmaxTest(test.TestCase):
       self.assertEqual(0, array_ops.size(x).eval())
       # reshape would raise if logits is empty
       with self.assertRaises(errors_impl.InvalidArgumentError):
-        nn_ops.softmax(x, dim=0).eval()
+        nn_ops.softmax(x, axis=0).eval()
 
   def testDimTooLarge(self):
     with self.test_session():
@@ -177,7 +203,7 @@ class SoftmaxTest(test.TestCase):
       # inference error.
       dim = array_ops.placeholder_with_default(100, shape=[])
       with self.assertRaises(errors_impl.InvalidArgumentError):
-        nn_ops.softmax([1., 2., 3., 4.], dim=dim).eval()
+        nn_ops.softmax([1., 2., 3., 4.], axis=dim).eval()
 
   def testLargeDims(self):
     # Make sure that we properly handle large inputs. See