From ef8c0863ca653f235ec2b79beaea32fe6ddee7a9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 6 Feb 2018 08:50:41 -0800 Subject: [PATCH] Make namedtuples with identical name and field names to be considered as the same shallow structure in assert_shallow_structure PiperOrigin-RevId: 184687609 --- tensorflow/python/util/nest.py | 23 ++++++++++++++++++----- tensorflow/python/util/nest_test.py | 9 +++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index c8525ed..23c2c48 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -497,7 +497,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): shallow_tree: an arbitrarily nested structure. input_tree: an arbitrarily nested structure. check_types: if `True` (default) the sequence types of `shallow_tree` and - `input_tree` have to be the same. + `input_tree` have to be the same. Note that even with check_types==True, + this function will consider two different namedtuple classes with the same + name and _fields attribute to be the same class. Raises: TypeError: If `shallow_tree` is a sequence but `input_tree` is not. @@ -513,10 +515,21 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): "Input has type: %s." % type(input_tree)) if check_types and not isinstance(input_tree, type(shallow_tree)): - raise TypeError( - "The two structures don't have the same sequence type. Input " - "structure has type %s, while shallow structure has type %s." - % (type(input_tree), type(shallow_tree))) + # Duck-typing means that nest should be fine with two different + # namedtuples with identical name and fields. + shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) + input_is_namedtuple = _is_namedtuple(input_tree, False) + if shallow_is_namedtuple and input_is_namedtuple: + if not _same_namedtuples(shallow_tree, input_tree): + raise TypeError( + "The two namedtuples don't have the same sequence type. Input " + "structure has type %s, while shallow structure has type %s." + % (type(input_tree), type(shallow_tree))) + else: + raise TypeError( + "The two structures don't have the same sequence type. Input " + "structure has type %s, while shallow structure has type %s." + % (type(input_tree), type(shallow_tree))) if len(input_tree) != len(shallow_tree): raise ValueError( diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 8aaf799..4439d62 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -429,6 +429,15 @@ class NestTest(test.TestCase): inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) nest.assert_shallow_structure(inp_ab, inp_ba) + # This assertion is expected to pass: two namedtuples with the same + # name and field names are considered to be identical. + same_name_type_0 = collections.namedtuple("same_name", ("a", "b")) + same_name_type_1 = collections.namedtuple("same_name", ("a", "b")) + inp_shallow = same_name_type_0(1, 2) + inp_deep = same_name_type_1(1, [1, 2, 3]) + nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) + nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) + def testFlattenUpTo(self): # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] -- 2.7.4