Handle non-broadcastables shapes in eager assert_equal
authorIgor Ganichev <iga@google.com>
Tue, 20 Mar 2018 02:52:06 +0000 (19:52 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 02:56:22 +0000 (19:56 -0700)
Before this change assert_equal would fail when producing an error
message for non-equal shapes because array_ops.boolean_mask only
works for equal shapes.

This part of the error message is fairly confusing in presence
of non-equal shapes. This change removes it.

PiperOrigin-RevId: 189682518

tensorflow/python/kernel_tests/check_ops_test.py
tensorflow/python/ops/check_ops.py

index 26d3df9..5a83ec8 100644 (file)
@@ -213,6 +213,12 @@ First 2 elements of y:
       self.evaluate(out)
 
   @test_util.run_in_graph_and_eager_modes()
+  def test_raises_when_not_equal_and_broadcastable_shapes(self):
+    cond = constant_op.constant([True, False], name="small")
+    with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
+      check_ops.assert_equal(cond, False, message="fail")
+
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_both_empty(self):
     larry = constant_op.constant([])
     curly = constant_op.constant([])
index d6d75e4..9cea3e9 100644 (file)
@@ -363,27 +363,30 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
                          (x_sum, x_np[:x_sum],
                           y_sum, y_np[:y_sum]))
 
-        # Get the values that actually differed and their indices.
-        mask = math_ops.logical_not(eq)
-        indices = array_ops.where(mask)
-        indices_np = indices.numpy()
-        x_vals = array_ops.boolean_mask(x, mask)
-        y_vals = array_ops.boolean_mask(y, mask)
-        summarize = min(summarize, indices_np.shape[0])
+        index_and_values_str = ''
+        if x.shape == y.shape:
+          # If the shapes of x and y are the same,
+          # Get the values that actually differed and their indices.
+          # If shapes are different this information is more confusing
+          # than useful.
+          mask = math_ops.logical_not(eq)
+          indices = array_ops.where(mask)
+          indices_np = indices.numpy()
+          x_vals = array_ops.boolean_mask(x, mask)
+          y_vals = array_ops.boolean_mask(y, mask)
+          summarize = min(summarize, indices_np.shape[0])
+          index_and_values_str = (
+              'Indices of first %s different values:\n%s\n'
+              'Corresponding x values:\n%s\n'
+              'Corresponding y values:\n%s\n' %
+              (summarize, indices_np[:summarize],
+               x_vals.numpy().reshape((-1,))[:summarize],
+               y_vals.numpy().reshape((-1,))[:summarize]))
 
         raise errors.InvalidArgumentError(
             node_def=None, op=None,
-            message=('%s\nCondition x == y did not hold.\n'
-                     'Indices of first %s different values:\n%s\n'
-                     'Corresponding x values:\n%s\n'
-                     'Corresponding y values:\n%s\n'
-                     '%s'
-                     %
-                     (message or '',
-                      summarize, indices_np[:summarize],
-                      x_vals.numpy().reshape((-1,))[:summarize],
-                      y_vals.numpy().reshape((-1,))[:summarize],
-                      summary_msg)))
+            message=('%s\nCondition x == y did not hold.\n%s%s' %
+                     (message or '', index_and_values_str, summary_msg)))
       return
 
     if data is None: