From 127f1b21dccf4eb46b6cf80657dfa340ed1c1ede Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Wed, 20 Dec 2017 23:02:11 -0800 Subject: [PATCH] [tf nest] Add additional key yielder. This is just copying a utility function created by Malcolm Reynolds. PiperOrigin-RevId: 179775504 --- tensorflow/python/util/nest.py | 96 +++++++++++++++++++++++++++-- tensorflow/python/util/nest_test.py | 53 ++++++++++++++++ 2 files changed, 144 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 5c066e2bef..4ce871de72 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -47,6 +47,15 @@ def _sorted(dict_): raise TypeError("nest only supports dicts with sortable keys.") +def _is_namedtuple(instance): + """Returns True iff `instance` is a `namedtuple`.""" + return ( + isinstance(instance, tuple) and + hasattr(instance, "_fields") and + isinstance(instance._fields, _collections.Sequence) and + all(isinstance(f, _six.string_types) for f in instance._fields)) + + def _sequence_like(instance, args): """Converts the sequence `args` to the same type as `instance`. @@ -66,11 +75,7 @@ def _sequence_like(instance, args): # corresponding `OrderedDict` to pack it back). result = dict(zip(_sorted(instance), args)) return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) - elif (isinstance(instance, tuple) and - hasattr(instance, "_fields") and - isinstance(instance._fields, _collections.Sequence) and - all(isinstance(f, _six.string_types) for f in instance._fields)): - # This is a namedtuple + elif _is_namedtuple(instance): return type(instance)(*args) else: # Not a namedtuple @@ -677,6 +682,85 @@ def get_traverse_shallow_structure(traverse_fn, structure): return _sequence_like(structure, level_traverse) +def yield_flat_paths(nest): + """Yields paths for some nested structure. + + Paths are lists of objects which can be str-converted, which may include + integers or other types which are used as indices in a dict. + + The flat list will be in the corresponding order as if you called + `snt.nest.flatten` on the structure. This is handy for naming Tensors such + the TF scope structure matches the tuple structure. + + E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` + + ```shell + >>> nest.flatten(value) + [3, 23, 42] + >>> list(nest.yield_flat_paths(value)) + [('a',), ('b', 'c'), ('b', 'd')] + ``` + + ```shell + >>> list(nest.yield_flat_paths({'a': [3]})) + [('a', 0)] + >>> list(nest.yield_flat_paths({'a': 3})) + [('a',)] + ``` + + Args: + nest: the value to produce a flattened paths list for. + + Yields: + Tuples containing index or key values which form the path to a specific + leaf value in the nested structure. + """ + + # The _maybe_add_final_path_element function is used below in order to avoid + # adding trailing slashes when the sub-element recursed into is a leaf. + if isinstance(nest, dict): + for key in _sorted(nest): + value = nest[key] + for sub_path in yield_flat_paths(value): + yield (key,) + sub_path + elif _is_namedtuple(nest): + for key in nest._fields: + value = getattr(nest, key) + for sub_path in yield_flat_paths(value): + yield (key,) + sub_path + elif isinstance(nest, _six.string_types): + yield () + elif isinstance(nest, _collections.Sequence): + for idx, value in enumerate(nest): + for sub_path in yield_flat_paths(value): + yield (idx,) + sub_path + else: + yield () + + +def flatten_with_joined_string_paths(structure, separator="/"): + """Returns a list of (string path, data element) tuples. + + The order of tuples produced matches that of `nest.flatten`. This allows you + to flatten a nested structure while keeping information about where in the + structure each data element was located. See `nest.yield_flat_paths` + for more information. + + Args: + structure: the nested structure to flatten. + separator: string to separate levels of hierarchy in the results, defaults + to '/'. + + Returns: + A list of (string, data element) tuples. + """ + flat_paths = yield_flat_paths(structure) + def stringify_and_join(path_elements): + return separator.join(str(path_element) for path_element in path_elements) + flat_string_paths = [stringify_and_join(path) for path in flat_paths] + return list(zip(flat_string_paths, flatten(structure))) + + _pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence) @@ -691,6 +775,8 @@ _allowed_symbols = [ "flatten_up_to", "map_structure_up_to", "get_traverse_shallow_structure", + "yield_flat_paths", + "flatten_with_joined_string_paths", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 3d9e9f9684..4906649f01 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -584,6 +584,59 @@ class NestTest(test.TestCase): TypeError, "didn't return a depth=1 structure of bools"): nest.get_traverse_shallow_structure(lambda _: [1], [1]) + def testYieldFlatStringPaths(self): + for inputs_expected in ({"inputs": [], "expected": []}, + {"inputs": 3, "expected": [()]}, + {"inputs": [3], "expected": [(0,)]}, + {"inputs": {"a": 3}, "expected": [("a",)]}, + {"inputs": {"a": {"b": 4}}, + "expected": [("a", "b")]}, + {"inputs": [{"a": 2}], "expected": [(0, "a")]}, + {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, + {"inputs": [{"a": [(23, 42)]}], + "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, + {"inputs": [{"a": ([23], 42)}], + "expected": [(0, "a", 0, 0), (0, "a", 1)]}, + {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, + "expected": [("a", "a"), ("c", 0, 0, 0)]}, + {"inputs": {"0": [{"1": 23}]}, + "expected": [("0", 0, "1")]}): + inputs = inputs_expected["inputs"] + expected = inputs_expected["expected"] + self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) + + def testFlattenWithStringPaths(self): + for inputs_expected in ( + {"inputs": [], "expected": []}, + {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, + {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): + inputs = inputs_expected["inputs"] + expected = inputs_expected["expected"] + self.assertEqual( + nest.flatten_with_joined_string_paths(inputs, separator="/"), + expected) + + # Need a separate test for namedtuple as we can't declare tuple definitions + # in the @parameterized arguments. + def testFlattenNamedTuple(self): + # pylint: disable=invalid-name + Foo = collections.namedtuple("Foo", ["a", "b"]) + Bar = collections.namedtuple("Bar", ["c", "d"]) + # pylint: enable=invalid-name + test_cases = [ + (Foo(a=3, b=Bar(c=23, d=42)), + [("a", 3), ("b/c", 23), ("b/d", 42)]), + (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), + [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), + (Bar(c=42, d=43), + [("c", 42), ("d", 43)]), + (Bar(c=[42], d=43), + [("c/0", 42), ("d", 43)]), + ] + for inputs, expected in test_cases: + self.assertEqual( + list(nest.flatten_with_joined_string_paths(inputs)), expected) + if __name__ == "__main__": test.main() -- 2.34.1