srcs = ["util/util.cc"],
hdrs = ["util/util.h"],
deps = [
+ ":safe_ptr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//util/python:python_headers",
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())
- self.assertEqual(previous_count, new_count)
+ # In some cases (specifacally on MacOS), new_count is somehow
+ # smaller than previous_count.
+ # Using plain assert because not all classes using this decorator
+ # have assertLessEqual
+ assert new_count <= previous_count, (
+ "new_count(%d) is not less than or equal to previous_count(%d)" % (
+ new_count, previous_count))
gc.enable()
return decorator
with self.test_session():
nums = np.array([1, 2, 3, 4, 5, 6])
with self.assertRaisesRegexp(
- TypeError, r"two structures don't have the same sequence type."):
+ TypeError, r"two structures don't have the same nested structure"):
# lambda emits tuple, but dtype is a list
functional_ops.map_fn(
lambda x: ((x + 3) * 2, -(x + 3) * 2),
initializer = np.array(1.0)
# Multiply a * 1 each time
with self.assertRaisesRegexp(
- ValueError, "two structures don't have the same number of elements"):
+ ValueError, "two structures don't have the same nested structure"):
functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
def testScan_Scoped(self):
Returns:
True if `instance` is a `namedtuple`.
"""
- # Attemp to limit the test to plain namedtuple (not stuff inheriting from it).
- if not isinstance(instance, tuple):
- return False
- if strict and instance.__class__.__base__ != tuple:
- return False
- return (
- hasattr(instance, "_fields") and
- isinstance(instance._fields, _collections.Sequence) and
- all(isinstance(f, _six.string_types) for f in instance._fields))
+ return _pywrap_tensorflow.IsNamedtuple(instance, strict)
def _sequence_like(instance, args):
def _same_namedtuples(nest1, nest2):
"""Returns True if the two namedtuples have the same name and fields."""
- if nest1._fields != nest2._fields:
- return False
- if nest1.__class__.__name__ != nest2.__class__.__name__:
- return False
- return True
-
-
-def _recursive_assert_same_structure(nest1, nest2, check_types):
- """Helper function for `assert_same_structure`.
-
- See `assert_same_structure` for further information about namedtuples.
-
- Args:
- nest1: An arbitrarily nested structure.
- nest2: An arbitrarily nested structure.
- check_types: If `True` (default) types of sequences are checked as
- well, including the keys of dictionaries. If set to `False`, for example
- a list and a tuple of objects will look the same if they have the same
- size. Note that namedtuples with identical name and fields are always
- considered to have the same shallow structure.
-
- Returns:
- True if `nest1` and `nest2` have the same structure.
-
- Raises:
- ValueError: If the two structure don't have the same nested structre.
- TypeError: If the two structure don't have the same sequence type.
- ValueError: If the two dictionaries don't have the same set of keys.
- """
- is_sequence_nest1 = is_sequence(nest1)
- if is_sequence_nest1 != is_sequence(nest2):
- raise ValueError(
- "The two structures don't have the same nested structure.\n\n"
- "First structure: %s\n\nSecond structure: %s." % (nest1, nest2))
-
- if not is_sequence_nest1:
- return # finished checking
-
- if check_types:
- type_nest1 = type(nest1)
- type_nest2 = type(nest2)
-
- # Duck-typing means that nest should be fine with two different namedtuples
- # with identical name and fields.
- if _is_namedtuple(nest1, True) and _is_namedtuple(nest2, True):
- if not _same_namedtuples(nest1, nest2):
- raise TypeError(
- "The two namedtuples don't have the same sequence type. First "
- "structure has type %s, while second structure has type %s."
- % (type_nest1, type_nest2))
- else:
- if type_nest1 != type_nest2:
- raise TypeError(
- "The two structures don't have the same sequence type. First "
- "structure has type %s, while second structure has type %s."
- % (type_nest1, type_nest2))
-
- if isinstance(nest1, dict):
- keys1 = set(_six.iterkeys(nest1))
- keys2 = set(_six.iterkeys(nest2))
- if keys1 != keys2:
- raise ValueError(
- "The two dictionaries don't have the same set of keys. First "
- "structure has keys {}, while second structure has keys {}."
- .format(keys1, keys2))
-
- nest1_as_sequence = [n for n in _yield_value(nest1)]
- nest2_as_sequence = [n for n in _yield_value(nest2)]
- for n1, n2 in zip(nest1_as_sequence, nest2_as_sequence):
- _recursive_assert_same_structure(n1, n2, check_types)
+ return _pywrap_tensorflow.SameNamedtuples(nest1, nest2)
def assert_same_structure(nest1, nest2, check_types=True):
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
- len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
- len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
- if len_nest1 != len_nest2:
- raise ValueError("The two structures don't have the same number of "
- "elements.\n\nFirst structure (%i elements): %s\n\n"
- "Second structure (%i elements): %s"
- % (len_nest1, nest1, len_nest2, nest2))
- _recursive_assert_same_structure(nest1, nest2, check_types)
+ _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
def flatten_dict_items(dictionary):
from __future__ import print_function
import collections
+import time
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class NestTest(test.TestCase):
+ PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
("d", "e", ("f", "g"), "h")))
- point = collections.namedtuple("Point", ["x", "y"])
- structure = (point(x=4, y=2), ((point(x=1, y=0),),))
+ structure = (NestTest.PointXY(x=4, y=2),
+ ((NestTest.PointXY(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(nest.flatten(structure), flat)
restructured_from_flat = nest.pack_sequence_as(structure, flat)
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenDictOrder(self):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
ordered_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
+ Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack_withDicts(self):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
- named_tuple = collections.namedtuple("A", ("b", "c"))
mess = [
"z",
- named_tuple(3, 4),
+ NestTest.Abc(3, 4),
{
"c": [
1,
structure_of_mess = [
14,
- named_tuple("a", True),
+ NestTest.Abc("a", True),
{
"c": [
0,
nest.pack_sequence_as(["hello", "world"],
["and", "goodbye", "again"])
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testIsSequence(self):
self.assertFalse(nest.is_sequence("1234"))
self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
nest.flatten_dict_items(another_bad_dictionary)
+ # pylint does not correctly recognize these as class names and
+ # suggests to use variable style under_score naming.
+ # pylint: disable=invalid-name
+ Named0ab = collections.namedtuple("named_0", ("a", "b"))
+ Named1ab = collections.namedtuple("named_1", ("a", "b"))
+ SameNameab = collections.namedtuple("same_name", ("a", "b"))
+ SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
+ SameNamexy = collections.namedtuple("same_name", ("x", "y"))
+ SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
+ SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
+ NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
+ # pylint: enable=invalid-name
+
+ class SameNamedType1(SameNameab):
+ pass
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testAssertSameStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
with self.assertRaisesRegexp(
ValueError,
- ("don't have the same number of elements\\.\n\n"
- "First structure \\(6 elements\\):.*?"
- "\n\nSecond structure \\(2 elements\\):")):
+ ("The two structures don't have the same nested structure\\.\n\n"
+ "First structure:.*?\n\n"
+ "Second structure:.*\n\n"
+ "More specifically: Substructure "
+ r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
+ 'substructure "type=str str=spam" is not')):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(
ValueError,
- ("don't have the same number of elements\\.\n\n"
- "First structure \\(2 elements\\):.*?"
- "\n\nSecond structure \\(1 elements\\):")):
+ ("The two structures don't have the same nested structure\\.\n\n"
+ "First structure:.*?\n\n"
+ "Second structure:.*\n\n"
+ r'More specifically: Substructure "type=list str=\[0, 1\]" '
+ r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
+ "is not")):
nest.assert_same_structure([0, 1], np.array([0, 1]))
with self.assertRaisesRegexp(
ValueError,
- ("don't have the same number of elements\\.\n\n"
- "First structure \\(1 elements\\):.*"
- "\n\nSecond structure \\(2 elements\\):")):
+ ("The two structures don't have the same nested structure\\.\n\n"
+ "First structure:.*?\n\n"
+ "Second structure:.*\n\n"
+ r'More specifically: Substructure "type=list str=\[0, 1\]" '
+ 'is a sequence, while substructure "type=int str=0" '
+ "is not")):
nest.assert_same_structure(0, [0, 1])
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure(structure1, structure_different_nesting)
- named_type_0 = collections.namedtuple("named_0", ("a", "b"))
- named_type_1 = collections.namedtuple("named_1", ("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
- named_type_0("a", "b"))
+ NestTest.Named0ab("a", "b"))
- nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))
+ nest.assert_same_structure(NestTest.Named0ab(3, 4),
+ NestTest.Named0ab("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure,
- named_type_0(3, 4), named_type_1(3, 4))
+ NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
- nest.assert_same_structure(named_type_0(3, 4), named_type_0([3], 4))
+ nest.assert_same_structure(NestTest.Named0ab(3, 4),
+ NestTest.Named0ab([3], 4))
with self.assertRaisesRegexp(
ValueError,
"don't have the same set of keys"):
nest.assert_same_structure({"a": 1}, {"b": 1})
- same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
- same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
- nest.assert_same_structure(same_name_type_0(0, 1), same_name_type_1(2, 3))
+ nest.assert_same_structure(NestTest.SameNameab(0, 1),
+ NestTest.SameNameab2(2, 3))
# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
- same_name_type_2 = collections.namedtuple("same_name_1", ("x", "y"))
- same_name_type_3 = collections.namedtuple("same_name_1", ("x", "y"))
nest.assert_same_structure(
- same_name_type_0(same_name_type_2(0, 1), 2),
- same_name_type_1(same_name_type_3(2, 3), 4))
+ NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
+ NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
expected_message = "The two structures don't have the same.*"
with self.assertRaisesRegexp(ValueError, expected_message):
- nest.assert_same_structure(same_name_type_0(0, same_name_type_1(1, 2)),
- same_name_type_1(same_name_type_0(0, 1), 2))
+ nest.assert_same_structure(
+ NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
+ NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
- same_name_type_1 = collections.namedtuple("not_same_name", ("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure,
- same_name_type_0(0, 1), same_name_type_1(2, 3))
+ NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
- same_name_type_1 = collections.namedtuple("same_name", ("x", "y"))
self.assertRaises(TypeError, nest.assert_same_structure,
- same_name_type_0(0, 1), same_name_type_1(2, 3))
+ NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
- class SameNamedType1(collections.namedtuple("same_name", ("a", "b"))):
- pass
self.assertRaises(TypeError, nest.assert_same_structure,
- same_name_type_0(0, 1), SameNamedType1(2, 3))
+ NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
+ EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
- empty_nt = collections.namedtuple("empty_nt", "")
- self.assertEqual(empty_nt(), nest.map_structure(lambda x: x + 1,
- empty_nt()))
+ self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
+ NestTest.EmptyNT()))
# This is checking actual equality of types, empty list != empty tuple
self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
+ ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructureWithStrings(self):
- ab_tuple = collections.namedtuple("ab_tuple", "a, b")
- inp_a = ab_tuple(a="foo", b=("bar", "baz"))
- inp_b = ab_tuple(a=2, b=(1, 3))
+ inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
+ inp_b = NestTest.ABTuple(a=2, b=(1, 3))
out = nest.map_structure(lambda string, repeats: string * repeats,
inp_a,
inp_b)
self.assertEqual("bar", out.b[0])
self.assertEqual("bazbazbaz", out.b[1])
- nt = ab_tuple(a=("something", "something_else"),
- b="yet another thing")
+ nt = NestTest.ABTuple(a=("something", "something_else"),
+ b="yet another thing")
rev_nt = nest.map_structure(lambda x: x[::-1], nt)
# Check the output is the correct structure, and all strings are reversed.
nest.assert_same_structure(nt, rev_nt)
# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
- same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
- same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
- inp_shallow = same_name_type_0(1, 2)
- inp_deep = same_name_type_1(1, [1, 2, 3])
+ inp_shallow = NestTest.SameNameab(1, 2)
+ inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
[1, {"c": 2}, 3, (4, 5)])
# Namedtuples.
- ab_tuple = collections.namedtuple("ab_tuple", "a, b")
+ ab_tuple = NestTest.ABTuple
input_tree = ab_tuple(a=[0, 1], b=2)
shallow_tree = ab_tuple(a=0, b=1)
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
list(nest.flatten_with_joined_string_paths(inputs)), expected)
+class NestBenchmark(test.Benchmark):
+
+ def run_and_report(self, s1, s2, name):
+ burn_iter, test_iter = 100, 30000
+
+ for _ in xrange(burn_iter):
+ nest.assert_same_structure(s1, s2)
+
+ t0 = time.time()
+ for _ in xrange(test_iter):
+ nest.assert_same_structure(s1, s2)
+ t1 = time.time()
+
+ self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
+ name=name)
+
+ def benchmark_assert_structure(self):
+ s1 = (((1, 2), 3), 4, (5, 6))
+ s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
+ self.run_and_report(s1, s2, "assert_same_structure_6_elem")
+
+ s1 = (((1, 2), 3), 4, (5, 6)) * 10
+ s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
+ self.run_and_report(s1, s2, "assert_same_structure_60_elem")
+
+
if __name__ == "__main__":
test.main()
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
namespace tensorflow {
namespace swig {
bool WarnedThatSetIsNotSequence = false;
+bool IsString(PyObject* o) {
+ return PyBytes_Check(o) ||
+#if PY_MAJOR_VERSION < 3
+ PyString_Check(o) ||
+#endif
+ PyUnicode_Check(o);
+}
+
+// Equivalent to Python's 'o.__class__.__name__'
+// Note that '__class__' attribute is set only in new-style classes.
+// A lot of tensorflow code uses __class__ without checks, so it seems like
+// we only support new-style classes.
+StringPiece GetClassName(PyObject* o) {
+ // __class__ is equivalent to type() for new style classes.
+ // type() is equivalent to PyObject_Type()
+ // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
+ // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
+ // we don't need here.
+ PyTypeObject* type = o->ob_type;
+
+ // __name__ is the value of `tp_name` after the last '.'
+ // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
+ StringPiece name(type->tp_name);
+ size_t pos = name.rfind('.');
+ if (pos != StringPiece::npos) {
+ name.remove_prefix(pos + 1);
+ }
+ return name;
+}
+
+string PyObjectToString(PyObject* o) {
+ if (o == nullptr) {
+ return "<null object>";
+ }
+ PyObject* str = PyObject_Str(o);
+ if (str) {
+#if PY_MAJOR_VERSION < 3
+ string s(PyString_AS_STRING(str));
+#else
+ string s(PyUnicode_AsUTF8(str));
+#endif
+ Py_DECREF(str);
+ return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
+ } else {
+ return "<failed to execute str() on object>";
+ }
+}
+
+// Implements the same idea as tensorflow.util.nest._yield_value
+// During construction we check if the iterable is a dictionary.
+// If so, we construct a sequence from its sorted keys that will be used
+// for iteration.
+// If not, we construct a sequence directly from the iterable.
+// At each step, we get the next element from the sequence and use it
+// either as a key or return it directly.
+//
+// 'iterable' must not be modified while ValIterator is used.
+class ValIterator {
+ public:
+ explicit ValIterator(PyObject* iterable) : dict_(nullptr), index_(0) {
+ if (PyDict_Check(iterable)) {
+ dict_ = iterable;
+ // PyDict_Keys returns a list, which can be used with
+ // PySequence_Fast_GET_ITEM.
+ seq_ = PyDict_Keys(iterable);
+ // Iterate through dictionaries in a deterministic order by sorting the
+ // keys. Notice this means that we ignore the original order of
+ // `OrderedDict` instances. This is intentional, to avoid potential
+ // bugs caused by mixing ordered and plain dicts (e.g., flattening
+ // a dict but using a corresponding `OrderedDict` to pack it back).
+ PyList_Sort(seq_);
+ } else {
+ seq_ = PySequence_Fast(iterable, "");
+ }
+ size_ = PySequence_Fast_GET_SIZE(seq_);
+ }
+
+ ~ValIterator() { Py_DECREF(seq_); }
+
+ // Return a borrowed reference to the next element from iterable.
+ // Return nullptr when iteration is over.
+ PyObject* next() {
+ PyObject* element = nullptr;
+ if (index_ < size_) {
+ // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
+ // references.
+ element = PySequence_Fast_GET_ITEM(seq_, index_);
+ ++index_;
+ if (dict_ != nullptr) {
+ element = PyDict_GetItem(dict_, element);
+ if (element == nullptr) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Dictionary was modified during iteration over it");
+ return nullptr;
+ }
+ }
+ }
+ return element;
+ }
+
+ private:
+ PyObject* seq_;
+ PyObject* dict_;
+ Py_ssize_t size_;
+ Py_ssize_t index_;
+};
+
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
"so consider avoiding using them.";
WarnedThatSetIsNotSequence = true;
}
- if (CollectionsSequenceType == nullptr) {
+ if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
PyErr_SetString(
PyExc_RuntimeError,
tensorflow::strings::StrCat(
}
int is_instance = PyObject_IsInstance(o, CollectionsSequenceType);
if (is_instance == -1) return -1;
- return static_cast<int>(is_instance != 0 && !PyBytes_Check(o) &&
-#if PY_MAJOR_VERSION < 3
- !PyString_Check(o) &&
-#endif
- !PyUnicode_Check(o));
+ return static_cast<int>(is_instance != 0 && !IsString(o));
}
bool FlattenHelper(PyObject* nested, PyObject* list) {
// while the method is running.
PyObject* key = PyList_GET_ITEM(keys, i);
PyObject* val = PyDict_GetItem(nested, key);
- if (Py_EnterRecursiveCall(" in Flatten")) {
+ if (Py_EnterRecursiveCall(" in flatten")) {
Py_DECREF(keys);
return false;
}
- FlattenHelper(val, list);
+ const bool success = FlattenHelper(val, list);
Py_LeaveRecursiveCall();
+ if (!success) {
+ Py_DECREF(keys);
+ return false;
+ }
}
Py_DECREF(keys);
return true;
PyObject* item;
PyObject* iterator = PyObject_GetIter(nested);
while ((item = PyIter_Next(iterator)) != nullptr) {
- FlattenHelper(item, list);
+ if (Py_EnterRecursiveCall(" in flatten")) {
+ Py_DECREF(iterator);
+ Py_DECREF(item);
+ return false;
+ }
+ bool success = FlattenHelper(item, list);
+ Py_LeaveRecursiveCall();
+ if (!success) {
+ Py_DECREF(iterator);
+ Py_DECREF(item);
+ return false;
+ }
Py_DECREF(item);
}
Py_DECREF(iterator);
return true;
}
+// Sets error using keys of 'dict1' and 'dict2'.
+// 'dict1' and 'dict2' are assumed to be Python dictionaries.
+void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
+ bool* is_type_error) {
+ PyObject* k1 = PyDict_Keys(dict1);
+ PyObject* k2 = PyDict_Keys(dict2);
+ *is_type_error = false;
+ *error_msg = tensorflow::strings::StrCat(
+ "The two dictionaries don't have the same set of keys. "
+ "First structure has keys ",
+ PyObjectToString(k1), ", while second structure has keys ",
+ PyObjectToString(k2));
+ Py_DECREF(k1);
+ Py_DECREF(k2);
+}
+
+// Returns true iff there were no "internal" errors. In other words,
+// errors that has nothing to do with structure checking.
+// If an "internal" error occured, the appropriate Python error will be
+// set and the caller can propage it directly to the user.
+//
+// Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
+// be empty.
+// Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
+// with appropriate error and sets `is_type_error` to true iff
+// the error to be raised should be TypeError.
+bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
+ string* error_msg, bool* is_type_error) {
+ DCHECK(error_msg);
+ DCHECK(is_type_error);
+ const bool is_seq1 = IsSequence(o1);
+ const bool is_seq2 = IsSequence(o2);
+ if (PyErr_Occurred()) return false;
+ if (is_seq1 != is_seq2) {
+ string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
+ string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
+ *is_type_error = false;
+ *error_msg = tensorflow::strings::StrCat(
+ "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
+ non_seq_str, "\" is not");
+ return true;
+ }
+
+ // Got to scalars, so finished checking. Structures are the same.
+ if (!is_seq1) return true;
+
+ if (check_types) {
+ const PyTypeObject* type1 = o1->ob_type;
+ const PyTypeObject* type2 = o2->ob_type;
+
+ // We treat two different namedtuples with identical name and fields
+ // as having the same type.
+ const PyObject* o1_tuple = IsNamedtuple(o1, true);
+ if (o1_tuple == nullptr) return false;
+ const PyObject* o2_tuple = IsNamedtuple(o2, true);
+ if (o2_tuple == nullptr) {
+ Py_DECREF(o1_tuple);
+ return false;
+ }
+ bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
+ Py_DECREF(o1_tuple);
+ Py_DECREF(o2_tuple);
+
+ if (both_tuples) {
+ const PyObject* same_tuples = SameNamedtuples(o1, o2);
+ if (same_tuples == nullptr) return false;
+ bool not_same_tuples = same_tuples != Py_True;
+ Py_DECREF(same_tuples);
+ if (not_same_tuples) {
+ *is_type_error = true;
+ *error_msg = tensorflow::strings::StrCat(
+ "The two namedtuples don't have the same sequence type. "
+ "First structure ",
+ PyObjectToString(o1), " has type ", type1->tp_name,
+ ", while second structure ", PyObjectToString(o2), " has type ",
+ type2->tp_name);
+ return true;
+ }
+ } else if (type1 != type2) {
+ *is_type_error = true;
+ *error_msg = tensorflow::strings::StrCat(
+ "The two namedtuples don't have the same sequence type. "
+ "First structure ",
+ PyObjectToString(o1), " has type ", type1->tp_name,
+ ", while second structure ", PyObjectToString(o2), " has type ",
+ type2->tp_name);
+ return true;
+ }
+
+ if (PyDict_Check(o1)) {
+ if (PyDict_Size(o1) != PyDict_Size(o2)) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ return true;
+ }
+
+ PyObject* key;
+ Py_ssize_t pos = 0;
+ while (PyDict_Next(o1, &pos, &key, nullptr)) {
+ if (PyDict_GetItem(o2, key) == nullptr) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ return true;
+ }
+ }
+ }
+ }
+
+ ValIterator iter1(o1);
+ ValIterator iter2(o2);
+
+ while (true) {
+ PyObject* v1 = iter1.next();
+ PyObject* v2 = iter2.next();
+ if (v1 != nullptr && v2 != nullptr) {
+ if (Py_EnterRecursiveCall(" in assert_same_structure")) {
+ return false;
+ }
+ bool no_internal_errors = AssertSameStructureHelper(
+ v1, v2, check_types, error_msg, is_type_error);
+ Py_LeaveRecursiveCall();
+ if (!no_internal_errors) return false;
+ if (!error_msg->empty()) return true;
+ } else if (v1 == nullptr && v2 == nullptr) {
+ // Done with all recursive calls. Structure matched.
+ return true;
+ } else {
+ *is_type_error = false;
+ *error_msg = tensorflow::strings::StrCat(
+ "The two structures don't have the same number of elements. ",
+ "First structure: ", PyObjectToString(o1),
+ ". Second structure: ", PyObjectToString(o2));
+ return true;
+ }
+ }
+}
+
} // anonymous namespace
void RegisterSequenceClass(PyObject* sequence_class) {
return nullptr;
}
}
+
+PyObject* IsNamedtuple(PyObject* o, bool strict) {
+ // Must be subclass of tuple
+ if (!PyTuple_Check(o)) {
+ Py_RETURN_FALSE;
+ }
+
+ // If strict, o.__class__.__base__ must be tuple
+ if (strict) {
+ PyObject* klass = PyObject_GetAttrString(o, "__class__");
+ if (klass == nullptr) return nullptr;
+ PyObject* base = PyObject_GetAttrString(klass, "__base__");
+ Py_DECREF(klass);
+ if (base == nullptr) return nullptr;
+
+ const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
+ // built-in object types are singletons
+ bool tuple_base = base_type == &PyTuple_Type;
+ Py_DECREF(base);
+ if (!tuple_base) {
+ Py_RETURN_FALSE;
+ }
+ }
+
+ if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please call RegisterSequenceClass before using this module")
+ .c_str());
+ return nullptr;
+ }
+
+ // o must have attribute '_fields' and every element in
+ // '_fields' must be a string.
+ int has_fields = PyObject_HasAttrString(o, "_fields");
+ if (!has_fields) {
+ Py_RETURN_FALSE;
+ }
+
+ Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
+ int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType);
+ if (is_instance == 0) {
+ Py_RETURN_FALSE;
+ } else if (is_instance == -1) {
+ return nullptr;
+ }
+
+ Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
+ const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
+ for (Py_ssize_t i = 0; i < s; ++i) {
+ // PySequence_Fast_GET_ITEM returns borrowed ref
+ PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
+ if (!IsString(elem)) {
+ Py_RETURN_FALSE;
+ }
+ }
+
+ Py_RETURN_TRUE;
+}
+
+PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
+ PyObject* f1 = PyObject_GetAttrString(o1, "_fields");
+ PyObject* f2 = PyObject_GetAttrString(o2, "_fields");
+ if (f1 == nullptr || f2 == nullptr) {
+ Py_XDECREF(f1);
+ Py_XDECREF(f2);
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ "Expected namedtuple-like objects (that have _fields attr)");
+ return nullptr;
+ }
+
+ if (PyObject_RichCompareBool(f1, f2, Py_NE)) {
+ Py_RETURN_FALSE;
+ }
+
+ if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
+PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) {
+ string error_msg;
+ bool is_type_error = false;
+ AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error);
+ if (!error_msg.empty()) {
+ PyErr_SetString(
+ is_type_error ? PyExc_TypeError : PyExc_ValueError,
+ tensorflow::strings::StrCat(
+ "The two structures don't have the same nested structure.\n\n",
+ "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
+ PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
+ .c_str());
+ return nullptr;
+ }
+ Py_RETURN_NONE;
+}
+
} // namespace swig
} // namespace tensorflow
// dict.
bool IsSequence(PyObject* o);
+// Implements the same interface as tensorflow.util.nest._is_namedtuple
+// Returns Py_True iff `instance` should be considered a `namedtuple`.
+//
+// Args:
+// instance: An instance of a Python object.
+// strict: If True, `instance` is considered to be a `namedtuple` only if
+// it is a "plain" namedtuple. For instance, a class inheriting
+// from a `namedtuple` will be considered to be a `namedtuple`
+// iff `strict=False`.
+//
+// Returns:
+// True if `instance` is a `namedtuple`.
+PyObject* IsNamedtuple(PyObject* o, bool strict);
+
+// Implements the same interface as tensorflow.util.nest._same_namedtuples
+// Returns Py_True iff the two namedtuples have the same name and fields.
+// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
+// '_fields' attribute).
+PyObject* SameNamedtuples(PyObject* o1, PyObject* o2);
+
+// Asserts that two structures are nested in the same way.
+//
+// Note that namedtuples with identical name and fields are always considered
+// to have the same shallow structure (even with `check_types=True`).
+// For intance, this code will print `True`:
+//
+// ```python
+// def nt(a, b):
+// return collections.namedtuple('foo', 'a b')(a, b)
+// print(assert_same_structure(nt(0, 1), nt(2, 3)))
+// ```
+//
+// Args:
+// nest1: an arbitrarily nested structure.
+// nest2: an arbitrarily nested structure.
+// check_types: if `true`, types of sequences are checked as
+// well, including the keys of dictionaries. If set to `false`, for example
+// a list and a tuple of objects will look the same if they have the same
+// size. Note that namedtuples with identical name and fields are always
+// considered to have the same shallow structure.
+//
+// Raises:
+// ValueError: If the two structures do not have the same number of elements or
+// if the two structures are not nested in the same way.
+// TypeError: If the two structures differ in the type of sequence in any of
+// their substructures. Only possible if `check_types` is `True`.
+//
+// Returns:
+// Py_None on success, nullptr on error.
+PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types);
+
// Implements the same interface as tensorflow.util.nest.flatten
//
// Returns a flat list from a given nested structure.
%unignore tensorflow::swig::IsSequence;
%noexception tensorflow::swig::IsSequence;
+%unignore tensorflow::swig::IsNamedtuple;
+%noexception tensorflow::swig::IsNamedtuple;
+
+%unignore tensorflow::swig::SameNamedtuples;
+%noexception tensorflow::swig::SameNamedtuples;
+
+%unignore tensorflow::swig::AssertSameStructure;
+%noexception tensorflow::swig::AssertSameStructure;
+
%unignore tensorflow::swig::Flatten;
%noexception tensorflow::swig::Flatten;