Adds checks to tf.nn.sparse_softmax_cross_entropy_with_logits to make sure that shape...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 5 Mar 2018 10:45:58 +0000 (02:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 10:49:52 +0000 (02:49 -0800)
In sparse_softmax_cross_entropy_with_logits the input dimensions are flattened, which can lead to unexpected bugs if the order of dimensions does not match (e.g. if one is time-major and the other is batch-major). This prevents such mistakes.

PiperOrigin-RevId: 187841750

tensorflow/python/estimator/canned/head_test.py
tensorflow/python/ops/nn_ops.py

index a300f315c18f60e77f262a3b961c5ef6306bc235..23158c76e7b5dbe579fc54a585572f041d7a42d6 100644 (file)
@@ -300,7 +300,12 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
     features = {'x': values_2x3}
 
     # Static shape.
-    with self.assertRaisesRegexp(ValueError, 'Dimensions must be equal'):
+    with self.assertRaisesRegexp(
+        ValueError,
+        r'Shape mismatch: The shape of labels \(received \(3,\)\) should equal '
+        r'the shape of logits except for the last dimension '
+        r'\(received \(2, 3\)\)\.'
+    ):
       head.create_loss(
           features=features,
           mode=model_fn.ModeKeys.EVAL,
index a0d500afce7c3a98e1e6ab8d5e80bd8748af6b0b..852ab365bbc57c0591198fda67401aa6399ffbee 100644 (file)
@@ -29,6 +29,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
@@ -2025,6 +2026,9 @@ def sparse_softmax_cross_entropy_with_logits(
     # Store label shape for result later.
     labels_static_shape = labels.get_shape()
     labels_shape = array_ops.shape(labels)
+    static_shapes_fully_defined = (
+        labels_static_shape.is_fully_defined() and
+        logits.get_shape()[:-1].is_fully_defined())
     if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0:
       raise ValueError(
           "Logits cannot be scalars - received shape %s." % logits.get_shape())
@@ -2034,6 +2038,12 @@ def sparse_softmax_cross_entropy_with_logits(
       raise ValueError("Rank mismatch: Rank of labels (received %s) should "
                        "equal rank of logits minus 1 (received %s)." %
                        (labels_static_shape.ndims, logits.get_shape().ndims))
+    if (static_shapes_fully_defined and
+        labels_static_shape != logits.get_shape()[:-1]):
+      raise ValueError("Shape mismatch: The shape of labels (received %s) "
+                       "should equal the shape of logits except for the last "
+                       "dimension (received %s)." % (labels_static_shape,
+                                                     logits.get_shape()))
     # Check if no reshapes are required.
     if logits.get_shape().ndims == 2:
       cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
@@ -2043,20 +2053,29 @@ def sparse_softmax_cross_entropy_with_logits(
       else:
         return cost
 
-    # Reshape logits to 2 dim, labels to 1 dim.
-    num_classes = array_ops.shape(logits)[array_ops.rank(logits) - 1]
-    precise_logits = array_ops.reshape(precise_logits, [-1, num_classes])
-    labels = array_ops.reshape(labels, [-1])
-    # The second output tensor contains the gradients.  We use it in
-    # _CrossEntropyGrad() in nn_grad but not here.
-    cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
-        precise_logits, labels, name=name)
-    cost = array_ops.reshape(cost, labels_shape)
-    cost.set_shape(labels_static_shape)
-    if logits.dtype == dtypes.float16:
-      return math_ops.cast(cost, dtypes.float16)
-    else:
-      return cost
+    # Perform a check of the dynamic shapes if the static shapes are not fully
+    # defined.
+    shape_checks = []
+    if not static_shapes_fully_defined:
+      shape_checks.append(
+          check_ops.assert_equal(
+              array_ops.shape(labels),
+              array_ops.shape(logits)[:-1]))
+    with ops.control_dependencies(shape_checks):
+      # Reshape logits to 2 dim, labels to 1 dim.
+      num_classes = array_ops.shape(logits)[array_ops.rank(logits) - 1]
+      precise_logits = array_ops.reshape(precise_logits, [-1, num_classes])
+      labels = array_ops.reshape(labels, [-1])
+      # The second output tensor contains the gradients.  We use it in
+      # _CrossEntropyGrad() in nn_grad but not here.
+      cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
+          precise_logits, labels, name=name)
+      cost = array_ops.reshape(cost, labels_shape)
+      cost.set_shape(labels_static_shape)
+      if logits.dtype == dtypes.float16:
+        return math_ops.cast(cost, dtypes.float16)
+      else:
+        return cost
 
 
 @tf_export("nn.avg_pool")