[tf nest] Add additional key yielder.
authorEugene Brevdo <ebrevdo@google.com>
Thu, 21 Dec 2017 07:02:11 +0000 (23:02 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 21 Dec 2017 07:09:08 +0000 (23:09 -0800)
This is just copying a utility function created by Malcolm Reynolds.

PiperOrigin-RevId: 179775504

tensorflow/python/util/nest.py
tensorflow/python/util/nest_test.py

index 5c066e2bef1eb557b81b4996a4848fb18318ab4e..4ce871de72fb43420e25bfa7cd13002b09f83f18 100644 (file)
@@ -47,6 +47,15 @@ def _sorted(dict_):
     raise TypeError("nest only supports dicts with sortable keys.")
 
 
+def _is_namedtuple(instance):
+  """Returns True iff `instance` is a `namedtuple`."""
+  return (
+      isinstance(instance, tuple) and
+      hasattr(instance, "_fields") and
+      isinstance(instance._fields, _collections.Sequence) and
+      all(isinstance(f, _six.string_types) for f in instance._fields))
+
+
 def _sequence_like(instance, args):
   """Converts the sequence `args` to the same type as `instance`.
 
@@ -66,11 +75,7 @@ def _sequence_like(instance, args):
     # corresponding `OrderedDict` to pack it back).
     result = dict(zip(_sorted(instance), args))
     return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
-  elif (isinstance(instance, tuple) and
-        hasattr(instance, "_fields") and
-        isinstance(instance._fields, _collections.Sequence) and
-        all(isinstance(f, _six.string_types) for f in instance._fields)):
-    # This is a namedtuple
+  elif _is_namedtuple(instance):
     return type(instance)(*args)
   else:
     # Not a namedtuple
@@ -677,6 +682,85 @@ def get_traverse_shallow_structure(traverse_fn, structure):
   return _sequence_like(structure, level_traverse)
 
 
+def yield_flat_paths(nest):
+  """Yields paths for some nested structure.
+
+  Paths are lists of objects which can be str-converted, which may include
+  integers or other types which are used as indices in a dict.
+
+  The flat list will be in the corresponding order as if you called
+  `snt.nest.flatten` on the structure. This is handy for naming Tensors such
+  the TF scope structure matches the tuple structure.
+
+  E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
+
+  ```shell
+  >>> nest.flatten(value)
+  [3, 23, 42]
+  >>> list(nest.yield_flat_paths(value))
+  [('a',), ('b', 'c'), ('b', 'd')]
+  ```
+
+  ```shell
+  >>> list(nest.yield_flat_paths({'a': [3]}))
+  [('a', 0)]
+  >>> list(nest.yield_flat_paths({'a': 3}))
+  [('a',)]
+  ```
+
+  Args:
+    nest: the value to produce a flattened paths list for.
+
+  Yields:
+    Tuples containing index or key values which form the path to a specific
+      leaf value in the nested structure.
+  """
+
+  # The _maybe_add_final_path_element function is used below in order to avoid
+  # adding trailing slashes when the sub-element recursed into is a leaf.
+  if isinstance(nest, dict):
+    for key in _sorted(nest):
+      value = nest[key]
+      for sub_path in yield_flat_paths(value):
+        yield (key,) + sub_path
+  elif _is_namedtuple(nest):
+    for key in nest._fields:
+      value = getattr(nest, key)
+      for sub_path in yield_flat_paths(value):
+        yield (key,) + sub_path
+  elif isinstance(nest, _six.string_types):
+    yield ()
+  elif isinstance(nest, _collections.Sequence):
+    for idx, value in enumerate(nest):
+      for sub_path in yield_flat_paths(value):
+        yield (idx,) + sub_path
+  else:
+    yield ()
+
+
+def flatten_with_joined_string_paths(structure, separator="/"):
+  """Returns a list of (string path, data element) tuples.
+
+  The order of tuples produced matches that of `nest.flatten`. This allows you
+  to flatten a nested structure while keeping information about where in the
+  structure each data element was located. See `nest.yield_flat_paths`
+  for more information.
+
+  Args:
+    structure: the nested structure to flatten.
+    separator: string to separate levels of hierarchy in the results, defaults
+      to '/'.
+
+  Returns:
+    A list of (string, data element) tuples.
+  """
+  flat_paths = yield_flat_paths(structure)
+  def stringify_and_join(path_elements):
+    return separator.join(str(path_element) for path_element in path_elements)
+  flat_string_paths = [stringify_and_join(path) for path in flat_paths]
+  return list(zip(flat_string_paths, flatten(structure)))
+
+
 _pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
 
 
@@ -691,6 +775,8 @@ _allowed_symbols = [
     "flatten_up_to",
     "map_structure_up_to",
     "get_traverse_shallow_structure",
+    "yield_flat_paths",
+    "flatten_with_joined_string_paths",
 ]
 
 remove_undocumented(__name__, _allowed_symbols)
index 3d9e9f96849c1b7415892ec9341947565ed89664..4906649f013da38f6b18f1645958aa4b244a9d05 100644 (file)
@@ -584,6 +584,59 @@ class NestTest(test.TestCase):
         TypeError, "didn't return a depth=1 structure of bools"):
       nest.get_traverse_shallow_structure(lambda _: [1], [1])
 
+  def testYieldFlatStringPaths(self):
+    for inputs_expected in ({"inputs": [], "expected": []},
+                            {"inputs": 3, "expected": [()]},
+                            {"inputs": [3], "expected": [(0,)]},
+                            {"inputs": {"a": 3}, "expected": [("a",)]},
+                            {"inputs": {"a": {"b": 4}},
+                             "expected": [("a", "b")]},
+                            {"inputs": [{"a": 2}], "expected": [(0, "a")]},
+                            {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
+                            {"inputs": [{"a": [(23, 42)]}],
+                             "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
+                            {"inputs": [{"a": ([23], 42)}],
+                             "expected": [(0, "a", 0, 0), (0, "a", 1)]},
+                            {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
+                             "expected": [("a", "a"), ("c", 0, 0, 0)]},
+                            {"inputs": {"0": [{"1": 23}]},
+                             "expected": [("0", 0, "1")]}):
+      inputs = inputs_expected["inputs"]
+      expected = inputs_expected["expected"]
+      self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
+
+  def testFlattenWithStringPaths(self):
+    for inputs_expected in (
+        {"inputs": [], "expected": []},
+        {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
+        {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
+      inputs = inputs_expected["inputs"]
+      expected = inputs_expected["expected"]
+      self.assertEqual(
+          nest.flatten_with_joined_string_paths(inputs, separator="/"),
+          expected)
+
+  # Need a separate test for namedtuple as we can't declare tuple definitions
+  # in the @parameterized arguments.
+  def testFlattenNamedTuple(self):
+    # pylint: disable=invalid-name
+    Foo = collections.namedtuple("Foo", ["a", "b"])
+    Bar = collections.namedtuple("Bar", ["c", "d"])
+    # pylint: enable=invalid-name
+    test_cases = [
+        (Foo(a=3, b=Bar(c=23, d=42)),
+         [("a", 3), ("b/c", 23), ("b/d", 42)]),
+        (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")),
+         [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
+        (Bar(c=42, d=43),
+         [("c", 42), ("d", 43)]),
+        (Bar(c=[42], d=43),
+         [("c/0", 42), ("d", 43)]),
+    ]
+    for inputs, expected in test_cases:
+      self.assertEqual(
+          list(nest.flatten_with_joined_string_paths(inputs)), expected)
+
 
 if __name__ == "__main__":
   test.main()