From: A. Unique TensorFlower Date: Wed, 28 Mar 2018 21:47:00 +0000 (-0700) Subject: Use high precision to compute softmax_cross_entropy_with_logits. X-Git-Tag: tflite-v0.1.7~67^2^2~37 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=355c88503a3a998aef3c1dc51045409778afd578;p=platform%2Fupstream%2Ftensorflow.git Use high precision to compute softmax_cross_entropy_with_logits. PiperOrigin-RevId: 190837379 --- diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc index 98936e0..5d17c89 100644 --- a/tensorflow/core/kernels/cwise_op_log.cc +++ b/tensorflow/core/kernels/cwise_op_log.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER5(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double, - complex64, complex128); +REGISTER6(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double, + bfloat16, complex64, complex128); #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Log", functor::log, float, Eigen::half, double); diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index a74de39..0c55386 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1836,8 +1836,10 @@ def softmax_cross_entropy_with_logits_v2( [logits, labels]) as name: logits = ops.convert_to_tensor(logits, name="logits") labels = ops.convert_to_tensor(labels, name="labels") + convert_to_float32 = ( + logits.dtype == dtypes.float16 or logits.dtype == dtypes.bfloat16) precise_logits = math_ops.cast( - logits, dtypes.float32) if (logits.dtype == dtypes.float16) else logits + logits, dtypes.float32) if convert_to_float32 else logits # labels and logits must be of the same type labels = math_ops.cast(labels, precise_logits.dtype) input_rank = array_ops.rank(precise_logits) @@ -1883,8 +1885,8 @@ def softmax_cross_entropy_with_logits_v2( del shape[dim] cost.set_shape(shape) - if logits.dtype == dtypes.float16: - return math_ops.cast(cost, dtypes.float16) + if convert_to_float32: + return math_ops.cast(cost, logits.dtype) else: return cost diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index af9dae2..da86d5f 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -852,6 +852,57 @@ class ComputeSampledLogitsTest(test_lib.TestCase): self.assertAllClose(exp_sampled_softmax_loss, got_sampled_softmax_loss.eval(), 1e-4) + def testSampledSoftmaxLossBf16(self): + # A simple test to verify the numerics for bfloat16. + def _SoftmaxCrossEntropyWithLogits(logits, targets): + # logits, targets: float arrays of the same shape. + assert logits.shape == targets.shape + stable_exp_logits = np.exp( + logits - np.amax(logits, axis=1, keepdims=True)) + pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) + return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) + + np.random.seed(0) + num_classes = 5 + batch_size = 3 + labels = [0, 1, 2] + sampled = [1, 0, 2, 3] + (weights, biases, hidden_acts, _, exp_logits, + exp_labels) = self._GenerateTestData( + num_classes=num_classes, + dim=10, + batch_size=batch_size, + num_true=1, + labels=labels, + sampled=sampled, + subtract_log_q=True) + exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits( + exp_logits, exp_labels) + + with self.test_session(): + true_exp_bf16 = np.full( + [batch_size, 1], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_exp_bf16 = np.full( + [len(sampled)], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype) + sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16) + + got_sampled_softmax_loss = math_ops.cast( + nn_impl.sampled_softmax_loss( + weights=constant_op.constant(weights, dtype=dtypes.bfloat16), + biases=constant_op.constant(biases, dtype=dtypes.bfloat16), + labels=constant_op.constant( + labels, shape=(batch_size, 1), dtype=dtypes.bfloat16), + inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16), + num_sampled=4, + num_classes=num_classes, + num_true=1, + sampled_values=sampled_vals_bf16, + remove_accidental_hits=False, + partition_strategy="div"), dtypes.float32) + + self.assertAllClose(exp_sampled_softmax_loss, + got_sampled_softmax_loss.eval(), 1e-1) + class CReluTest(test_lib.TestCase):