Adding support for analyzing assignment info for nested tuples.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Mar 2018 22:33:33 +0000 (15:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 25 Mar 2018 11:16:22 +0000 (04:16 -0700)
PiperOrigin-RevId: 190285584

tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py

index 5556a58..a969adb 100644 (file)
@@ -168,6 +168,15 @@ class TypeInfoResolver(transformer.Base):
                      anno.getanno(definition, 'element_type'))
     return node
 
+  def _process_tuple_assignment(self, source, t):
+    for i, e in enumerate(t.elts):
+      if isinstance(e, gast.Tuple):
+        self._process_tuple_assignment(source, e)
+      else:
+        self.scope.setval(
+            anno.getanno(e, anno.Basic.QN),
+            gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
+
   def _process_variable_assignment(self, source, targets):
     if isinstance(source, gast.Call):
       func = source.func
@@ -183,10 +192,9 @@ class TypeInfoResolver(transformer.Base):
 
     for t in targets:
       if isinstance(t, gast.Tuple):
-        for i, e in enumerate(t.elts):
-          self.scope.setval(
-              anno.getanno(e, anno.Basic.QN),
-              gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
+        # need to recurse on the case of assigning nested tuples,
+        # ex. a, (b, c) = f()
+        self._process_tuple_assignment(source, t)
       elif isinstance(t, (gast.Name, gast.Attribute)):
         self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
       else:
index 0d9d5a8..8a89561 100644 (file)
@@ -196,6 +196,23 @@ class TypeInfoResolverTest(test.TestCase):
     f_ref = node.body[0].body[1].value
     self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
 
+  def test_nested_assignment(self):
+
+    def test_fn(foo):
+      a, (b, c) = foo
+      return a, b, c
+
+    node = self._parse_and_analyze(test_fn, {'foo': (1, 2, 3)})
+    lhs = node.body[0].body[1].value.elts
+    a = lhs[0]
+    b = lhs[1]
+    c = lhs[2]
+    # TODO(mdan): change these once we have the live values propagating
+    # correctly
+    self.assertFalse(anno.hasanno(a, 'live_val'))
+    self.assertFalse(anno.hasanno(b, 'live_val'))
+    self.assertFalse(anno.hasanno(c, 'live_val'))
+
 
 if __name__ == '__main__':
   test.main()