Handle scalar input to assert_equal in eager.
authorTom Hennigan <tomhennigan@google.com>
Tue, 5 Jun 2018 10:56:47 +0000 (03:56 -0700)
committerChristopher Suter <cgs@google.com>
Fri, 22 Jun 2018 17:54:28 +0000 (13:54 -0400)
PiperOrigin-RevId: 199274329

tensorflow/python/kernel_tests/check_ops_test.py
tensorflow/python/ops/check_ops.py

index 5a83ec8..7ef841c 100644 (file)
@@ -88,6 +88,13 @@ class AssertEqualTest(test.TestCase):
       out = array_ops.identity(small)
     self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
+  def test_scalar_comparison(self):
+    const_true = constant_op.constant(True, name="true")
+    const_false = constant_op.constant(False, name="false")
+    with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
+      check_ops.assert_equal(const_true, const_false, message="fail")
+
   def test_returns_none_with_eager(self):
     with context.eager_mode():
       small = constant_op.constant([1, 2], name="small")
index cabc1e7..375a5ec 100644 (file)
@@ -341,8 +341,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
                           y_sum, y_np[:y_sum]))
 
         index_and_values_str = ''
-        if x.shape == y.shape:
-          # If the shapes of x and y are the same,
+        if x.shape == y.shape and x.shape.as_list():
+          # If the shapes of x and y are the same (and not scalars),
           # Get the values that actually differed and their indices.
           # If shapes are different this information is more confusing
           # than useful.