Allow assignment to subscripts in static analysis.
authorDan Moldovan <mdan@google.com>
Tue, 29 May 2018 15:45:28 +0000 (08:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 15:47:58 +0000 (08:47 -0700)
Move the handling of syntactic unpackings to a generic helper function since the pattern is used in multiple places. Update the type info analyzer to correctly process function arguments.

PiperOrigin-RevId: 198401368

tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
tensorflow/contrib/autograph/pyct/transformer.py
tensorflow/contrib/autograph/pyct/transformer_test.py

index c00946f..d6555dc 100644 (file)
@@ -136,14 +136,14 @@ class TypeInfoResolver(transformer.Base):
 
   def _process_function_arg(self, arg_name):
     str_name = str(arg_name)
+    type_holder = arg_name.ast()
+    self.scope.setval(arg_name, type_holder)
     if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types:
       # Forge a node to hold the type information, so that method calls on
       # it can resolve the type.
-      type_holder = arg_name.ast()
       type_string, type_obj = self.context.arg_types[str_name]
       anno.setanno(type_holder, 'type', type_obj)
       anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
-      self.scope.setval(arg_name, type_holder)
 
   def visit_arg(self, node):
     self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN))
@@ -167,50 +167,41 @@ class TypeInfoResolver(transformer.Base):
                      anno.getanno(definition, 'element_type'))
     return node
 
-  def _process_variable_assignment(self, source, targets):
-    # Special case: constructors.
-    if isinstance(source, gast.Call):
-      func = source.func
+  def _process_variable_assignment(self, target, value):
+    # Constructors
+    if isinstance(value, gast.Call):
+      func = value.func
       if anno.hasanno(func, 'live_val'):
         func_obj = anno.getanno(func, 'live_val')
         if tf_inspect.isclass(func_obj):
-          anno.setanno(source, 'is_constructor', True)
-          anno.setanno(source, 'type', func_obj)
-          anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
+          anno.setanno(value, 'is_constructor', True)
+          anno.setanno(value, 'type', func_obj)
+          anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn'))
           # TODO(mdan): Raise an error if constructor has side effects.
           # We can have a whitelist of no-side-effects constructors.
           # We can also step inside the constructor and further analyze.
 
-    # Multiple targets mean multiple assignment.
-    for target in targets:
-      # Tuple target means unpacking.
-      if isinstance(target, (gast.Tuple, gast.List)):
-        for i, target_item in enumerate(target.elts):
-          # Two cases here:
-          #   1. Static unpacking, e.g. a, b = c, d
-          #   2. Dynamic unpacking, e.g. a, b = c
-          # The former case is optimized away.
-          if isinstance(source, (gast.Tuple, gast.List)):
-            source_item = source.elts[i]
-          else:
-            source_item = gast.Subscript(source, gast.Index(i), ctx=None)
-          self._process_variable_assignment(source_item, (target_item,))
-      elif isinstance(target, (gast.Name, gast.Attribute)):
-        target_symbol = anno.getanno(target, anno.Basic.QN)
-        self.scope.setval(target_symbol, source)
-      else:
-        raise ValueError('assignment target has unknown type: %s' % target)
+    if isinstance(target, (gast.Name, gast.Attribute)):
+      target_symbol = anno.getanno(target, anno.Basic.QN)
+      self.scope.setval(target_symbol, value)
+    elif isinstance(target, gast.Subscript):
+      pass
+    else:
+      raise ValueError('assignment target has unknown type: %s' % target)
 
   def visit_With(self, node):
-    for wi in node.items:
-      if wi.optional_vars is not None:
-        self._process_variable_assignment(wi.context_expr, (wi.optional_vars,))
+    for item in node.items:
+      if item.optional_vars is not None:
+        self.apply_to_single_assignments((item.optional_vars,),
+                                         item.context_expr,
+                                         self._process_variable_assignment)
     self.generic_visit(node)
     return node
 
   def visit_Assign(self, node):
     self.generic_visit(node)
-    self._process_variable_assignment(node.value, node.targets)
+    self.apply_to_single_assignments(
+        node.targets, node.value, self._process_variable_assignment)
     return node
 
   def visit_Call(self, node):
index 46b7701..95cbf5c 100644 (file)
@@ -196,6 +196,19 @@ class TypeInfoResolverTest(test.TestCase):
     f_ref = node.body[0].body[1].value
     self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
 
+  def test_type_annotation_args(self):
+
+    class Foo(object):
+      pass
+
+    def test_fn(f):
+      utils.set_element_type(f, Foo)
+      return f
+
+    node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
+    f_ref = node.body[0].body[1].value
+    self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
+
   def test_nested_unpacking(self):
 
     class Foo(object):
index 4db6cc0..4c65edb 100644 (file)
@@ -103,6 +103,54 @@ class Base(gast.NodeTransformer):
           results.append(replacement)
     return results
 
+  # TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
+  def apply_to_single_assignments(self, targets, values, apply_fn):
+    """Applies a fuction to each individual assignment.
+
+    This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
+    It tries to break down the unpacking if possible. In effect, it has the same
+    effect as passing the assigned values in SSA form to apply_fn.
+
+    Examples:
+
+    The following will result in apply_fn(a, c), apply_fn(b, d):
+
+        a, b = c, d
+
+    The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
+
+        a, b = c
+
+    The following will result in apply_fn(a, (b, c)):
+
+        a = b, c
+
+    It uses the visitor pattern to allow subclasses to process single
+    assignments individually.
+
+    Args:
+      targets: list, tuple of or individual AST node. Should be used with the
+          targets field of an ast.Assign node.
+      values: an AST node.
+      apply_fn: a function of a single argument, which will be called with the
+          respective nodes of each single assignment. The signaure is
+          apply_fn(target, value), no return value.
+    """
+    if not isinstance(targets, (list, tuple)):
+      targets = (targets,)
+    for target in targets:
+      if isinstance(target, (gast.Tuple, gast.List)):
+        for i in range(len(target.elts)):
+          target_el = target.elts[i]
+          if isinstance(values, (gast.Tuple, gast.List)):
+            value_el = values.elts[i]
+          else:
+            value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store())
+          self.apply_to_single_assignments(target_el, value_el, apply_fn)
+      else:
+        # TODO(mdan): Look into allowing to rewrite the AST here.
+        apply_fn(target, values)
+
   def visit(self, node):
     source_code = self.context.source_code
     source_file = self.context.source_file
index f96b0dc..1f1adf4 100644 (file)
@@ -94,7 +94,7 @@ class TransformerTest(test.TestCase):
                       inner_function, lambda_node),
                      anno.getanno(lambda_expr, 'enclosing_entities'))
 
-  def test_statement_info_stack(self):
+  def test_local_scope_info_stack(self):
 
     class TestTransformer(transformer.Base):
 
@@ -142,7 +142,7 @@ class TransformerTest(test.TestCase):
     self.assertFalse(anno.hasanno(while_node, 'string'))
     self.assertEqual('1', anno.getanno(while_node, 'test'))
 
-  def test_statement_info_stack_checks_integrity(self):
+  def test_local_scope_info_stack_checks_integrity(self):
 
     class TestTransformer(transformer.Base):