Implement assert_same_structure in C++
authorIgor Ganichev <iga@google.com>
Thu, 29 Mar 2018 01:26:30 +0000 (18:26 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 01:29:02 +0000 (18:29 -0700)
Also implements helper functions nest._is_namedtuple
nest._same_namedtuple.

Also, fix a bug in FlattenHelper where error from recursive
calls were not propagated up immediately.

This change implements a good chunk of machinery that will
allow us to move map_structure to C++.

Before:
entry {
  name: "NestBenchmark.assert_same_structure_6_elem"
  iters: 30000
  wall_time: 4.79532718658e-05
}

entry {
  name: "NestBenchmark.assert_same_structure_60_elem"
  iters: 30000
  wall_time: 0.000403008667628
}

After:
entry {
  name: "NestBenchmark.assert_same_structure_6_elem"
  iters: 30000
  wall_time: 1.65301720301e-05
}

entry {
  name: "NestBenchmark.assert_same_structure_60_elem"
  iters: 30000
  wall_time: 0.000147621099154
}

PiperOrigin-RevId: 190869007

tensorflow/python/BUILD
tensorflow/python/framework/test_util.py
tensorflow/python/kernel_tests/functional_ops_test.py
tensorflow/python/util/nest.py
tensorflow/python/util/nest_test.py
tensorflow/python/util/util.cc
tensorflow/python/util/util.h
tensorflow/python/util/util.i

index 4f61c01..09c1965 100644 (file)
@@ -298,6 +298,7 @@ cc_library(
     srcs = ["util/util.cc"],
     hdrs = ["util/util.h"],
     deps = [
+        ":safe_ptr",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//util/python:python_headers",
index 4192a27..bf00fa6 100644 (file)
@@ -487,7 +487,13 @@ def assert_no_new_pyobjects_executing_eagerly(f):
       gc.collect()
       # There should be no new Python objects hanging around.
       new_count = len(gc.get_objects())
-      self.assertEqual(previous_count, new_count)
+      # In some cases (specifacally on MacOS), new_count is somehow
+      # smaller than previous_count.
+      # Using plain assert because not all classes using this decorator
+      # have assertLessEqual
+      assert new_count <= previous_count, (
+          "new_count(%d) is not less than or equal to previous_count(%d)" % (
+              new_count, previous_count))
       gc.enable()
 
   return decorator
index f5717a5..1301ef9 100644 (file)
@@ -229,7 +229,7 @@ class FunctionalOpsTest(test.TestCase):
     with self.test_session():
       nums = np.array([1, 2, 3, 4, 5, 6])
       with self.assertRaisesRegexp(
-          TypeError, r"two structures don't have the same sequence type."):
+          TypeError, r"two structures don't have the same nested structure"):
         # lambda emits tuple, but dtype is a list
         functional_ops.map_fn(
             lambda x: ((x + 3) * 2, -(x + 3) * 2),
@@ -316,7 +316,7 @@ class FunctionalOpsTest(test.TestCase):
       initializer = np.array(1.0)
       # Multiply a * 1 each time
       with self.assertRaisesRegexp(
-          ValueError, "two structures don't have the same number of elements"):
+          ValueError, "two structures don't have the same nested structure"):
         functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
 
   def testScan_Scoped(self):
index 23c2c48..5622431 100644 (file)
@@ -60,15 +60,7 @@ def _is_namedtuple(instance, strict=False):
   Returns:
     True if `instance` is a `namedtuple`.
   """
-  # Attemp to limit the test to plain namedtuple (not stuff inheriting from it).
-  if not isinstance(instance, tuple):
-    return False
-  if strict and instance.__class__.__base__ != tuple:
-    return False
-  return (
-      hasattr(instance, "_fields") and
-      isinstance(instance._fields, _collections.Sequence) and
-      all(isinstance(f, _six.string_types) for f in instance._fields))
+  return _pywrap_tensorflow.IsNamedtuple(instance, strict)
 
 
 def _sequence_like(instance, args):
@@ -157,76 +149,7 @@ def flatten(nest):
 
 def _same_namedtuples(nest1, nest2):
   """Returns True if the two namedtuples have the same name and fields."""
-  if nest1._fields != nest2._fields:
-    return False
-  if nest1.__class__.__name__ != nest2.__class__.__name__:
-    return False
-  return True
-
-
-def _recursive_assert_same_structure(nest1, nest2, check_types):
-  """Helper function for `assert_same_structure`.
-
-  See `assert_same_structure` for further information about namedtuples.
-
-  Args:
-    nest1: An arbitrarily nested structure.
-    nest2: An arbitrarily nested structure.
-    check_types: If `True` (default) types of sequences are checked as
-        well, including the keys of dictionaries. If set to `False`, for example
-        a list and a tuple of objects will look the same if they have the same
-        size. Note that namedtuples with identical name and fields are always
-        considered to have the same shallow structure.
-
-  Returns:
-    True if `nest1` and `nest2` have the same structure.
-
-  Raises:
-    ValueError: If the two structure don't have the same nested structre.
-    TypeError: If the two structure don't have the same sequence type.
-    ValueError: If the two dictionaries don't have the same set of keys.
-  """
-  is_sequence_nest1 = is_sequence(nest1)
-  if is_sequence_nest1 != is_sequence(nest2):
-    raise ValueError(
-        "The two structures don't have the same nested structure.\n\n"
-        "First structure: %s\n\nSecond structure: %s." % (nest1, nest2))
-
-  if not is_sequence_nest1:
-    return  # finished checking
-
-  if check_types:
-    type_nest1 = type(nest1)
-    type_nest2 = type(nest2)
-
-    # Duck-typing means that nest should be fine with two different namedtuples
-    # with identical name and fields.
-    if _is_namedtuple(nest1, True) and _is_namedtuple(nest2, True):
-      if not _same_namedtuples(nest1, nest2):
-        raise TypeError(
-            "The two namedtuples don't have the same sequence type. First "
-            "structure has type %s, while second structure has type %s."
-            % (type_nest1, type_nest2))
-    else:
-      if type_nest1 != type_nest2:
-        raise TypeError(
-            "The two structures don't have the same sequence type. First "
-            "structure has type %s, while second structure has type %s."
-            % (type_nest1, type_nest2))
-
-    if isinstance(nest1, dict):
-      keys1 = set(_six.iterkeys(nest1))
-      keys2 = set(_six.iterkeys(nest2))
-      if keys1 != keys2:
-        raise ValueError(
-            "The two dictionaries don't have the same set of keys. First "
-            "structure has keys {}, while second structure has keys {}."
-            .format(keys1, keys2))
-
-  nest1_as_sequence = [n for n in _yield_value(nest1)]
-  nest2_as_sequence = [n for n in _yield_value(nest2)]
-  for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence):
-    _recursive_assert_same_structure(n1, n2, check_types)
+  return _pywrap_tensorflow.SameNamedtuples(nest1, nest2)
 
 
 def assert_same_structure(nest1, nest2, check_types=True):
@@ -257,14 +180,7 @@ def assert_same_structure(nest1, nest2, check_types=True):
     TypeError: If the two structures differ in the type of sequence in any of
       their substructures. Only possible if `check_types` is `True`.
   """
-  len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
-  len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
-  if len_nest1 != len_nest2:
-    raise ValueError("The two structures don't have the same number of "
-                     "elements.\n\nFirst structure (%i elements): %s\n\n"
-                     "Second structure (%i elements): %s"
-                     % (len_nest1, nest1, len_nest2, nest2))
-  _recursive_assert_same_structure(nest1, nest2, check_types)
+  _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
 
 
 def flatten_dict_items(dictionary):
index 4439d62..2f12b25 100644 (file)
@@ -19,11 +19,14 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
+import time
 
 import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
@@ -32,6 +35,9 @@ from tensorflow.python.util import nest
 
 class NestTest(test.TestCase):
 
+  PointXY = collections.namedtuple("Point", ["x", "y"])  # pylint: disable=invalid-name
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
   def testFlattenAndPack(self):
     structure = ((3, 4), 5, (6, 7, (9, 10), 8))
     flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
@@ -39,8 +45,8 @@ class NestTest(test.TestCase):
     self.assertEqual(
         nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
                                                  ("d", "e", ("f", "g"), "h")))
-    point = collections.namedtuple("Point", ["x", "y"])
-    structure = (point(x=4, y=2), ((point(x=1, y=0),),))
+    structure = (NestTest.PointXY(x=4, y=2),
+                 ((NestTest.PointXY(x=1, y=0),),))
     flat = [4, 2, 1, 0]
     self.assertEqual(nest.flatten(structure), flat)
     restructured_from_flat = nest.pack_sequence_as(structure, flat)
@@ -66,6 +72,7 @@ class NestTest(test.TestCase):
     with self.assertRaises(ValueError):
       nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
 
+  @test_util.assert_no_new_pyobjects_executing_eagerly
   def testFlattenDictOrder(self):
     """`flatten` orders dicts by key, including OrderedDicts."""
     ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
@@ -87,12 +94,14 @@ class NestTest(test.TestCase):
         ordered_reconstruction)
     self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
 
+  Abc = collections.namedtuple("A", ("b", "c"))  # pylint: disable=invalid-name
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
   def testFlattenAndPack_withDicts(self):
     # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
-    named_tuple = collections.namedtuple("A", ("b", "c"))
     mess = [
         "z",
-        named_tuple(3, 4),
+        NestTest.Abc(3, 4),
         {
             "c": [
                 1,
@@ -111,7 +120,7 @@ class NestTest(test.TestCase):
 
     structure_of_mess = [
         14,
-        named_tuple("a", True),
+        NestTest.Abc("a", True),
         {
             "c": [
                 0,
@@ -157,6 +166,7 @@ class NestTest(test.TestCase):
       nest.pack_sequence_as(["hello", "world"],
                             ["and", "goodbye", "again"])
 
+  @test_util.assert_no_new_pyobjects_executing_eagerly
   def testIsSequence(self):
     self.assertFalse(nest.is_sequence("1234"))
     self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
@@ -186,6 +196,23 @@ class NestTest(test.TestCase):
         ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
       nest.flatten_dict_items(another_bad_dictionary)
 
+  # pylint does not correctly recognize these as class names and
+  # suggests to use variable style under_score naming.
+  # pylint: disable=invalid-name
+  Named0ab = collections.namedtuple("named_0", ("a", "b"))
+  Named1ab = collections.namedtuple("named_1", ("a", "b"))
+  SameNameab = collections.namedtuple("same_name", ("a", "b"))
+  SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
+  SameNamexy = collections.namedtuple("same_name", ("x", "y"))
+  SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
+  SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
+  NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
+  # pylint: enable=invalid-name
+
+  class SameNamedType1(SameNameab):
+    pass
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
   def testAssertSameStructure(self):
     structure1 = (((1, 2), 3), 4, (5, 6))
     structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
@@ -198,23 +225,32 @@ class NestTest(test.TestCase):
 
     with self.assertRaisesRegexp(
         ValueError,
-        ("don't have the same number of elements\\.\n\n"
-         "First structure \\(6 elements\\):.*?"
-         "\n\nSecond structure \\(2 elements\\):")):
+        ("The two structures don't have the same nested structure\\.\n\n"
+         "First structure:.*?\n\n"
+         "Second structure:.*\n\n"
+         "More specifically: Substructure "
+         r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
+         'substructure "type=str str=spam" is not')):
       nest.assert_same_structure(structure1, structure_different_num_elements)
 
     with self.assertRaisesRegexp(
         ValueError,
-        ("don't have the same number of elements\\.\n\n"
-         "First structure \\(2 elements\\):.*?"
-         "\n\nSecond structure \\(1 elements\\):")):
+        ("The two structures don't have the same nested structure\\.\n\n"
+         "First structure:.*?\n\n"
+         "Second structure:.*\n\n"
+         r'More specifically: Substructure "type=list str=\[0, 1\]" '
+         r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
+         "is not")):
       nest.assert_same_structure([0, 1], np.array([0, 1]))
 
     with self.assertRaisesRegexp(
         ValueError,
-        ("don't have the same number of elements\\.\n\n"
-         "First structure \\(1 elements\\):.*"
-         "\n\nSecond structure \\(2 elements\\):")):
+        ("The two structures don't have the same nested structure\\.\n\n"
+         "First structure:.*?\n\n"
+         "Second structure:.*\n\n"
+         r'More specifically: Substructure "type=list str=\[0, 1\]" '
+         'is a sequence, while substructure "type=int str=0" '
+         "is not")):
       nest.assert_same_structure(0, [0, 1])
 
     self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
@@ -225,21 +261,21 @@ class NestTest(test.TestCase):
          "First structure: .*?\n\nSecond structure: ")):
       nest.assert_same_structure(structure1, structure_different_nesting)
 
-    named_type_0 = collections.namedtuple("named_0", ("a", "b"))
-    named_type_1 = collections.namedtuple("named_1", ("a", "b"))
     self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
-                      named_type_0("a", "b"))
+                      NestTest.Named0ab("a", "b"))
 
-    nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))
+    nest.assert_same_structure(NestTest.Named0ab(3, 4),
+                               NestTest.Named0ab("a", "b"))
 
     self.assertRaises(TypeError, nest.assert_same_structure,
-                      named_type_0(3, 4), named_type_1(3, 4))
+                      NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
 
     with self.assertRaisesRegexp(
         ValueError,
         ("don't have the same nested structure\\.\n\n"
          "First structure: .*?\n\nSecond structure: ")):
-      nest.assert_same_structure(named_type_0(3, 4), named_type_0([3], 4))
+      nest.assert_same_structure(NestTest.Named0ab(3, 4),
+                                 NestTest.Named0ab([3], 4))
 
     with self.assertRaisesRegexp(
         ValueError,
@@ -258,36 +294,33 @@ class NestTest(test.TestCase):
                                  "don't have the same set of keys"):
       nest.assert_same_structure({"a": 1}, {"b": 1})
 
-    same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
-    same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
-    nest.assert_same_structure(same_name_type_0(0, 1), same_name_type_1(2, 3))
+    nest.assert_same_structure(NestTest.SameNameab(0, 1),
+                               NestTest.SameNameab2(2, 3))
 
     # This assertion is expected to pass: two namedtuples with the same
     # name and field names are considered to be identical.
-    same_name_type_2 = collections.namedtuple("same_name_1", ("x", "y"))
-    same_name_type_3 = collections.namedtuple("same_name_1", ("x", "y"))
     nest.assert_same_structure(
-        same_name_type_0(same_name_type_2(0, 1), 2),
-        same_name_type_1(same_name_type_3(2, 3), 4))
+        NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
+        NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
 
     expected_message = "The two structures don't have the same.*"
     with self.assertRaisesRegexp(ValueError, expected_message):
-      nest.assert_same_structure(same_name_type_0(0, same_name_type_1(1, 2)),
-                                 same_name_type_1(same_name_type_0(0, 1), 2))
+      nest.assert_same_structure(
+          NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
+          NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
 
-    same_name_type_1 = collections.namedtuple("not_same_name", ("a", "b"))
     self.assertRaises(TypeError, nest.assert_same_structure,
-                      same_name_type_0(0, 1), same_name_type_1(2, 3))
+                      NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
 
-    same_name_type_1 = collections.namedtuple("same_name", ("x", "y"))
     self.assertRaises(TypeError, nest.assert_same_structure,
-                      same_name_type_0(0, 1), same_name_type_1(2, 3))
+                      NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
 
-    class SameNamedType1(collections.namedtuple("same_name", ("a", "b"))):
-      pass
     self.assertRaises(TypeError, nest.assert_same_structure,
-                      same_name_type_0(0, 1), SameNamedType1(2, 3))
+                      NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
 
+  EmptyNT = collections.namedtuple("empty_nt", "")  # pylint: disable=invalid-name
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
   def testMapStructure(self):
     structure1 = (((1, 2), 3), 4, (5, 6))
     structure2 = (((7, 8), 9), 10, (11, 12))
@@ -310,9 +343,8 @@ class NestTest(test.TestCase):
     self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
     self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
     self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
-    empty_nt = collections.namedtuple("empty_nt", "")
-    self.assertEqual(empty_nt(), nest.map_structure(lambda x: x + 1,
-                                                    empty_nt()))
+    self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
+                                                            NestTest.EmptyNT()))
 
     # This is checking actual equality of types, empty list != empty tuple
     self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
@@ -352,10 +384,12 @@ class NestTest(test.TestCase):
     with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
       nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
 
+  ABTuple = collections.namedtuple("ab_tuple", "a, b")  # pylint: disable=invalid-name
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
   def testMapStructureWithStrings(self):
-    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
-    inp_a = ab_tuple(a="foo", b=("bar", "baz"))
-    inp_b = ab_tuple(a=2, b=(1, 3))
+    inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
+    inp_b = NestTest.ABTuple(a=2, b=(1, 3))
     out = nest.map_structure(lambda string, repeats: string * repeats,
                              inp_a,
                              inp_b)
@@ -363,8 +397,8 @@ class NestTest(test.TestCase):
     self.assertEqual("bar", out.b[0])
     self.assertEqual("bazbazbaz", out.b[1])
 
-    nt = ab_tuple(a=("something", "something_else"),
-                  b="yet another thing")
+    nt = NestTest.ABTuple(a=("something", "something_else"),
+                          b="yet another thing")
     rev_nt = nest.map_structure(lambda x: x[::-1], nt)
     # Check the output is the correct structure, and all strings are reversed.
     nest.assert_same_structure(nt, rev_nt)
@@ -431,10 +465,8 @@ class NestTest(test.TestCase):
 
     # This assertion is expected to pass: two namedtuples with the same
     # name and field names are considered to be identical.
-    same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
-    same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
-    inp_shallow = same_name_type_0(1, 2)
-    inp_deep = same_name_type_1(1, [1, 2, 3])
+    inp_shallow = NestTest.SameNameab(1, 2)
+    inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
     nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
     nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
 
@@ -466,7 +498,7 @@ class NestTest(test.TestCase):
                      [1, {"c": 2}, 3, (4, 5)])
 
     # Namedtuples.
-    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
+    ab_tuple = NestTest.ABTuple
     input_tree = ab_tuple(a=[0, 1], b=2)
     shallow_tree = ab_tuple(a=0, b=1)
     input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
@@ -681,5 +713,31 @@ class NestTest(test.TestCase):
           list(nest.flatten_with_joined_string_paths(inputs)), expected)
 
 
+class NestBenchmark(test.Benchmark):
+
+  def run_and_report(self, s1, s2, name):
+    burn_iter, test_iter = 100, 30000
+
+    for _ in xrange(burn_iter):
+      nest.assert_same_structure(s1, s2)
+
+    t0 = time.time()
+    for _ in xrange(test_iter):
+      nest.assert_same_structure(s1, s2)
+    t1 = time.time()
+
+    self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
+                          name=name)
+
+  def benchmark_assert_structure(self):
+    s1 = (((1, 2), 3), 4, (5, 6))
+    s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
+    self.run_and_report(s1, s2, "assert_same_structure_6_elem")
+
+    s1 = (((1, 2), 3), 4, (5, 6)) * 10
+    s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
+    self.run_and_report(s1, s2, "assert_same_structure_60_elem")
+
+
 if __name__ == "__main__":
   test.main()
index a41fa7d..70aee4a 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
 
 namespace tensorflow {
 namespace swig {
@@ -27,6 +28,113 @@ PyObject* CollectionsSequenceType = nullptr;
 
 bool WarnedThatSetIsNotSequence = false;
 
+bool IsString(PyObject* o) {
+  return PyBytes_Check(o) ||
+#if PY_MAJOR_VERSION < 3
+         PyString_Check(o) ||
+#endif
+         PyUnicode_Check(o);
+}
+
+// Equivalent to Python's 'o.__class__.__name__'
+// Note that '__class__' attribute is set only in new-style classes.
+// A lot of tensorflow code uses __class__ without checks, so it seems like
+// we only support new-style classes.
+StringPiece GetClassName(PyObject* o) {
+  // __class__ is equivalent to type() for new style classes.
+  // type() is equivalent to PyObject_Type()
+  // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
+  // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
+  // we don't need here.
+  PyTypeObject* type = o->ob_type;
+
+  // __name__ is the value of `tp_name` after the last '.'
+  // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
+  StringPiece name(type->tp_name);
+  size_t pos = name.rfind('.');
+  if (pos != StringPiece::npos) {
+    name.remove_prefix(pos + 1);
+  }
+  return name;
+}
+
+string PyObjectToString(PyObject* o) {
+  if (o == nullptr) {
+    return "<null object>";
+  }
+  PyObject* str = PyObject_Str(o);
+  if (str) {
+#if PY_MAJOR_VERSION < 3
+    string s(PyString_AS_STRING(str));
+#else
+    string s(PyUnicode_AsUTF8(str));
+#endif
+    Py_DECREF(str);
+    return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
+  } else {
+    return "<failed to execute str() on object>";
+  }
+}
+
+// Implements the same idea as tensorflow.util.nest._yield_value
+// During construction we check if the iterable is a dictionary.
+// If so, we construct a sequence from its sorted keys that will be used
+// for iteration.
+// If not, we construct a sequence directly from the iterable.
+// At each step, we get the next element from the sequence and use it
+// either as a key or return it directly.
+//
+// 'iterable' must not be modified while ValIterator is used.
+class ValIterator {
+ public:
+  explicit ValIterator(PyObject* iterable) : dict_(nullptr), index_(0) {
+    if (PyDict_Check(iterable)) {
+      dict_ = iterable;
+      // PyDict_Keys returns a list, which can be used with
+      // PySequence_Fast_GET_ITEM.
+      seq_ = PyDict_Keys(iterable);
+      // Iterate through dictionaries in a deterministic order by sorting the
+      // keys. Notice this means that we ignore the original order of
+      // `OrderedDict` instances. This is intentional, to avoid potential
+      // bugs caused by mixing ordered and plain dicts (e.g., flattening
+      // a dict but using a corresponding `OrderedDict` to pack it back).
+      PyList_Sort(seq_);
+    } else {
+      seq_ = PySequence_Fast(iterable, "");
+    }
+    size_ = PySequence_Fast_GET_SIZE(seq_);
+  }
+
+  ~ValIterator() { Py_DECREF(seq_); }
+
+  // Return a borrowed reference to the next element from iterable.
+  // Return nullptr when iteration is over.
+  PyObject* next() {
+    PyObject* element = nullptr;
+    if (index_ < size_) {
+      // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
+      // references.
+      element = PySequence_Fast_GET_ITEM(seq_, index_);
+      ++index_;
+      if (dict_ != nullptr) {
+        element = PyDict_GetItem(dict_, element);
+        if (element == nullptr) {
+          PyErr_SetString(PyExc_RuntimeError,
+                          "Dictionary was modified during iteration over it");
+          return nullptr;
+        }
+      }
+    }
+    return element;
+  }
+
+ private:
+  PyObject* seq_;
+  PyObject* dict_;
+  Py_ssize_t size_;
+  Py_ssize_t index_;
+};
+
 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
 // Returns 0 otherwise.
 // Returns -1 if an error occurred.
@@ -38,7 +146,7 @@ int IsSequenceHelper(PyObject* o) {
                     "so consider avoiding using them.";
     WarnedThatSetIsNotSequence = true;
   }
-  if (CollectionsSequenceType == nullptr) {
+  if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
     PyErr_SetString(
         PyExc_RuntimeError,
         tensorflow::strings::StrCat(
@@ -49,11 +157,7 @@ int IsSequenceHelper(PyObject* o) {
   }
   int is_instance = PyObject_IsInstance(o, CollectionsSequenceType);
   if (is_instance == -1) return -1;
-  return static_cast<int>(is_instance != 0 && !PyBytes_Check(o) &&
-#if PY_MAJOR_VERSION < 3
-                          !PyString_Check(o) &&
-#endif
-                          !PyUnicode_Check(o));
+  return static_cast<int>(is_instance != 0 && !IsString(o));
 }
 
 bool FlattenHelper(PyObject* nested, PyObject* list) {
@@ -75,12 +179,16 @@ bool FlattenHelper(PyObject* nested, PyObject* list) {
       // while the method is running.
       PyObject* key = PyList_GET_ITEM(keys, i);
       PyObject* val = PyDict_GetItem(nested, key);
-      if (Py_EnterRecursiveCall(" in Flatten")) {
+      if (Py_EnterRecursiveCall(" in flatten")) {
         Py_DECREF(keys);
         return false;
       }
-      FlattenHelper(val, list);
+      const bool success = FlattenHelper(val, list);
       Py_LeaveRecursiveCall();
+      if (!success) {
+        Py_DECREF(keys);
+        return false;
+      }
     }
     Py_DECREF(keys);
     return true;
@@ -90,13 +198,159 @@ bool FlattenHelper(PyObject* nested, PyObject* list) {
   PyObject* item;
   PyObject* iterator = PyObject_GetIter(nested);
   while ((item = PyIter_Next(iterator)) != nullptr) {
-    FlattenHelper(item, list);
+    if (Py_EnterRecursiveCall(" in flatten")) {
+      Py_DECREF(iterator);
+      Py_DECREF(item);
+      return false;
+    }
+    bool success = FlattenHelper(item, list);
+    Py_LeaveRecursiveCall();
+    if (!success) {
+      Py_DECREF(iterator);
+      Py_DECREF(item);
+      return false;
+    }
     Py_DECREF(item);
   }
   Py_DECREF(iterator);
   return true;
 }
 
+// Sets error using keys of 'dict1' and 'dict2'.
+// 'dict1' and 'dict2' are assumed to be Python dictionaries.
+void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
+                           bool* is_type_error) {
+  PyObject* k1 = PyDict_Keys(dict1);
+  PyObject* k2 = PyDict_Keys(dict2);
+  *is_type_error = false;
+  *error_msg = tensorflow::strings::StrCat(
+      "The two dictionaries don't have the same set of keys. "
+      "First structure has keys ",
+      PyObjectToString(k1), ", while second structure has keys ",
+      PyObjectToString(k2));
+  Py_DECREF(k1);
+  Py_DECREF(k2);
+}
+
+// Returns true iff there were no "internal" errors. In other words,
+// errors that has nothing to do with structure checking.
+// If an "internal" error occured, the appropriate Python error will be
+// set and the caller can propage it directly to the user.
+//
+// Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
+// be empty.
+// Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
+// with appropriate error and sets `is_type_error` to true iff
+// the error to be raised should be TypeError.
+bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
+                               string* error_msg, bool* is_type_error) {
+  DCHECK(error_msg);
+  DCHECK(is_type_error);
+  const bool is_seq1 = IsSequence(o1);
+  const bool is_seq2 = IsSequence(o2);
+  if (PyErr_Occurred()) return false;
+  if (is_seq1 != is_seq2) {
+    string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
+    string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
+    *is_type_error = false;
+    *error_msg = tensorflow::strings::StrCat(
+        "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
+        non_seq_str, "\" is not");
+    return true;
+  }
+
+  // Got to scalars, so finished checking. Structures are the same.
+  if (!is_seq1) return true;
+
+  if (check_types) {
+    const PyTypeObject* type1 = o1->ob_type;
+    const PyTypeObject* type2 = o2->ob_type;
+
+    // We treat two different namedtuples with identical name and fields
+    // as having the same type.
+    const PyObject* o1_tuple = IsNamedtuple(o1, true);
+    if (o1_tuple == nullptr) return false;
+    const PyObject* o2_tuple = IsNamedtuple(o2, true);
+    if (o2_tuple == nullptr) {
+      Py_DECREF(o1_tuple);
+      return false;
+    }
+    bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
+    Py_DECREF(o1_tuple);
+    Py_DECREF(o2_tuple);
+
+    if (both_tuples) {
+      const PyObject* same_tuples = SameNamedtuples(o1, o2);
+      if (same_tuples == nullptr) return false;
+      bool not_same_tuples = same_tuples != Py_True;
+      Py_DECREF(same_tuples);
+      if (not_same_tuples) {
+        *is_type_error = true;
+        *error_msg = tensorflow::strings::StrCat(
+            "The two namedtuples don't have the same sequence type. "
+            "First structure ",
+            PyObjectToString(o1), " has type ", type1->tp_name,
+            ", while second structure ", PyObjectToString(o2), " has type ",
+            type2->tp_name);
+        return true;
+      }
+    } else if (type1 != type2) {
+      *is_type_error = true;
+      *error_msg = tensorflow::strings::StrCat(
+          "The two namedtuples don't have the same sequence type. "
+          "First structure ",
+          PyObjectToString(o1), " has type ", type1->tp_name,
+          ", while second structure ", PyObjectToString(o2), " has type ",
+          type2->tp_name);
+      return true;
+    }
+
+    if (PyDict_Check(o1)) {
+      if (PyDict_Size(o1) != PyDict_Size(o2)) {
+        SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+        return true;
+      }
+
+      PyObject* key;
+      Py_ssize_t pos = 0;
+      while (PyDict_Next(o1, &pos, &key, nullptr)) {
+        if (PyDict_GetItem(o2, key) == nullptr) {
+          SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+          return true;
+        }
+      }
+    }
+  }
+
+  ValIterator iter1(o1);
+  ValIterator iter2(o2);
+
+  while (true) {
+    PyObject* v1 = iter1.next();
+    PyObject* v2 = iter2.next();
+    if (v1 != nullptr && v2 != nullptr) {
+      if (Py_EnterRecursiveCall(" in assert_same_structure")) {
+        return false;
+      }
+      bool no_internal_errors = AssertSameStructureHelper(
+          v1, v2, check_types, error_msg, is_type_error);
+      Py_LeaveRecursiveCall();
+      if (!no_internal_errors) return false;
+      if (!error_msg->empty()) return true;
+    } else if (v1 == nullptr && v2 == nullptr) {
+      // Done with all recursive calls. Structure matched.
+      return true;
+    } else {
+      *is_type_error = false;
+      *error_msg = tensorflow::strings::StrCat(
+          "The two structures don't have the same number of elements. ",
+          "First structure: ", PyObjectToString(o1),
+          ". Second structure: ", PyObjectToString(o2));
+      return true;
+    }
+  }
+}
+
 }  // anonymous namespace
 
 void RegisterSequenceClass(PyObject* sequence_class) {
@@ -123,5 +377,107 @@ PyObject* Flatten(PyObject* nested) {
     return nullptr;
   }
 }
+
+PyObject* IsNamedtuple(PyObject* o, bool strict) {
+  // Must be subclass of tuple
+  if (!PyTuple_Check(o)) {
+    Py_RETURN_FALSE;
+  }
+
+  // If strict, o.__class__.__base__ must be tuple
+  if (strict) {
+    PyObject* klass = PyObject_GetAttrString(o, "__class__");
+    if (klass == nullptr) return nullptr;
+    PyObject* base = PyObject_GetAttrString(klass, "__base__");
+    Py_DECREF(klass);
+    if (base == nullptr) return nullptr;
+
+    const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
+    // built-in object types are singletons
+    bool tuple_base = base_type == &PyTuple_Type;
+    Py_DECREF(base);
+    if (!tuple_base) {
+      Py_RETURN_FALSE;
+    }
+  }
+
+  if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
+    PyErr_SetString(
+        PyExc_RuntimeError,
+        tensorflow::strings::StrCat(
+            "collections.Sequence type has not been set. "
+            "Please call RegisterSequenceClass before using this module")
+            .c_str());
+    return nullptr;
+  }
+
+  // o must have attribute '_fields' and every element in
+  // '_fields' must be a string.
+  int has_fields = PyObject_HasAttrString(o, "_fields");
+  if (!has_fields) {
+    Py_RETURN_FALSE;
+  }
+
+  Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
+  int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType);
+  if (is_instance == 0) {
+    Py_RETURN_FALSE;
+  } else if (is_instance == -1) {
+    return nullptr;
+  }
+
+  Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
+  const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
+  for (Py_ssize_t i = 0; i < s; ++i) {
+    // PySequence_Fast_GET_ITEM returns borrowed ref
+    PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
+    if (!IsString(elem)) {
+      Py_RETURN_FALSE;
+    }
+  }
+
+  Py_RETURN_TRUE;
+}
+
+PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
+  PyObject* f1 = PyObject_GetAttrString(o1, "_fields");
+  PyObject* f2 = PyObject_GetAttrString(o2, "_fields");
+  if (f1 == nullptr || f2 == nullptr) {
+    Py_XDECREF(f1);
+    Py_XDECREF(f2);
+    PyErr_SetString(
+        PyExc_RuntimeError,
+        "Expected namedtuple-like objects (that have _fields attr)");
+    return nullptr;
+  }
+
+  if (PyObject_RichCompareBool(f1, f2, Py_NE)) {
+    Py_RETURN_FALSE;
+  }
+
+  if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
+    Py_RETURN_TRUE;
+  } else {
+    Py_RETURN_FALSE;
+  }
+}
+
+PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) {
+  string error_msg;
+  bool is_type_error = false;
+  AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error);
+  if (!error_msg.empty()) {
+    PyErr_SetString(
+        is_type_error ? PyExc_TypeError : PyExc_ValueError,
+        tensorflow::strings::StrCat(
+            "The two structures don't have the same nested structure.\n\n",
+            "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
+            PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
+            .c_str());
+    return nullptr;
+  }
+  Py_RETURN_NONE;
+}
+
 }  // namespace swig
 }  // namespace tensorflow
index 2af71dc..c325baa 100644 (file)
@@ -33,6 +33,57 @@ namespace swig {
 //   dict.
 bool IsSequence(PyObject* o);
 
+// Implements the same interface as tensorflow.util.nest._is_namedtuple
+// Returns Py_True iff `instance` should be considered a `namedtuple`.
+//
+// Args:
+//   instance: An instance of a Python object.
+//   strict: If True, `instance` is considered to be a `namedtuple` only if
+//       it is a "plain" namedtuple. For instance, a class inheriting
+//       from a `namedtuple` will be considered to be a `namedtuple`
+//       iff `strict=False`.
+//
+// Returns:
+//   True if `instance` is a `namedtuple`.
+PyObject* IsNamedtuple(PyObject* o, bool strict);
+
+// Implements the same interface as tensorflow.util.nest._same_namedtuples
+// Returns Py_True iff the two namedtuples have the same name and fields.
+// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
+// '_fields' attribute).
+PyObject* SameNamedtuples(PyObject* o1, PyObject* o2);
+
+// Asserts that two structures are nested in the same way.
+//
+// Note that namedtuples with identical name and fields are always considered
+// to have the same shallow structure (even with `check_types=True`).
+// For intance, this code will print `True`:
+//
+// ```python
+// def nt(a, b):
+//   return collections.namedtuple('foo', 'a b')(a, b)
+// print(assert_same_structure(nt(0, 1), nt(2, 3)))
+// ```
+//
+// Args:
+//  nest1: an arbitrarily nested structure.
+//  nest2: an arbitrarily nested structure.
+//  check_types: if `true`, types of sequences are checked as
+//      well, including the keys of dictionaries. If set to `false`, for example
+//      a list and a tuple of objects will look the same if they have the same
+//      size. Note that namedtuples with identical name and fields are always
+//      considered to have the same shallow structure.
+//
+// Raises:
+//  ValueError: If the two structures do not have the same number of elements or
+//    if the two structures are not nested in the same way.
+//  TypeError: If the two structures differ in the type of sequence in any of
+//    their substructures. Only possible if `check_types` is `True`.
+//
+// Returns:
+//  Py_None on success, nullptr on error.
+PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types);
+
 // Implements the same interface as tensorflow.util.nest.flatten
 //
 // Returns a flat list from a given nested structure.
index d69084f..b7f201b 100644 (file)
@@ -34,6 +34,15 @@ limitations under the License.
 %unignore tensorflow::swig::IsSequence;
 %noexception tensorflow::swig::IsSequence;
 
+%unignore tensorflow::swig::IsNamedtuple;
+%noexception tensorflow::swig::IsNamedtuple;
+
+%unignore tensorflow::swig::SameNamedtuples;
+%noexception tensorflow::swig::SameNamedtuples;
+
+%unignore tensorflow::swig::AssertSameStructure;
+%noexception tensorflow::swig::AssertSameStructure;
+
 %unignore tensorflow::swig::Flatten;
 %noexception tensorflow::swig::Flatten;