From 17dfe3ed7db7fb4d41f8933adead4737c30a92c9 Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Wed, 28 Mar 2018 18:26:30 -0700 Subject: [PATCH] Implement assert_same_structure in C++ Also implements helper functions nest._is_namedtuple nest._same_namedtuple. Also, fix a bug in FlattenHelper where error from recursive calls were not propagated up immediately. This change implements a good chunk of machinery that will allow us to move map_structure to C++. Before: entry { name: "NestBenchmark.assert_same_structure_6_elem" iters: 30000 wall_time: 4.79532718658e-05 } entry { name: "NestBenchmark.assert_same_structure_60_elem" iters: 30000 wall_time: 0.000403008667628 } After: entry { name: "NestBenchmark.assert_same_structure_6_elem" iters: 30000 wall_time: 1.65301720301e-05 } entry { name: "NestBenchmark.assert_same_structure_60_elem" iters: 30000 wall_time: 0.000147621099154 } PiperOrigin-RevId: 190869007 --- tensorflow/python/BUILD | 1 + tensorflow/python/framework/test_util.py | 8 +- .../python/kernel_tests/functional_ops_test.py | 4 +- tensorflow/python/util/nest.py | 90 +---- tensorflow/python/util/nest_test.py | 156 ++++++--- tensorflow/python/util/util.cc | 374 ++++++++++++++++++++- tensorflow/python/util/util.h | 51 +++ tensorflow/python/util/util.i | 9 + 8 files changed, 545 insertions(+), 148 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4f61c01..09c1965 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -298,6 +298,7 @@ cc_library( srcs = ["util/util.cc"], hdrs = ["util/util.h"], deps = [ + ":safe_ptr", "//tensorflow/core:framework", "//tensorflow/core:lib", "//util/python:python_headers", diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4192a27..bf00fa6 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -487,7 +487,13 @@ def assert_no_new_pyobjects_executing_eagerly(f): gc.collect() # There should be no new Python objects hanging around. new_count = len(gc.get_objects()) - self.assertEqual(previous_count, new_count) + # In some cases (specifacally on MacOS), new_count is somehow + # smaller than previous_count. + # Using plain assert because not all classes using this decorator + # have assertLessEqual + assert new_count <= previous_count, ( + "new_count(%d) is not less than or equal to previous_count(%d)" % ( + new_count, previous_count)) gc.enable() return decorator diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index f5717a5..1301ef9 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -229,7 +229,7 @@ class FunctionalOpsTest(test.TestCase): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) with self.assertRaisesRegexp( - TypeError, r"two structures don't have the same sequence type."): + TypeError, r"two structures don't have the same nested structure"): # lambda emits tuple, but dtype is a list functional_ops.map_fn( lambda x: ((x + 3) * 2, -(x + 3) * 2), @@ -316,7 +316,7 @@ class FunctionalOpsTest(test.TestCase): initializer = np.array(1.0) # Multiply a * 1 each time with self.assertRaisesRegexp( - ValueError, "two structures don't have the same number of elements"): + ValueError, "two structures don't have the same nested structure"): functional_ops.scan(lambda a, x: (a, -a), elems, initializer) def testScan_Scoped(self): diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 23c2c48..5622431 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -60,15 +60,7 @@ def _is_namedtuple(instance, strict=False): Returns: True if `instance` is a `namedtuple`. """ - # Attemp to limit the test to plain namedtuple (not stuff inheriting from it). - if not isinstance(instance, tuple): - return False - if strict and instance.__class__.__base__ != tuple: - return False - return ( - hasattr(instance, "_fields") and - isinstance(instance._fields, _collections.Sequence) and - all(isinstance(f, _six.string_types) for f in instance._fields)) + return _pywrap_tensorflow.IsNamedtuple(instance, strict) def _sequence_like(instance, args): @@ -157,76 +149,7 @@ def flatten(nest): def _same_namedtuples(nest1, nest2): """Returns True if the two namedtuples have the same name and fields.""" - if nest1._fields != nest2._fields: - return False - if nest1.__class__.__name__ != nest2.__class__.__name__: - return False - return True - - -def _recursive_assert_same_structure(nest1, nest2, check_types): - """Helper function for `assert_same_structure`. - - See `assert_same_structure` for further information about namedtuples. - - Args: - nest1: An arbitrarily nested structure. - nest2: An arbitrarily nested structure. - check_types: If `True` (default) types of sequences are checked as - well, including the keys of dictionaries. If set to `False`, for example - a list and a tuple of objects will look the same if they have the same - size. Note that namedtuples with identical name and fields are always - considered to have the same shallow structure. - - Returns: - True if `nest1` and `nest2` have the same structure. - - Raises: - ValueError: If the two structure don't have the same nested structre. - TypeError: If the two structure don't have the same sequence type. - ValueError: If the two dictionaries don't have the same set of keys. - """ - is_sequence_nest1 = is_sequence(nest1) - if is_sequence_nest1 != is_sequence(nest2): - raise ValueError( - "The two structures don't have the same nested structure.\n\n" - "First structure: %s\n\nSecond structure: %s." % (nest1, nest2)) - - if not is_sequence_nest1: - return # finished checking - - if check_types: - type_nest1 = type(nest1) - type_nest2 = type(nest2) - - # Duck-typing means that nest should be fine with two different namedtuples - # with identical name and fields. - if _is_namedtuple(nest1, True) and _is_namedtuple(nest2, True): - if not _same_namedtuples(nest1, nest2): - raise TypeError( - "The two namedtuples don't have the same sequence type. First " - "structure has type %s, while second structure has type %s." - % (type_nest1, type_nest2)) - else: - if type_nest1 != type_nest2: - raise TypeError( - "The two structures don't have the same sequence type. First " - "structure has type %s, while second structure has type %s." - % (type_nest1, type_nest2)) - - if isinstance(nest1, dict): - keys1 = set(_six.iterkeys(nest1)) - keys2 = set(_six.iterkeys(nest2)) - if keys1 != keys2: - raise ValueError( - "The two dictionaries don't have the same set of keys. First " - "structure has keys {}, while second structure has keys {}." - .format(keys1, keys2)) - - nest1_as_sequence = [n for n in _yield_value(nest1)] - nest2_as_sequence = [n for n in _yield_value(nest2)] - for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence): - _recursive_assert_same_structure(n1, n2, check_types) + return _pywrap_tensorflow.SameNamedtuples(nest1, nest2) def assert_same_structure(nest1, nest2, check_types=True): @@ -257,14 +180,7 @@ def assert_same_structure(nest1, nest2, check_types=True): TypeError: If the two structures differ in the type of sequence in any of their substructures. Only possible if `check_types` is `True`. """ - len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1 - len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1 - if len_nest1 != len_nest2: - raise ValueError("The two structures don't have the same number of " - "elements.\n\nFirst structure (%i elements): %s\n\n" - "Second structure (%i elements): %s" - % (len_nest1, nest1, len_nest2, nest2)) - _recursive_assert_same_structure(nest1, nest2, check_types) + _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types) def flatten_dict_items(dictionary): diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 4439d62..2f12b25 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -19,11 +19,14 @@ from __future__ import division from __future__ import print_function import collections +import time import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -32,6 +35,9 @@ from tensorflow.python.util import nest class NestTest(test.TestCase): + PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name + + @test_util.assert_no_new_pyobjects_executing_eagerly def testFlattenAndPack(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] @@ -39,8 +45,8 @@ class NestTest(test.TestCase): self.assertEqual( nest.pack_sequence_as(structure, flat), (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) - point = collections.namedtuple("Point", ["x", "y"]) - structure = (point(x=4, y=2), ((point(x=1, y=0),),)) + structure = (NestTest.PointXY(x=4, y=2), + ((NestTest.PointXY(x=1, y=0),),)) flat = [4, 2, 1, 0] self.assertEqual(nest.flatten(structure), flat) restructured_from_flat = nest.pack_sequence_as(structure, flat) @@ -66,6 +72,7 @@ class NestTest(test.TestCase): with self.assertRaises(ValueError): nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) + @test_util.assert_no_new_pyobjects_executing_eagerly def testFlattenDictOrder(self): """`flatten` orders dicts by key, including OrderedDicts.""" ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) @@ -87,12 +94,14 @@ class NestTest(test.TestCase): ordered_reconstruction) self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) + Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name + + @test_util.assert_no_new_pyobjects_executing_eagerly def testFlattenAndPack_withDicts(self): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. - named_tuple = collections.namedtuple("A", ("b", "c")) mess = [ "z", - named_tuple(3, 4), + NestTest.Abc(3, 4), { "c": [ 1, @@ -111,7 +120,7 @@ class NestTest(test.TestCase): structure_of_mess = [ 14, - named_tuple("a", True), + NestTest.Abc("a", True), { "c": [ 0, @@ -157,6 +166,7 @@ class NestTest(test.TestCase): nest.pack_sequence_as(["hello", "world"], ["and", "goodbye", "again"]) + @test_util.assert_no_new_pyobjects_executing_eagerly def testIsSequence(self): self.assertFalse(nest.is_sequence("1234")) self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) @@ -186,6 +196,23 @@ class NestTest(test.TestCase): ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): nest.flatten_dict_items(another_bad_dictionary) + # pylint does not correctly recognize these as class names and + # suggests to use variable style under_score naming. + # pylint: disable=invalid-name + Named0ab = collections.namedtuple("named_0", ("a", "b")) + Named1ab = collections.namedtuple("named_1", ("a", "b")) + SameNameab = collections.namedtuple("same_name", ("a", "b")) + SameNameab2 = collections.namedtuple("same_name", ("a", "b")) + SameNamexy = collections.namedtuple("same_name", ("x", "y")) + SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) + SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) + NotSameName = collections.namedtuple("not_same_name", ("a", "b")) + # pylint: enable=invalid-name + + class SameNamedType1(SameNameab): + pass + + @test_util.assert_no_new_pyobjects_executing_eagerly def testAssertSameStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) @@ -198,23 +225,32 @@ class NestTest(test.TestCase): with self.assertRaisesRegexp( ValueError, - ("don't have the same number of elements\\.\n\n" - "First structure \\(6 elements\\):.*?" - "\n\nSecond structure \\(2 elements\\):")): + ("The two structures don't have the same nested structure\\.\n\n" + "First structure:.*?\n\n" + "Second structure:.*\n\n" + "More specifically: Substructure " + r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' + 'substructure "type=str str=spam" is not')): nest.assert_same_structure(structure1, structure_different_num_elements) with self.assertRaisesRegexp( ValueError, - ("don't have the same number of elements\\.\n\n" - "First structure \\(2 elements\\):.*?" - "\n\nSecond structure \\(1 elements\\):")): + ("The two structures don't have the same nested structure\\.\n\n" + "First structure:.*?\n\n" + "Second structure:.*\n\n" + r'More specifically: Substructure "type=list str=\[0, 1\]" ' + r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' + "is not")): nest.assert_same_structure([0, 1], np.array([0, 1])) with self.assertRaisesRegexp( ValueError, - ("don't have the same number of elements\\.\n\n" - "First structure \\(1 elements\\):.*" - "\n\nSecond structure \\(2 elements\\):")): + ("The two structures don't have the same nested structure\\.\n\n" + "First structure:.*?\n\n" + "Second structure:.*\n\n" + r'More specifically: Substructure "type=list str=\[0, 1\]" ' + 'is a sequence, while substructure "type=int str=0" ' + "is not")): nest.assert_same_structure(0, [0, 1]) self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) @@ -225,21 +261,21 @@ class NestTest(test.TestCase): "First structure: .*?\n\nSecond structure: ")): nest.assert_same_structure(structure1, structure_different_nesting) - named_type_0 = collections.namedtuple("named_0", ("a", "b")) - named_type_1 = collections.namedtuple("named_1", ("a", "b")) self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), - named_type_0("a", "b")) + NestTest.Named0ab("a", "b")) - nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b")) + nest.assert_same_structure(NestTest.Named0ab(3, 4), + NestTest.Named0ab("a", "b")) self.assertRaises(TypeError, nest.assert_same_structure, - named_type_0(3, 4), named_type_1(3, 4)) + NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) with self.assertRaisesRegexp( ValueError, ("don't have the same nested structure\\.\n\n" "First structure: .*?\n\nSecond structure: ")): - nest.assert_same_structure(named_type_0(3, 4), named_type_0([3], 4)) + nest.assert_same_structure(NestTest.Named0ab(3, 4), + NestTest.Named0ab([3], 4)) with self.assertRaisesRegexp( ValueError, @@ -258,36 +294,33 @@ class NestTest(test.TestCase): "don't have the same set of keys"): nest.assert_same_structure({"a": 1}, {"b": 1}) - same_name_type_0 = collections.namedtuple("same_name", ("a", "b")) - same_name_type_1 = collections.namedtuple("same_name", ("a", "b")) - nest.assert_same_structure(same_name_type_0(0, 1), same_name_type_1(2, 3)) + nest.assert_same_structure(NestTest.SameNameab(0, 1), + NestTest.SameNameab2(2, 3)) # This assertion is expected to pass: two namedtuples with the same # name and field names are considered to be identical. - same_name_type_2 = collections.namedtuple("same_name_1", ("x", "y")) - same_name_type_3 = collections.namedtuple("same_name_1", ("x", "y")) nest.assert_same_structure( - same_name_type_0(same_name_type_2(0, 1), 2), - same_name_type_1(same_name_type_3(2, 3), 4)) + NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), + NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) expected_message = "The two structures don't have the same.*" with self.assertRaisesRegexp(ValueError, expected_message): - nest.assert_same_structure(same_name_type_0(0, same_name_type_1(1, 2)), - same_name_type_1(same_name_type_0(0, 1), 2)) + nest.assert_same_structure( + NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), + NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) - same_name_type_1 = collections.namedtuple("not_same_name", ("a", "b")) self.assertRaises(TypeError, nest.assert_same_structure, - same_name_type_0(0, 1), same_name_type_1(2, 3)) + NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) - same_name_type_1 = collections.namedtuple("same_name", ("x", "y")) self.assertRaises(TypeError, nest.assert_same_structure, - same_name_type_0(0, 1), same_name_type_1(2, 3)) + NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) - class SameNamedType1(collections.namedtuple("same_name", ("a", "b"))): - pass self.assertRaises(TypeError, nest.assert_same_structure, - same_name_type_0(0, 1), SameNamedType1(2, 3)) + NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) + EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name + + @test_util.assert_no_new_pyobjects_executing_eagerly def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) structure2 = (((7, 8), 9), 10, (11, 12)) @@ -310,9 +343,8 @@ class NestTest(test.TestCase): self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) - empty_nt = collections.namedtuple("empty_nt", "") - self.assertEqual(empty_nt(), nest.map_structure(lambda x: x + 1, - empty_nt())) + self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, + NestTest.EmptyNT())) # This is checking actual equality of types, empty list != empty tuple self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) @@ -352,10 +384,12 @@ class NestTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") + ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name + + @test_util.assert_no_new_pyobjects_executing_eagerly def testMapStructureWithStrings(self): - ab_tuple = collections.namedtuple("ab_tuple", "a, b") - inp_a = ab_tuple(a="foo", b=("bar", "baz")) - inp_b = ab_tuple(a=2, b=(1, 3)) + inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) + inp_b = NestTest.ABTuple(a=2, b=(1, 3)) out = nest.map_structure(lambda string, repeats: string * repeats, inp_a, inp_b) @@ -363,8 +397,8 @@ class NestTest(test.TestCase): self.assertEqual("bar", out.b[0]) self.assertEqual("bazbazbaz", out.b[1]) - nt = ab_tuple(a=("something", "something_else"), - b="yet another thing") + nt = NestTest.ABTuple(a=("something", "something_else"), + b="yet another thing") rev_nt = nest.map_structure(lambda x: x[::-1], nt) # Check the output is the correct structure, and all strings are reversed. nest.assert_same_structure(nt, rev_nt) @@ -431,10 +465,8 @@ class NestTest(test.TestCase): # This assertion is expected to pass: two namedtuples with the same # name and field names are considered to be identical. - same_name_type_0 = collections.namedtuple("same_name", ("a", "b")) - same_name_type_1 = collections.namedtuple("same_name", ("a", "b")) - inp_shallow = same_name_type_0(1, 2) - inp_deep = same_name_type_1(1, [1, 2, 3]) + inp_shallow = NestTest.SameNameab(1, 2) + inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) @@ -466,7 +498,7 @@ class NestTest(test.TestCase): [1, {"c": 2}, 3, (4, 5)]) # Namedtuples. - ab_tuple = collections.namedtuple("ab_tuple", "a, b") + ab_tuple = NestTest.ABTuple input_tree = ab_tuple(a=[0, 1], b=2) shallow_tree = ab_tuple(a=0, b=1) input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, @@ -681,5 +713,31 @@ class NestTest(test.TestCase): list(nest.flatten_with_joined_string_paths(inputs)), expected) +class NestBenchmark(test.Benchmark): + + def run_and_report(self, s1, s2, name): + burn_iter, test_iter = 100, 30000 + + for _ in xrange(burn_iter): + nest.assert_same_structure(s1, s2) + + t0 = time.time() + for _ in xrange(test_iter): + nest.assert_same_structure(s1, s2) + t1 = time.time() + + self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, + name=name) + + def benchmark_assert_structure(self): + s1 = (((1, 2), 3), 4, (5, 6)) + s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) + self.run_and_report(s1, s2, "assert_same_structure_6_elem") + + s1 = (((1, 2), 3), 4, (5, 6)) * 10 + s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 + self.run_and_report(s1, s2, "assert_same_structure_60_elem") + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index a41fa7d..70aee4a 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/python/lib/core/safe_ptr.h" namespace tensorflow { namespace swig { @@ -27,6 +28,113 @@ PyObject* CollectionsSequenceType = nullptr; bool WarnedThatSetIsNotSequence = false; +bool IsString(PyObject* o) { + return PyBytes_Check(o) || +#if PY_MAJOR_VERSION < 3 + PyString_Check(o) || +#endif + PyUnicode_Check(o); +} + +// Equivalent to Python's 'o.__class__.__name__' +// Note that '__class__' attribute is set only in new-style classes. +// A lot of tensorflow code uses __class__ without checks, so it seems like +// we only support new-style classes. +StringPiece GetClassName(PyObject* o) { + // __class__ is equivalent to type() for new style classes. + // type() is equivalent to PyObject_Type() + // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type) + // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which + // we don't need here. + PyTypeObject* type = o->ob_type; + + // __name__ is the value of `tp_name` after the last '.' + // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name) + StringPiece name(type->tp_name); + size_t pos = name.rfind('.'); + if (pos != StringPiece::npos) { + name.remove_prefix(pos + 1); + } + return name; +} + +string PyObjectToString(PyObject* o) { + if (o == nullptr) { + return ""; + } + PyObject* str = PyObject_Str(o); + if (str) { +#if PY_MAJOR_VERSION < 3 + string s(PyString_AS_STRING(str)); +#else + string s(PyUnicode_AsUTF8(str)); +#endif + Py_DECREF(str); + return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s); + } else { + return ""; + } +} + +// Implements the same idea as tensorflow.util.nest._yield_value +// During construction we check if the iterable is a dictionary. +// If so, we construct a sequence from its sorted keys that will be used +// for iteration. +// If not, we construct a sequence directly from the iterable. +// At each step, we get the next element from the sequence and use it +// either as a key or return it directly. +// +// 'iterable' must not be modified while ValIterator is used. +class ValIterator { + public: + explicit ValIterator(PyObject* iterable) : dict_(nullptr), index_(0) { + if (PyDict_Check(iterable)) { + dict_ = iterable; + // PyDict_Keys returns a list, which can be used with + // PySequence_Fast_GET_ITEM. + seq_ = PyDict_Keys(iterable); + // Iterate through dictionaries in a deterministic order by sorting the + // keys. Notice this means that we ignore the original order of + // `OrderedDict` instances. This is intentional, to avoid potential + // bugs caused by mixing ordered and plain dicts (e.g., flattening + // a dict but using a corresponding `OrderedDict` to pack it back). + PyList_Sort(seq_); + } else { + seq_ = PySequence_Fast(iterable, ""); + } + size_ = PySequence_Fast_GET_SIZE(seq_); + } + + ~ValIterator() { Py_DECREF(seq_); } + + // Return a borrowed reference to the next element from iterable. + // Return nullptr when iteration is over. + PyObject* next() { + PyObject* element = nullptr; + if (index_ < size_) { + // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed + // references. + element = PySequence_Fast_GET_ITEM(seq_, index_); + ++index_; + if (dict_ != nullptr) { + element = PyDict_GetItem(dict_, element); + if (element == nullptr) { + PyErr_SetString(PyExc_RuntimeError, + "Dictionary was modified during iteration over it"); + return nullptr; + } + } + } + return element; + } + + private: + PyObject* seq_; + PyObject* dict_; + Py_ssize_t size_; + Py_ssize_t index_; +}; + // Returns 1 if `o` is considered a sequence for the purposes of Flatten(). // Returns 0 otherwise. // Returns -1 if an error occurred. @@ -38,7 +146,7 @@ int IsSequenceHelper(PyObject* o) { "so consider avoiding using them."; WarnedThatSetIsNotSequence = true; } - if (CollectionsSequenceType == nullptr) { + if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) { PyErr_SetString( PyExc_RuntimeError, tensorflow::strings::StrCat( @@ -49,11 +157,7 @@ int IsSequenceHelper(PyObject* o) { } int is_instance = PyObject_IsInstance(o, CollectionsSequenceType); if (is_instance == -1) return -1; - return static_cast(is_instance != 0 && !PyBytes_Check(o) && -#if PY_MAJOR_VERSION < 3 - !PyString_Check(o) && -#endif - !PyUnicode_Check(o)); + return static_cast(is_instance != 0 && !IsString(o)); } bool FlattenHelper(PyObject* nested, PyObject* list) { @@ -75,12 +179,16 @@ bool FlattenHelper(PyObject* nested, PyObject* list) { // while the method is running. PyObject* key = PyList_GET_ITEM(keys, i); PyObject* val = PyDict_GetItem(nested, key); - if (Py_EnterRecursiveCall(" in Flatten")) { + if (Py_EnterRecursiveCall(" in flatten")) { Py_DECREF(keys); return false; } - FlattenHelper(val, list); + const bool success = FlattenHelper(val, list); Py_LeaveRecursiveCall(); + if (!success) { + Py_DECREF(keys); + return false; + } } Py_DECREF(keys); return true; @@ -90,13 +198,159 @@ bool FlattenHelper(PyObject* nested, PyObject* list) { PyObject* item; PyObject* iterator = PyObject_GetIter(nested); while ((item = PyIter_Next(iterator)) != nullptr) { - FlattenHelper(item, list); + if (Py_EnterRecursiveCall(" in flatten")) { + Py_DECREF(iterator); + Py_DECREF(item); + return false; + } + bool success = FlattenHelper(item, list); + Py_LeaveRecursiveCall(); + if (!success) { + Py_DECREF(iterator); + Py_DECREF(item); + return false; + } Py_DECREF(item); } Py_DECREF(iterator); return true; } +// Sets error using keys of 'dict1' and 'dict2'. +// 'dict1' and 'dict2' are assumed to be Python dictionaries. +void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg, + bool* is_type_error) { + PyObject* k1 = PyDict_Keys(dict1); + PyObject* k2 = PyDict_Keys(dict2); + *is_type_error = false; + *error_msg = tensorflow::strings::StrCat( + "The two dictionaries don't have the same set of keys. " + "First structure has keys ", + PyObjectToString(k1), ", while second structure has keys ", + PyObjectToString(k2)); + Py_DECREF(k1); + Py_DECREF(k2); +} + +// Returns true iff there were no "internal" errors. In other words, +// errors that has nothing to do with structure checking. +// If an "internal" error occured, the appropriate Python error will be +// set and the caller can propage it directly to the user. +// +// Both `error_msg` and `is_type_error` must be non-null. `error_msg` must +// be empty. +// Leaves `error_msg` empty if structures matched. Else, fills `error_msg` +// with appropriate error and sets `is_type_error` to true iff +// the error to be raised should be TypeError. +bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, + string* error_msg, bool* is_type_error) { + DCHECK(error_msg); + DCHECK(is_type_error); + const bool is_seq1 = IsSequence(o1); + const bool is_seq2 = IsSequence(o2); + if (PyErr_Occurred()) return false; + if (is_seq1 != is_seq2) { + string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2); + string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1); + *is_type_error = false; + *error_msg = tensorflow::strings::StrCat( + "Substructure \"", seq_str, "\" is a sequence, while substructure \"", + non_seq_str, "\" is not"); + return true; + } + + // Got to scalars, so finished checking. Structures are the same. + if (!is_seq1) return true; + + if (check_types) { + const PyTypeObject* type1 = o1->ob_type; + const PyTypeObject* type2 = o2->ob_type; + + // We treat two different namedtuples with identical name and fields + // as having the same type. + const PyObject* o1_tuple = IsNamedtuple(o1, true); + if (o1_tuple == nullptr) return false; + const PyObject* o2_tuple = IsNamedtuple(o2, true); + if (o2_tuple == nullptr) { + Py_DECREF(o1_tuple); + return false; + } + bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True; + Py_DECREF(o1_tuple); + Py_DECREF(o2_tuple); + + if (both_tuples) { + const PyObject* same_tuples = SameNamedtuples(o1, o2); + if (same_tuples == nullptr) return false; + bool not_same_tuples = same_tuples != Py_True; + Py_DECREF(same_tuples); + if (not_same_tuples) { + *is_type_error = true; + *error_msg = tensorflow::strings::StrCat( + "The two namedtuples don't have the same sequence type. " + "First structure ", + PyObjectToString(o1), " has type ", type1->tp_name, + ", while second structure ", PyObjectToString(o2), " has type ", + type2->tp_name); + return true; + } + } else if (type1 != type2) { + *is_type_error = true; + *error_msg = tensorflow::strings::StrCat( + "The two namedtuples don't have the same sequence type. " + "First structure ", + PyObjectToString(o1), " has type ", type1->tp_name, + ", while second structure ", PyObjectToString(o2), " has type ", + type2->tp_name); + return true; + } + + if (PyDict_Check(o1)) { + if (PyDict_Size(o1) != PyDict_Size(o2)) { + SetDifferentKeysError(o1, o2, error_msg, is_type_error); + return true; + } + + PyObject* key; + Py_ssize_t pos = 0; + while (PyDict_Next(o1, &pos, &key, nullptr)) { + if (PyDict_GetItem(o2, key) == nullptr) { + SetDifferentKeysError(o1, o2, error_msg, is_type_error); + return true; + } + } + } + } + + ValIterator iter1(o1); + ValIterator iter2(o2); + + while (true) { + PyObject* v1 = iter1.next(); + PyObject* v2 = iter2.next(); + if (v1 != nullptr && v2 != nullptr) { + if (Py_EnterRecursiveCall(" in assert_same_structure")) { + return false; + } + bool no_internal_errors = AssertSameStructureHelper( + v1, v2, check_types, error_msg, is_type_error); + Py_LeaveRecursiveCall(); + if (!no_internal_errors) return false; + if (!error_msg->empty()) return true; + } else if (v1 == nullptr && v2 == nullptr) { + // Done with all recursive calls. Structure matched. + return true; + } else { + *is_type_error = false; + *error_msg = tensorflow::strings::StrCat( + "The two structures don't have the same number of elements. ", + "First structure: ", PyObjectToString(o1), + ". Second structure: ", PyObjectToString(o2)); + return true; + } + } +} + } // anonymous namespace void RegisterSequenceClass(PyObject* sequence_class) { @@ -123,5 +377,107 @@ PyObject* Flatten(PyObject* nested) { return nullptr; } } + +PyObject* IsNamedtuple(PyObject* o, bool strict) { + // Must be subclass of tuple + if (!PyTuple_Check(o)) { + Py_RETURN_FALSE; + } + + // If strict, o.__class__.__base__ must be tuple + if (strict) { + PyObject* klass = PyObject_GetAttrString(o, "__class__"); + if (klass == nullptr) return nullptr; + PyObject* base = PyObject_GetAttrString(klass, "__base__"); + Py_DECREF(klass); + if (base == nullptr) return nullptr; + + const PyTypeObject* base_type = reinterpret_cast(base); + // built-in object types are singletons + bool tuple_base = base_type == &PyTuple_Type; + Py_DECREF(base); + if (!tuple_base) { + Py_RETURN_FALSE; + } + } + + if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) { + PyErr_SetString( + PyExc_RuntimeError, + tensorflow::strings::StrCat( + "collections.Sequence type has not been set. " + "Please call RegisterSequenceClass before using this module") + .c_str()); + return nullptr; + } + + // o must have attribute '_fields' and every element in + // '_fields' must be a string. + int has_fields = PyObject_HasAttrString(o, "_fields"); + if (!has_fields) { + Py_RETURN_FALSE; + } + + Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields")); + int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType); + if (is_instance == 0) { + Py_RETURN_FALSE; + } else if (is_instance == -1) { + return nullptr; + } + + Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), "")); + const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get()); + for (Py_ssize_t i = 0; i < s; ++i) { + // PySequence_Fast_GET_ITEM returns borrowed ref + PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i); + if (!IsString(elem)) { + Py_RETURN_FALSE; + } + } + + Py_RETURN_TRUE; +} + +PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) { + PyObject* f1 = PyObject_GetAttrString(o1, "_fields"); + PyObject* f2 = PyObject_GetAttrString(o2, "_fields"); + if (f1 == nullptr || f2 == nullptr) { + Py_XDECREF(f1); + Py_XDECREF(f2); + PyErr_SetString( + PyExc_RuntimeError, + "Expected namedtuple-like objects (that have _fields attr)"); + return nullptr; + } + + if (PyObject_RichCompareBool(f1, f2, Py_NE)) { + Py_RETURN_FALSE; + } + + if (GetClassName(o1).compare(GetClassName(o2)) == 0) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) { + string error_msg; + bool is_type_error = false; + AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error); + if (!error_msg.empty()) { + PyErr_SetString( + is_type_error ? PyExc_TypeError : PyExc_ValueError, + tensorflow::strings::StrCat( + "The two structures don't have the same nested structure.\n\n", + "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ", + PyObjectToString(o2), "\n\nMore specifically: ", error_msg) + .c_str()); + return nullptr; + } + Py_RETURN_NONE; +} + } // namespace swig } // namespace tensorflow diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 2af71dc..c325baa 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -33,6 +33,57 @@ namespace swig { // dict. bool IsSequence(PyObject* o); +// Implements the same interface as tensorflow.util.nest._is_namedtuple +// Returns Py_True iff `instance` should be considered a `namedtuple`. +// +// Args: +// instance: An instance of a Python object. +// strict: If True, `instance` is considered to be a `namedtuple` only if +// it is a "plain" namedtuple. For instance, a class inheriting +// from a `namedtuple` will be considered to be a `namedtuple` +// iff `strict=False`. +// +// Returns: +// True if `instance` is a `namedtuple`. +PyObject* IsNamedtuple(PyObject* o, bool strict); + +// Implements the same interface as tensorflow.util.nest._same_namedtuples +// Returns Py_True iff the two namedtuples have the same name and fields. +// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have +// '_fields' attribute). +PyObject* SameNamedtuples(PyObject* o1, PyObject* o2); + +// Asserts that two structures are nested in the same way. +// +// Note that namedtuples with identical name and fields are always considered +// to have the same shallow structure (even with `check_types=True`). +// For intance, this code will print `True`: +// +// ```python +// def nt(a, b): +// return collections.namedtuple('foo', 'a b')(a, b) +// print(assert_same_structure(nt(0, 1), nt(2, 3))) +// ``` +// +// Args: +// nest1: an arbitrarily nested structure. +// nest2: an arbitrarily nested structure. +// check_types: if `true`, types of sequences are checked as +// well, including the keys of dictionaries. If set to `false`, for example +// a list and a tuple of objects will look the same if they have the same +// size. Note that namedtuples with identical name and fields are always +// considered to have the same shallow structure. +// +// Raises: +// ValueError: If the two structures do not have the same number of elements or +// if the two structures are not nested in the same way. +// TypeError: If the two structures differ in the type of sequence in any of +// their substructures. Only possible if `check_types` is `True`. +// +// Returns: +// Py_None on success, nullptr on error. +PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types); + // Implements the same interface as tensorflow.util.nest.flatten // // Returns a flat list from a given nested structure. diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index d69084f..b7f201b 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -34,6 +34,15 @@ limitations under the License. %unignore tensorflow::swig::IsSequence; %noexception tensorflow::swig::IsSequence; +%unignore tensorflow::swig::IsNamedtuple; +%noexception tensorflow::swig::IsNamedtuple; + +%unignore tensorflow::swig::SameNamedtuples; +%noexception tensorflow::swig::SameNamedtuples; + +%unignore tensorflow::swig::AssertSameStructure; +%noexception tensorflow::swig::AssertSameStructure; + %unignore tensorflow::swig::Flatten; %noexception tensorflow::swig::Flatten; -- 2.7.4