Make namedtuples with identical name and field names to be considered as the same...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 6 Feb 2018 16:50:41 +0000 (08:50 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Feb 2018 16:54:27 +0000 (08:54 -0800)
PiperOrigin-RevId: 184687609

tensorflow/python/util/nest.py
tensorflow/python/util/nest_test.py

index c8525ed..23c2c48 100644 (file)
@@ -497,7 +497,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
     shallow_tree: an arbitrarily nested structure.
     input_tree: an arbitrarily nested structure.
     check_types: if `True` (default) the sequence types of `shallow_tree` and
-      `input_tree` have to be the same.
+      `input_tree` have to be the same. Note that even with check_types==True,
+      this function will consider two different namedtuple classes with the same
+      name and _fields attribute to be the same class.
 
   Raises:
     TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
@@ -513,10 +515,21 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
           "Input has type: %s." % type(input_tree))
 
     if check_types and not isinstance(input_tree, type(shallow_tree)):
-      raise TypeError(
-          "The two structures don't have the same sequence type. Input "
-          "structure has type %s, while shallow structure has type %s."
-          % (type(input_tree), type(shallow_tree)))
+      # Duck-typing means that nest should be fine with two different
+      # namedtuples with identical name and fields.
+      shallow_is_namedtuple = _is_namedtuple(shallow_tree, False)
+      input_is_namedtuple = _is_namedtuple(input_tree, False)
+      if shallow_is_namedtuple and input_is_namedtuple:
+        if not _same_namedtuples(shallow_tree, input_tree):
+          raise TypeError(
+              "The two namedtuples don't have the same sequence type. Input "
+              "structure has type %s, while shallow structure has type %s."
+              % (type(input_tree), type(shallow_tree)))
+      else:
+        raise TypeError(
+            "The two structures don't have the same sequence type. Input "
+            "structure has type %s, while shallow structure has type %s."
+            % (type(input_tree), type(shallow_tree)))
 
     if len(input_tree) != len(shallow_tree):
       raise ValueError(
index 8aaf799..4439d62 100644 (file)
@@ -429,6 +429,15 @@ class NestTest(test.TestCase):
     inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
     nest.assert_shallow_structure(inp_ab, inp_ba)
 
+    # This assertion is expected to pass: two namedtuples with the same
+    # name and field names are considered to be identical.
+    same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
+    same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
+    inp_shallow = same_name_type_0(1, 2)
+    inp_deep = same_name_type_1(1, [1, 2, 3])
+    nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
+    nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
+
   def testFlattenUpTo(self):
     # Shallow tree ends at scalar.
     input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]