Use high precision to compute softmax_cross_entropy_with_logits.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 21:47:00 +0000 (14:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 21:49:53 +0000 (14:49 -0700)
PiperOrigin-RevId: 190837379

tensorflow/core/kernels/cwise_op_log.cc
tensorflow/python/ops/nn_ops.py
tensorflow/python/ops/nn_test.py

index 98936e0..5d17c89 100644 (file)
@@ -16,8 +16,8 @@ limitations under the License.
 #include "tensorflow/core/kernels/cwise_ops_common.h"
 
 namespace tensorflow {
-REGISTER5(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double,
-          complex64, complex128);
+REGISTER6(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double,
+          bfloat16, complex64, complex128);
 
 #if GOOGLE_CUDA
 REGISTER3(UnaryOp, GPU, "Log", functor::log, float, Eigen::half, double);
index a74de39..0c55386 100644 (file)
@@ -1836,8 +1836,10 @@ def softmax_cross_entropy_with_logits_v2(
                       [logits, labels]) as name:
     logits = ops.convert_to_tensor(logits, name="logits")
     labels = ops.convert_to_tensor(labels, name="labels")
+    convert_to_float32 = (
+        logits.dtype == dtypes.float16 or logits.dtype == dtypes.bfloat16)
     precise_logits = math_ops.cast(
-        logits, dtypes.float32) if (logits.dtype == dtypes.float16) else logits
+        logits, dtypes.float32) if convert_to_float32 else logits
     # labels and logits must be of the same type
     labels = math_ops.cast(labels, precise_logits.dtype)
     input_rank = array_ops.rank(precise_logits)
@@ -1883,8 +1885,8 @@ def softmax_cross_entropy_with_logits_v2(
       del shape[dim]
       cost.set_shape(shape)
 
-    if logits.dtype == dtypes.float16:
-      return math_ops.cast(cost, dtypes.float16)
+    if convert_to_float32:
+      return math_ops.cast(cost, logits.dtype)
     else:
       return cost
 
index af9dae2..da86d5f 100644 (file)
@@ -852,6 +852,57 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
       self.assertAllClose(exp_sampled_softmax_loss,
                           got_sampled_softmax_loss.eval(), 1e-4)
 
+  def testSampledSoftmaxLossBf16(self):
+    # A simple test to verify the numerics for bfloat16.
+    def _SoftmaxCrossEntropyWithLogits(logits, targets):
+      # logits, targets: float arrays of the same shape.
+      assert logits.shape == targets.shape
+      stable_exp_logits = np.exp(
+          logits - np.amax(logits, axis=1, keepdims=True))
+      pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
+      return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
+
+    np.random.seed(0)
+    num_classes = 5
+    batch_size = 3
+    labels = [0, 1, 2]
+    sampled = [1, 0, 2, 3]
+    (weights, biases, hidden_acts, _, exp_logits,
+     exp_labels) = self._GenerateTestData(
+         num_classes=num_classes,
+         dim=10,
+         batch_size=batch_size,
+         num_true=1,
+         labels=labels,
+         sampled=sampled,
+         subtract_log_q=True)
+    exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits(
+        exp_logits, exp_labels)
+
+    with self.test_session():
+      true_exp_bf16 = np.full(
+          [batch_size, 1], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype)
+      sampled_exp_bf16 = np.full(
+          [len(sampled)], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype)
+      sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16)
+
+      got_sampled_softmax_loss = math_ops.cast(
+          nn_impl.sampled_softmax_loss(
+              weights=constant_op.constant(weights, dtype=dtypes.bfloat16),
+              biases=constant_op.constant(biases, dtype=dtypes.bfloat16),
+              labels=constant_op.constant(
+                  labels, shape=(batch_size, 1), dtype=dtypes.bfloat16),
+              inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16),
+              num_sampled=4,
+              num_classes=num_classes,
+              num_true=1,
+              sampled_values=sampled_vals_bf16,
+              remove_accidental_hits=False,
+              partition_strategy="div"), dtypes.float32)
+
+      self.assertAllClose(exp_sampled_softmax_loss,
+                          got_sampled_softmax_loss.eval(), 1e-1)
+
 
 class CReluTest(test_lib.TestCase):