From a2e0f8c24776f63b04a29fad9c66bf3d66e94f4d Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Mon, 19 Mar 2018 19:52:06 -0700 Subject: [PATCH] Handle non-broadcastables shapes in eager assert_equal Before this change assert_equal would fail when producing an error message for non-equal shapes because array_ops.boolean_mask only works for equal shapes. This part of the error message is fairly confusing in presence of non-equal shapes. This change removes it. PiperOrigin-RevId: 189682518 --- tensorflow/python/kernel_tests/check_ops_test.py | 6 ++++ tensorflow/python/ops/check_ops.py | 39 +++++++++++++----------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 26d3df9..5a83ec8 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -213,6 +213,12 @@ First 2 elements of y: self.evaluate(out) @test_util.run_in_graph_and_eager_modes() + def test_raises_when_not_equal_and_broadcastable_shapes(self): + cond = constant_op.constant([True, False], name="small") + with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): + check_ops.assert_equal(cond, False, message="fail") + + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index d6d75e4..9cea3e9 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -363,27 +363,30 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): (x_sum, x_np[:x_sum], y_sum, y_np[:y_sum])) - # Get the values that actually differed and their indices. - mask = math_ops.logical_not(eq) - indices = array_ops.where(mask) - indices_np = indices.numpy() - x_vals = array_ops.boolean_mask(x, mask) - y_vals = array_ops.boolean_mask(y, mask) - summarize = min(summarize, indices_np.shape[0]) + index_and_values_str = '' + if x.shape == y.shape: + # If the shapes of x and y are the same, + # Get the values that actually differed and their indices. + # If shapes are different this information is more confusing + # than useful. + mask = math_ops.logical_not(eq) + indices = array_ops.where(mask) + indices_np = indices.numpy() + x_vals = array_ops.boolean_mask(x, mask) + y_vals = array_ops.boolean_mask(y, mask) + summarize = min(summarize, indices_np.shape[0]) + index_and_values_str = ( + 'Indices of first %s different values:\n%s\n' + 'Corresponding x values:\n%s\n' + 'Corresponding y values:\n%s\n' % + (summarize, indices_np[:summarize], + x_vals.numpy().reshape((-1,))[:summarize], + y_vals.numpy().reshape((-1,))[:summarize])) raise errors.InvalidArgumentError( node_def=None, op=None, - message=('%s\nCondition x == y did not hold.\n' - 'Indices of first %s different values:\n%s\n' - 'Corresponding x values:\n%s\n' - 'Corresponding y values:\n%s\n' - '%s' - % - (message or '', - summarize, indices_np[:summarize], - x_vals.numpy().reshape((-1,))[:summarize], - y_vals.numpy().reshape((-1,))[:summarize], - summary_msg))) + message=('%s\nCondition x == y did not hold.\n%s%s' % + (message or '', index_and_values_str, summary_msg))) return if data is None: -- 2.7.4