From 62c50e197e25c661048fe90fdd177a87eda47376 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 8 May 2018 14:00:30 -0700 Subject: [PATCH] Avoid string formatting in assert_same_float_dtype unless there's an error Especially helpful when executing eagerly PiperOrigin-RevId: 195871887 --- tensorflow/python/ops/check_ops.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 306055d..cabc1e7 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1169,19 +1169,35 @@ def _assert_same_base_type(items, expected_type=None): Raises: ValueError: If any types do not match. """ - original_item_str = None + original_expected_type = expected_type + mismatch = False for item in items: if item is not None: item_type = item.dtype.base_dtype if not expected_type: expected_type = item_type - original_item_str = item.name if hasattr(item, 'name') else str(item) elif expected_type != item_type: - raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( - item.name if hasattr(item, 'name') else str(item), - item_type, expected_type, - (' as %s' % original_item_str) if original_item_str else '')) - return expected_type + mismatch = True + break + if mismatch: + # Loop back through and build up an informative error message (this is very + # slow, so we don't do it unless we found an error above). + expected_type = original_expected_type + original_item_str = None + for item in items: + if item is not None: + item_type = item.dtype.base_dtype + if not expected_type: + expected_type = item_type + original_item_str = item.name if hasattr(item, 'name') else str(item) + elif expected_type != item_type: + raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( + item.name if hasattr(item, 'name') else str(item), + item_type, expected_type, + (' as %s' % original_item_str) if original_item_str else '')) + return expected_type # Should be unreachable + else: + return expected_type @tf_export('assert_same_float_dtype') -- 2.7.4