shallow_tree: an arbitrarily nested structure.
input_tree: an arbitrarily nested structure.
check_types: if `True` (default) the sequence types of `shallow_tree` and
- `input_tree` have to be the same.
+ `input_tree` have to be the same. Note that even with check_types==True,
+ this function will consider two different namedtuple classes with the same
+ name and _fields attribute to be the same class.
Raises:
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
"Input has type: %s." % type(input_tree))
if check_types and not isinstance(input_tree, type(shallow_tree)):
- raise TypeError(
- "The two structures don't have the same sequence type. Input "
- "structure has type %s, while shallow structure has type %s."
- % (type(input_tree), type(shallow_tree)))
+ # Duck-typing means that nest should be fine with two different
+ # namedtuples with identical name and fields.
+ shallow_is_namedtuple = _is_namedtuple(shallow_tree, False)
+ input_is_namedtuple = _is_namedtuple(input_tree, False)
+ if shallow_is_namedtuple and input_is_namedtuple:
+ if not _same_namedtuples(shallow_tree, input_tree):
+ raise TypeError(
+ "The two namedtuples don't have the same sequence type. Input "
+ "structure has type %s, while shallow structure has type %s."
+ % (type(input_tree), type(shallow_tree)))
+ else:
+ raise TypeError(
+ "The two structures don't have the same sequence type. Input "
+ "structure has type %s, while shallow structure has type %s."
+ % (type(input_tree), type(shallow_tree)))
if len(input_tree) != len(shallow_tree):
raise ValueError(
inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
nest.assert_shallow_structure(inp_ab, inp_ba)
+ # This assertion is expected to pass: two namedtuples with the same
+ # name and field names are considered to be identical.
+ same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
+ same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
+ inp_shallow = same_name_type_0(1, 2)
+ inp_deep = same_name_type_1(1, [1, 2, 3])
+ nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
+ nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
+
def testFlattenUpTo(self):
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]