def _process_function_arg(self, arg_name):
str_name = str(arg_name)
+ type_holder = arg_name.ast()
+ self.scope.setval(arg_name, type_holder)
if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types:
# Forge a node to hold the type information, so that method calls on
# it can resolve the type.
- type_holder = arg_name.ast()
type_string, type_obj = self.context.arg_types[str_name]
anno.setanno(type_holder, 'type', type_obj)
anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
- self.scope.setval(arg_name, type_holder)
def visit_arg(self, node):
self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN))
anno.getanno(definition, 'element_type'))
return node
- def _process_variable_assignment(self, source, targets):
- # Special case: constructors.
- if isinstance(source, gast.Call):
- func = source.func
+ def _process_variable_assignment(self, target, value):
+ # Constructors
+ if isinstance(value, gast.Call):
+ func = value.func
if anno.hasanno(func, 'live_val'):
func_obj = anno.getanno(func, 'live_val')
if tf_inspect.isclass(func_obj):
- anno.setanno(source, 'is_constructor', True)
- anno.setanno(source, 'type', func_obj)
- anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
+ anno.setanno(value, 'is_constructor', True)
+ anno.setanno(value, 'type', func_obj)
+ anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn'))
# TODO(mdan): Raise an error if constructor has side effects.
# We can have a whitelist of no-side-effects constructors.
# We can also step inside the constructor and further analyze.
- # Multiple targets mean multiple assignment.
- for target in targets:
- # Tuple target means unpacking.
- if isinstance(target, (gast.Tuple, gast.List)):
- for i, target_item in enumerate(target.elts):
- # Two cases here:
- # 1. Static unpacking, e.g. a, b = c, d
- # 2. Dynamic unpacking, e.g. a, b = c
- # The former case is optimized away.
- if isinstance(source, (gast.Tuple, gast.List)):
- source_item = source.elts[i]
- else:
- source_item = gast.Subscript(source, gast.Index(i), ctx=None)
- self._process_variable_assignment(source_item, (target_item,))
- elif isinstance(target, (gast.Name, gast.Attribute)):
- target_symbol = anno.getanno(target, anno.Basic.QN)
- self.scope.setval(target_symbol, source)
- else:
- raise ValueError('assignment target has unknown type: %s' % target)
+ if isinstance(target, (gast.Name, gast.Attribute)):
+ target_symbol = anno.getanno(target, anno.Basic.QN)
+ self.scope.setval(target_symbol, value)
+ elif isinstance(target, gast.Subscript):
+ pass
+ else:
+ raise ValueError('assignment target has unknown type: %s' % target)
def visit_With(self, node):
- for wi in node.items:
- if wi.optional_vars is not None:
- self._process_variable_assignment(wi.context_expr, (wi.optional_vars,))
+ for item in node.items:
+ if item.optional_vars is not None:
+ self.apply_to_single_assignments((item.optional_vars,),
+ item.context_expr,
+ self._process_variable_assignment)
self.generic_visit(node)
return node
def visit_Assign(self, node):
self.generic_visit(node)
- self._process_variable_assignment(node.value, node.targets)
+ self.apply_to_single_assignments(
+ node.targets, node.value, self._process_variable_assignment)
return node
def visit_Call(self, node):
results.append(replacement)
return results
+ # TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
+ def apply_to_single_assignments(self, targets, values, apply_fn):
+ """Applies a fuction to each individual assignment.
+
+ This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
+ It tries to break down the unpacking if possible. In effect, it has the same
+ effect as passing the assigned values in SSA form to apply_fn.
+
+ Examples:
+
+ The following will result in apply_fn(a, c), apply_fn(b, d):
+
+ a, b = c, d
+
+ The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
+
+ a, b = c
+
+ The following will result in apply_fn(a, (b, c)):
+
+ a = b, c
+
+ It uses the visitor pattern to allow subclasses to process single
+ assignments individually.
+
+ Args:
+ targets: list, tuple of or individual AST node. Should be used with the
+ targets field of an ast.Assign node.
+ values: an AST node.
+ apply_fn: a function of a single argument, which will be called with the
+ respective nodes of each single assignment. The signaure is
+ apply_fn(target, value), no return value.
+ """
+ if not isinstance(targets, (list, tuple)):
+ targets = (targets,)
+ for target in targets:
+ if isinstance(target, (gast.Tuple, gast.List)):
+ for i in range(len(target.elts)):
+ target_el = target.elts[i]
+ if isinstance(values, (gast.Tuple, gast.List)):
+ value_el = values.elts[i]
+ else:
+ value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store())
+ self.apply_to_single_assignments(target_el, value_el, apply_fn)
+ else:
+ # TODO(mdan): Look into allowing to rewrite the AST here.
+ apply_fn(target, values)
+
def visit(self, node):
source_code = self.context.source_code
source_file = self.context.source_file