From f283e65a1bdb797070be9b84a69ef323268f7c3c Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Tue, 5 Jun 2018 03:56:47 -0700 Subject: [PATCH] Handle scalar input to assert_equal in eager. PiperOrigin-RevId: 199274329 --- tensorflow/python/kernel_tests/check_ops_test.py | 7 +++++++ tensorflow/python/ops/check_ops.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 5a83ec8..7ef841c 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -88,6 +88,13 @@ class AssertEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() + def test_scalar_comparison(self): + const_true = constant_op.constant(True, name="true") + const_false = constant_op.constant(False, name="false") + with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): + check_ops.assert_equal(const_true, const_false, message="fail") + def test_returns_none_with_eager(self): with context.eager_mode(): small = constant_op.constant([1, 2], name="small") diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index cabc1e7..375a5ec 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -341,8 +341,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): y_sum, y_np[:y_sum])) index_and_values_str = '' - if x.shape == y.shape: - # If the shapes of x and y are the same, + if x.shape == y.shape and x.shape.as_list(): + # If the shapes of x and y are the same (and not scalars), # Get the values that actually differed and their indices. # If shapes are different this information is more confusing # than useful. -- 2.7.4