From: Dan Moldovan Date: Tue, 29 May 2018 15:45:28 +0000 (-0700) Subject: Allow assignment to subscripts in static analysis. X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~20 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9f38ecf3bd6c6e96bf3bb56f1e37f6aff180c21e;p=platform%2Fupstream%2Ftensorflow.git Allow assignment to subscripts in static analysis. Move the handling of syntactic unpackings to a generic helper function since the pattern is used in multiple places. Update the type info analyzer to correctly process function arguments. PiperOrigin-RevId: 198401368 --- diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index c00946f..d6555dc 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -136,14 +136,14 @@ class TypeInfoResolver(transformer.Base): def _process_function_arg(self, arg_name): str_name = str(arg_name) + type_holder = arg_name.ast() + self.scope.setval(arg_name, type_holder) if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: # Forge a node to hold the type information, so that method calls on # it can resolve the type. - type_holder = arg_name.ast() type_string, type_obj = self.context.arg_types[str_name] anno.setanno(type_holder, 'type', type_obj) anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) - self.scope.setval(arg_name, type_holder) def visit_arg(self, node): self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) @@ -167,50 +167,41 @@ class TypeInfoResolver(transformer.Base): anno.getanno(definition, 'element_type')) return node - def _process_variable_assignment(self, source, targets): - # Special case: constructors. - if isinstance(source, gast.Call): - func = source.func + def _process_variable_assignment(self, target, value): + # Constructors + if isinstance(value, gast.Call): + func = value.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): - anno.setanno(source, 'is_constructor', True) - anno.setanno(source, 'type', func_obj) - anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) + anno.setanno(value, 'is_constructor', True) + anno.setanno(value, 'type', func_obj) + anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. - # Multiple targets mean multiple assignment. - for target in targets: - # Tuple target means unpacking. - if isinstance(target, (gast.Tuple, gast.List)): - for i, target_item in enumerate(target.elts): - # Two cases here: - # 1. Static unpacking, e.g. a, b = c, d - # 2. Dynamic unpacking, e.g. a, b = c - # The former case is optimized away. - if isinstance(source, (gast.Tuple, gast.List)): - source_item = source.elts[i] - else: - source_item = gast.Subscript(source, gast.Index(i), ctx=None) - self._process_variable_assignment(source_item, (target_item,)) - elif isinstance(target, (gast.Name, gast.Attribute)): - target_symbol = anno.getanno(target, anno.Basic.QN) - self.scope.setval(target_symbol, source) - else: - raise ValueError('assignment target has unknown type: %s' % target) + if isinstance(target, (gast.Name, gast.Attribute)): + target_symbol = anno.getanno(target, anno.Basic.QN) + self.scope.setval(target_symbol, value) + elif isinstance(target, gast.Subscript): + pass + else: + raise ValueError('assignment target has unknown type: %s' % target) def visit_With(self, node): - for wi in node.items: - if wi.optional_vars is not None: - self._process_variable_assignment(wi.context_expr, (wi.optional_vars,)) + for item in node.items: + if item.optional_vars is not None: + self.apply_to_single_assignments((item.optional_vars,), + item.context_expr, + self._process_variable_assignment) self.generic_visit(node) return node def visit_Assign(self, node): self.generic_visit(node) - self._process_variable_assignment(node.value, node.targets) + self.apply_to_single_assignments( + node.targets, node.value, self._process_variable_assignment) return node def visit_Call(self, node): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 46b7701..95cbf5c 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -196,6 +196,19 @@ class TypeInfoResolverTest(test.TestCase): f_ref = node.body[0].body[1].value self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + def test_type_annotation_args(self): + + class Foo(object): + pass + + def test_fn(f): + utils.set_element_type(f, Foo) + return f + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) + f_ref = node.body[0].body[1].value + self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + def test_nested_unpacking(self): class Foo(object): diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 4db6cc0..4c65edb 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -103,6 +103,54 @@ class Base(gast.NodeTransformer): results.append(replacement) return results + # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. + def apply_to_single_assignments(self, targets, values, apply_fn): + """Applies a fuction to each individual assignment. + + This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. + It tries to break down the unpacking if possible. In effect, it has the same + effect as passing the assigned values in SSA form to apply_fn. + + Examples: + + The following will result in apply_fn(a, c), apply_fn(b, d): + + a, b = c, d + + The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): + + a, b = c + + The following will result in apply_fn(a, (b, c)): + + a = b, c + + It uses the visitor pattern to allow subclasses to process single + assignments individually. + + Args: + targets: list, tuple of or individual AST node. Should be used with the + targets field of an ast.Assign node. + values: an AST node. + apply_fn: a function of a single argument, which will be called with the + respective nodes of each single assignment. The signaure is + apply_fn(target, value), no return value. + """ + if not isinstance(targets, (list, tuple)): + targets = (targets,) + for target in targets: + if isinstance(target, (gast.Tuple, gast.List)): + for i in range(len(target.elts)): + target_el = target.elts[i] + if isinstance(values, (gast.Tuple, gast.List)): + value_el = values.elts[i] + else: + value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store()) + self.apply_to_single_assignments(target_el, value_el, apply_fn) + else: + # TODO(mdan): Look into allowing to rewrite the AST here. + apply_fn(target, values) + def visit(self, node): source_code = self.context.source_code source_file = self.context.source_file diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index f96b0dc..1f1adf4 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -94,7 +94,7 @@ class TransformerTest(test.TestCase): inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities')) - def test_statement_info_stack(self): + def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): @@ -142,7 +142,7 @@ class TransformerTest(test.TestCase): self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test')) - def test_statement_info_stack_checks_integrity(self): + def test_local_scope_info_stack_checks_integrity(self): class TestTransformer(transformer.Base):