raise TypeError("nest only supports dicts with sortable keys.")
+def _is_namedtuple(instance):
+ """Returns True iff `instance` is a `namedtuple`."""
+ return (
+ isinstance(instance, tuple) and
+ hasattr(instance, "_fields") and
+ isinstance(instance._fields, _collections.Sequence) and
+ all(isinstance(f, _six.string_types) for f in instance._fields))
+
+
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
# corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args))
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
- elif (isinstance(instance, tuple) and
- hasattr(instance, "_fields") and
- isinstance(instance._fields, _collections.Sequence) and
- all(isinstance(f, _six.string_types) for f in instance._fields)):
- # This is a namedtuple
+ elif _is_namedtuple(instance):
return type(instance)(*args)
else:
# Not a namedtuple
return _sequence_like(structure, level_traverse)
+def yield_flat_paths(nest):
+ """Yields paths for some nested structure.
+
+ Paths are lists of objects which can be str-converted, which may include
+ integers or other types which are used as indices in a dict.
+
+ The flat list will be in the corresponding order as if you called
+ `snt.nest.flatten` on the structure. This is handy for naming Tensors such
+ the TF scope structure matches the tuple structure.
+
+ E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
+
+ ```shell
+ >>> nest.flatten(value)
+ [3, 23, 42]
+ >>> list(nest.yield_flat_paths(value))
+ [('a',), ('b', 'c'), ('b', 'd')]
+ ```
+
+ ```shell
+ >>> list(nest.yield_flat_paths({'a': [3]}))
+ [('a', 0)]
+ >>> list(nest.yield_flat_paths({'a': 3}))
+ [('a',)]
+ ```
+
+ Args:
+ nest: the value to produce a flattened paths list for.
+
+ Yields:
+ Tuples containing index or key values which form the path to a specific
+ leaf value in the nested structure.
+ """
+
+ # The _maybe_add_final_path_element function is used below in order to avoid
+ # adding trailing slashes when the sub-element recursed into is a leaf.
+ if isinstance(nest, dict):
+ for key in _sorted(nest):
+ value = nest[key]
+ for sub_path in yield_flat_paths(value):
+ yield (key,) + sub_path
+ elif _is_namedtuple(nest):
+ for key in nest._fields:
+ value = getattr(nest, key)
+ for sub_path in yield_flat_paths(value):
+ yield (key,) + sub_path
+ elif isinstance(nest, _six.string_types):
+ yield ()
+ elif isinstance(nest, _collections.Sequence):
+ for idx, value in enumerate(nest):
+ for sub_path in yield_flat_paths(value):
+ yield (idx,) + sub_path
+ else:
+ yield ()
+
+
+def flatten_with_joined_string_paths(structure, separator="/"):
+ """Returns a list of (string path, data element) tuples.
+
+ The order of tuples produced matches that of `nest.flatten`. This allows you
+ to flatten a nested structure while keeping information about where in the
+ structure each data element was located. See `nest.yield_flat_paths`
+ for more information.
+
+ Args:
+ structure: the nested structure to flatten.
+ separator: string to separate levels of hierarchy in the results, defaults
+ to '/'.
+
+ Returns:
+ A list of (string, data element) tuples.
+ """
+ flat_paths = yield_flat_paths(structure)
+ def stringify_and_join(path_elements):
+ return separator.join(str(path_element) for path_element in path_elements)
+ flat_string_paths = [stringify_and_join(path) for path in flat_paths]
+ return list(zip(flat_string_paths, flatten(structure)))
+
+
_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
"flatten_up_to",
"map_structure_up_to",
"get_traverse_shallow_structure",
+ "yield_flat_paths",
+ "flatten_with_joined_string_paths",
]
remove_undocumented(__name__, _allowed_symbols)
TypeError, "didn't return a depth=1 structure of bools"):
nest.get_traverse_shallow_structure(lambda _: [1], [1])
+ def testYieldFlatStringPaths(self):
+ for inputs_expected in ({"inputs": [], "expected": []},
+ {"inputs": 3, "expected": [()]},
+ {"inputs": [3], "expected": [(0,)]},
+ {"inputs": {"a": 3}, "expected": [("a",)]},
+ {"inputs": {"a": {"b": 4}},
+ "expected": [("a", "b")]},
+ {"inputs": [{"a": 2}], "expected": [(0, "a")]},
+ {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
+ {"inputs": [{"a": [(23, 42)]}],
+ "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
+ {"inputs": [{"a": ([23], 42)}],
+ "expected": [(0, "a", 0, 0), (0, "a", 1)]},
+ {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
+ "expected": [("a", "a"), ("c", 0, 0, 0)]},
+ {"inputs": {"0": [{"1": 23}]},
+ "expected": [("0", 0, "1")]}):
+ inputs = inputs_expected["inputs"]
+ expected = inputs_expected["expected"]
+ self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
+
+ def testFlattenWithStringPaths(self):
+ for inputs_expected in (
+ {"inputs": [], "expected": []},
+ {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
+ {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
+ inputs = inputs_expected["inputs"]
+ expected = inputs_expected["expected"]
+ self.assertEqual(
+ nest.flatten_with_joined_string_paths(inputs, separator="/"),
+ expected)
+
+ # Need a separate test for namedtuple as we can't declare tuple definitions
+ # in the @parameterized arguments.
+ def testFlattenNamedTuple(self):
+ # pylint: disable=invalid-name
+ Foo = collections.namedtuple("Foo", ["a", "b"])
+ Bar = collections.namedtuple("Bar", ["c", "d"])
+ # pylint: enable=invalid-name
+ test_cases = [
+ (Foo(a=3, b=Bar(c=23, d=42)),
+ [("a", 3), ("b/c", 23), ("b/d", 42)]),
+ (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")),
+ [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
+ (Bar(c=42, d=43),
+ [("c", 42), ("d", 43)]),
+ (Bar(c=[42], d=43),
+ [("c/0", 42), ("d", 43)]),
+ ]
+ for inputs, expected in test_cases:
+ self.assertEqual(
+ list(nest.flatten_with_joined_string_paths(inputs)), expected)
+
if __name__ == "__main__":
test.main()