visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/py2tf/impl",
+ "//tensorflow/contrib/py2tf/pyct",
"//tensorflow/contrib/py2tf/utils",
"@gast_archive//:gast",
"@six_archive//:six",
from tensorflow.contrib.py2tf.impl.api import graph_ready
from tensorflow.contrib.py2tf.impl.api import to_code
from tensorflow.contrib.py2tf.impl.api import to_graph
+from tensorflow.contrib.py2tf.pyct.transformer import PyFlowParseError
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ['to_graph', 'to_code', 'convert', 'graph_ready', 'utils']
+_allowed_symbols = [
+ 'to_graph', 'to_code', 'convert', 'graph_ready', 'utils', 'PyFlowParseError'
+]
remove_undocumented(__name__, _allowed_symbols)
"decorators.py",
"for_canonicalization.py",
"logical_expressions.py",
- "print_functions.py",
"side_effect_guards.py",
],
srcs_version = "PY2AND3",
)
py_test(
- name = "print_functions_test",
- srcs = ["print_functions_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":test_lib",
- "//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/python:client_testlib",
- "@gast_archive//:gast",
- ],
-)
-
-py_test(
name = "side_effect_guards_test",
srcs = ["side_effect_guards_test.py"],
srcs_version = "PY2AND3",
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
-class BreakCanonicalizationTransformer(gast.NodeTransformer):
+class BreakCanonicalizationTransformer(transformer.Base):
"""Canonicalizes continue statements into additional conditionals."""
- def __init__(self, namer):
- self.namer = namer
+ def __init__(self, context):
+ super(BreakCanonicalizationTransformer, self).__init__(context)
# This is a stack structure, to correctly process nested loops.
self.break_uses = []
def visit_While(self, node):
self.generic_visit(node.test)
- scope = anno.getanno(node, 'body_scope')
+ scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- break_var = self.namer.new_symbol('break_requested', scope.referenced)
+ break_var = self.context.namer.new_symbol('break_requested',
+ scope.referenced)
self.break_uses.append([False, break_var])
node.body = self._manual_visit_list(node.body)
if self.break_uses[-1][0]:
def visit_For(self, node):
self.generic_visit(node.target)
self.generic_visit(node.iter)
- scope = anno.getanno(node, 'body_scope')
+ scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- break_var = self.namer.new_symbol('break_requested', scope.referenced)
+ break_var = self.context.namer.new_symbol('break_requested',
+ scope.referenced)
self.break_uses.append([False, break_var])
node.body = self._manual_visit_list(node.body)
if self.break_uses[-1][0]:
return self._create_break_trigger()
-def transform(node, namer):
- transformer = BreakCanonicalizationTransformer(namer)
- node = transformer.visit(node)
- return node
+def transform(node, context):
+ return BreakCanonicalizationTransformer(context).visit(node)
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
- node = break_canonicalization.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = break_canonicalization.transform(node, self.ctx)
result = compiler.ast_to_object(node)
self.assertEqual(test_fn(0), result.test_fn(0))
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
- node = break_canonicalization.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = break_canonicalization.transform(node, self.ctx)
result = compiler.ast_to_object(node)
# The break is incompletely canonicalized. Everything is in place, but
v.append(x)
return v, u, w
- node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
- node = break_canonicalization.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = break_canonicalization.transform(node, self.ctx)
result = compiler.ast_to_object(node)
self.assertEqual(test_fn(0), result.test_fn(0))
import gast
from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
-class BuiltinFunctionTransformer(gast.NodeTransformer):
+class BuiltinFunctionTransformer(transformer.Base):
"""Transforms Print nodes to Call so they can be handled as functions."""
- # TODO(mdan): Bring print_functions in here.
+ def __init__(self, context):
+ super(BuiltinFunctionTransformer, self).__init__(context)
def _convert_len(self, node):
template = """
return self._convert_len(node)
return node
+ def visit_Print(self, node):
+ self.generic_visit(node)
+ template = """
+ fname(args)
+ """
+ return templates.replace(template, fname='print', args=node.values)
+
# pylint:enable=invalid-name
-def transform(node):
- transformer = BuiltinFunctionTransformer()
- node = transformer.visit(node)
- return node
+def transform(node, context):
+ return BuiltinFunctionTransformer(context).visit(node)
from __future__ import division
from __future__ import print_function
+import gast
+
from tensorflow.contrib.py2tf.converters import builtin_functions
from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.pyct import compiler
return len(a)
node = self.parse_and_analyze(test_fn, {'len': len})
- node = builtin_functions.transform(node)
+ node = builtin_functions.transform(node, self.ctx)
result = compiler.ast_to_object(node)
setattr(result, 'tf', array_ops)
sess.run(
result.test_fn(constant_op.constant([0, 0, 0]))))
+ def test_print(self):
+
+ def test_fn(a):
+ print(a)
+
+ node = self.parse_and_analyze(test_fn, {'print': print})
+ node = builtin_functions.transform(node, self.ctx)
+ result = compiler.ast_to_object(node)
+
+ result.test_fn('a')
+ self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
+
if __name__ == '__main__':
test.main()
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct import templates
from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
from tensorflow.python.util import tf_inspect
# The renaming process will transform it into a regular function.
# TODO(mdan): Is this complete? How does it work with nested members?
node.args = [node.func.value] + node.args
- node.func = gast.Name(new_name, gast.Load(), None)
+ node.func = templates.replace('func_name', func_name=new_name)[0]
return node
def _wrap_to_py_func_no_return(self, node):
- args_scope = anno.getanno(node, 'args_scope')
+ args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
# TODO(mdan): Properly handle varargs, kwargs, etc.
template = """
def wrapper(args):
template,
call=node.func,
wrapper=self.context.namer.compiled_function_name(node.func.id)[0],
- args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used))
- anno.setanno(call_expr.value, 'args_scope', args_scope)
+ args=tuple(args_scope.used))
+ anno.setanno(call_expr.value, NodeAnno.ARGS_SCOPE, args_scope)
# TODO(mdan): Rename this annotation to 'graph_ready'
- anno.setanno(wrapper_def, 'skip_processing', True)
+ anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)
return (wrapper_def, call_expr)
from __future__ import division
from __future__ import print_function
-import gast
-
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
-class ContinueCanonicalizationTransformer(gast.NodeTransformer):
+class ContinueCanonicalizationTransformer(transformer.Base):
"""Canonicalizes continue statements into additional conditionals."""
- def __init__(self, namer):
- self.namer = namer
+ def __init__(self, context):
+ super(ContinueCanonicalizationTransformer, self).__init__(context)
# This is a stack structure, to correctly process nested loops.
self.continuation_uses = []
return reorganized_nodes
def _process_loop_block(self, block, scope):
- cont_var = self.namer.new_symbol('cont_requested', scope.referenced)
+ cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced)
self.continuation_uses.append([False, cont_var])
block = self._visit_and_reindent_if_necessary(block)
if self.continuation_uses[-1][0]:
def visit_While(self, node):
self.generic_visit(node.test)
node.body = self._process_loop_block(node.body,
- anno.getanno(node, 'body_scope'))
+ anno.getanno(node,
+ NodeAnno.BODY_SCOPE))
for n in node.orelse:
self.generic_visit(n)
return node
self.generic_visit(node.target)
self.generic_visit(node.iter)
node.body = self._process_loop_block(node.body,
- anno.getanno(node, 'body_scope'))
+ anno.getanno(node,
+ NodeAnno.BODY_SCOPE))
for n in node.orelse:
self.generic_visit(n)
return node
def transform(node, namer):
- transformer = ContinueCanonicalizationTransformer(namer)
- node = transformer.visit(node)
- return node
+ return ContinueCanonicalizationTransformer(namer).visit(node)
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
- node = continue_canonicalization.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = continue_canonicalization.transform(node, self.ctx)
result = compiler.ast_to_object(node)
self.assertEqual(test_fn(0), result.test_fn(0))
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
- node = continue_canonicalization.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = continue_canonicalization.transform(node, self.ctx)
result = compiler.ast_to_object(node)
self.assertEqual(test_fn([]), result.test_fn([]))
v.append(x)
return v, u, w
- node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
- node = continue_canonicalization.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = continue_canonicalization.transform(node, self.ctx)
result = compiler.ast_to_object(node)
self.assertEqual(test_fn(0), result.test_fn(0))
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
class SymbolNamer(object):
class SymbolRenamer(gast.NodeTransformer):
+ """Transformer that can rename symbols to a simple names."""
def __init__(self, name_map):
self.name_map = name_map
- def visit_Name(self, node):
- if node.id in self.name_map:
- node.id = self.name_map[node.id]
+ def _process(self, node):
+ qn = anno.getanno(node, anno.Basic.QN)
+ if qn in self.name_map:
+ return gast.Name(self.name_map[qn], node.ctx, None)
return node
+ def visit_Name(self, node):
+ return self._process(node)
+
+ def visit_Attribute(self, node):
+ return self._process(node)
+
-class ControlFlowTransformer(gast.NodeTransformer):
+class ControlFlowTransformer(transformer.Base):
"""Transforms control flow structures like loops an conditionals."""
- def __init__(self, namer):
- self.namer = namer
+ def __init__(self, context):
+ super(ControlFlowTransformer, self).__init__(context)
# pylint:disable=invalid-name
def visit_If(self, node):
self.generic_visit(node)
- body_scope = anno.getanno(node, 'body_scope')
- orelse_scope = anno.getanno(node, 'orelse_scope')
+ body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
if body_scope.created - orelse_scope.created:
raise ValueError(
(body_scope.created | orelse_scope.created))
aliased_orig_names = tuple(need_alias)
aliased_new_names = tuple(
- self.namer.new_symbol(s, all_referenced) for s in aliased_orig_names)
+ self.context.namer.new_symbol(s.ssf(), all_referenced)
+ for s in aliased_orig_names)
alias_map = dict(zip(aliased_orig_names, aliased_new_names))
node_body = node.body
node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body]
node_orelse = [SymbolRenamer(alias_map).visit(n) for n in node_orelse]
if len(all_modified) == 1:
- results = gast.Name(all_modified[0], None, None)
+ results = all_modified[0]
else:
- results = gast.Tuple(
- tuple(gast.Name(s, None, None) for s in all_modified), None)
-
- template = """
- def body_name():
- aliased_new_names, = aliased_orig_names,
- body
- return (all_results,)
- def orelse_name():
- aliased_new_names, = aliased_orig_names,
- orelse
- return (all_results,)
- results = tf.cond(test, body_name, orelse_name)
- """
- body_name = self.namer.new_symbol('if_true', all_referenced)
- return templates.replace(
- template,
- test=node.test,
- body_name=body_name,
- body=node_body,
- orelse_name=self.namer.new_symbol('if_false', all_referenced),
- orelse=node_orelse,
- aliased_orig_names=tuple(aliased_orig_names),
- aliased_new_names=tuple(aliased_new_names),
- all_results=tuple(alias_map[s] if s in aliased_orig_names else s
- for s in all_modified),
- results=results)
+ results = gast.Tuple([s.ast() for s in all_modified], None)
+
+ if aliased_orig_names:
+ template = """
+ def body_name():
+ aliased_new_names, = aliased_orig_names,
+ body
+ return (all_results,)
+ def orelse_name():
+ aliased_new_names, = aliased_orig_names,
+ orelse
+ return (all_results,)
+ results = tf.cond(test, body_name, orelse_name)
+ """
+ body_name = self.context.namer.new_symbol('if_true', all_referenced)
+ return templates.replace(
+ template,
+ test=node.test,
+ body_name=body_name,
+ body=node_body,
+ orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
+ orelse=node_orelse,
+ aliased_orig_names=tuple(aliased_orig_names),
+ aliased_new_names=tuple(aliased_new_names),
+ all_results=tuple(alias_map[s] if s in aliased_orig_names else s
+ for s in all_modified),
+ results=results)
+ else:
+ template = """
+ def body_name():
+ body
+ return (all_results,)
+ def orelse_name():
+ orelse
+ return (all_results,)
+ results = tf.cond(test, body_name, orelse_name)
+ """
+ body_name = self.context.namer.new_symbol('if_true', all_referenced)
+ return templates.replace(
+ template,
+ test=node.test,
+ body_name=body_name,
+ body=node_body,
+ orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
+ orelse=node_orelse,
+ all_results=tuple(s for s in all_modified),
+ results=results)
def visit_While(self, node):
self.generic_visit(node)
- body_scope = anno.getanno(node, 'body_scope')
- body_closure = tuple(body_scope.modified - body_scope.created)
-
- if len(body_closure) == 1:
- state = body_closure[0]
+ body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ body_closure = body_scope.modified - body_scope.created
+ all_referenced = body_scope.referenced
+
+ state = list(body_closure)
+ state_ssf = [
+ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
+ ]
+ ssf_map = {
+ name: ssf
+ for name, ssf in zip(state, state_ssf)
+ if str(name) != ssf
+ }
+
+ if len(state) == 1:
+ state = state[0]
+ state_ssf = state_ssf[0]
state_ast_tuple = state
else:
- state = tuple(body_closure)
- state_ast_tuple = gast.Tuple(
- tuple(gast.Name(n, None, None) for n in state), None)
+ state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
+
+ node_body = node.body
+ node_body = [SymbolRenamer(ssf_map).visit(n) for n in node_body]
+
+ test = node.test
+ test = SymbolRenamer(ssf_map).visit(test)
+
template = """
- def test_name(state):
+ def test_name(state_ssf):
return test
- def body_name(state):
+ def body_name(state_ssf):
body
- return state,
+ return state_ssf,
state_ast_tuple = tf.while_loop(test_name, body_name, [state])
"""
node = templates.replace(
template,
state=state,
+ state_ssf=state_ssf,
state_ast_tuple=state_ast_tuple,
- test_name=self.namer.new_symbol('loop_test', body_scope.referenced),
- test=node.test,
- body_name=self.namer.new_symbol('loop_body', body_scope.referenced),
- body=node.body)
+ test_name=self.context.namer.new_symbol('loop_test',
+ body_scope.referenced),
+ test=test,
+ body_name=self.context.namer.new_symbol('loop_body',
+ body_scope.referenced),
+ body=node_body)
return node
# pylint:enable=invalid-name
-def transform(node, namer):
- transformer = ControlFlowTransformer(namer)
- node = transformer.visit(node)
+def transform(node, context):
+ t = ControlFlowTransformer(context)
+ node = t.visit(node)
return node
i += 1
return s, i, n
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = control_flow.transform(node, self.ctx)
result = compiler.ast_to_object(node)
setattr(result, 'tf', control_flow_ops)
n -= 1
return n
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = control_flow.transform(node, self.ctx)
result = compiler.ast_to_object(node)
setattr(result, 'tf', control_flow_ops)
b = 2 * n
return a, b
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = control_flow.transform(node, self.ctx)
result = compiler.ast_to_object(node)
setattr(result, 'tf', control_flow_ops)
n = -n
return n
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = control_flow.transform(node, self.ctx)
result = compiler.ast_to_object(node)
setattr(result, 'tf', control_flow_ops)
from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.platform import test
arg_values=None,
arg_types=arg_types,
recursive=recursive)
- node = access.resolve(node, ctx)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, ctx)
node = live_values.resolve(node, ctx, {})
if include_type_analysis:
node = type_info.resolve(node, ctx)
from __future__ import division
from __future__ import print_function
-import gast
-
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
-class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
+class ForLoopCanonicalizationTransformer(transformer.Base):
"""Canonicalizes for loops (e.g. into while loops)."""
- def __init__(self, namer):
- self.namer = namer
+ def __init__(self, context):
+ super(ForLoopCanonicalizationTransformer, self).__init__(context)
def visit_For(self, node):
self.generic_visit(node)
- body_scope = anno.getanno(node, 'body_scope')
-
- # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)`
- # Or maybe we should replace range with tf.range?
+ body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
if anno.hasanno(node, 'extra_cond'):
template = """
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=self.namer.new_symbol('i', body_scope.referenced),
- n=self.namer.new_symbol('n', body_scope.referenced),
+ i=self.context.namer.new_symbol('i', body_scope.referenced),
+ n=self.context.namer.new_symbol('n', body_scope.referenced),
extra_cond=anno.getanno(node, 'extra_cond'))
else:
template = """
body # pylint:disable=pointless-statement
i += 1
"""
- return templates.replace(
+ repl = templates.replace(
template,
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=self.namer.new_symbol('i', body_scope.referenced),
- n=self.namer.new_symbol('n', body_scope.referenced))
+ i=self.context.namer.new_symbol('i', body_scope.referenced),
+ n=self.context.namer.new_symbol('n', body_scope.referenced))
+ return repl
def visit_Continue(self, node):
assert False, 'continue statement should be desugared at this point'
assert False, 'break statement should be desugared at this point'
-def transform(node, namer):
- transformer = ForLoopCanonicalizationTransformer(namer)
- node = transformer.visit(node)
- return node
+def transform(node, context):
+ return ForLoopCanonicalizationTransformer(context).visit(node)
s += e
return s
- node = self.parse_and_analyze(test_fn, {})
- node = for_canonicalization.transform(node, TestNamer())
+ node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = for_canonicalization.transform(node, self.ctx)
result = compiler.ast_to_object(node)
l = [1, 2, 3]
+++ /dev/null
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Compatibility support. Converts Print nodes to function calls."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gast
-
-from tensorflow.contrib.py2tf.pyct import anno
-
-
-class PrintFunctionTransformer(gast.NodeTransformer):
- """Transforms Print nodes to Call so they can be handled as functions."""
-
- # pylint:disable=invalid-name
-
- def visit_Print(self, node):
- self.generic_visit(node)
- for n in node.values:
- n.ctx = gast.Param()
- call_node = gast.Call(
- func=gast.Name('print', gast.Load(), None),
- args=node.values,
- keywords=[])
- anno.setanno(call_node.func, 'live_val', print)
- anno.setanno(call_node.func, 'fqn', 'print')
- anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope'))
- node = gast.Expr(call_node)
- return node
-
- # pylint:enable=invalid-name
-
-
-def transform(node):
- transformer = PrintFunctionTransformer()
- node = transformer.visit(node)
- return node
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
class SymbolNamer(object):
raise NotImplementedError()
-class SideEffectGuardTransformer(gast.NodeTransformer):
+class SideEffectGuardTransformer(transformer.Base):
"""Adds control dependencies to functions with side effects."""
- def __init__(self, namer):
- self.namer = namer
+ def __init__(self, context):
+ super(SideEffectGuardTransformer, self).__init__(context)
self.indent_next = False
self.next_indent_owner = None
return new_nodes
def visit_FunctionDef(self, node):
- if anno.hasanno(node, 'skip_processing'):
- return node
node.body = self._visit_and_reindent(node.body)
return node
# First, attempt to gate future evaluation of args. If that's not
# possible, gate all remaining statements (and that may fail too, see
# _visit_and_reindent.
- args_scope = anno.getanno(node.value, 'args_scope')
+ args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
guarded_args = tuple(args_scope.used & (args_scope.parent.modified
| args_scope.parent.returned))
if guarded_args:
# pylint:enable=invalid-name
-def transform(node, namer):
- transformer = SideEffectGuardTransformer(namer)
- return transformer.visit(node)
+def transform(node, context):
+ return SideEffectGuardTransformer(context).visit(node)
state_ops.assign(a, a + 1)
return a
- node = self.parse_and_analyze(test_fn, {'state_ops': state_ops})
- node = side_effect_guards.transform(node, TestNamer())
+ node = self.parse_and_analyze(
+ test_fn, {'state_ops': state_ops}, namer=TestNamer())
+ node = side_effect_guards.transform(node, self.ctx)
result = compiler.ast_to_object(node)
setattr(result, 'state_ops', state_ops)
setattr(result, 'py2tf_utils', utils)
from tensorflow.contrib.py2tf.converters import decorators
from tensorflow.contrib.py2tf.converters import for_canonicalization
from tensorflow.contrib.py2tf.converters import logical_expressions
-from tensorflow.contrib.py2tf.converters import print_functions
from tensorflow.contrib.py2tf.converters import side_effect_guards
from tensorflow.contrib.py2tf.impl import config
from tensorflow.contrib.py2tf.impl import naming
from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.util import tf_inspect
def _static_analysis_pass(node, ctx):
- node = access.resolve(node, ctx)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, ctx, None)
node = live_values.resolve(node, ctx, config.PYTHON_LITERALS)
node = type_info.resolve(node, ctx)
return node
# TODO(mdan): Factor out common elements.
# These include:
- # * keeping track of symbols that have been created
- # * marking nodes (e.g. py_func wrappers) to suppress further processing
# * code move between blocks
- # * insertion of new global references
# * visiting blocks in transformers
# Certain steps, especially canonicalization, insert new symbols into the
# to re-run the analysis.
node = _static_analysis_pass(node, ctx)
+ # Past this point, line numbers are no longer accurate so we ignore the
+ # source.
+ # TODO(mdan): Is it feasible to reconstruct intermediate source code?
+ ctx.source_code = None
node = decorators.transform(node, nocompile_decorators)
- node = break_canonicalization.transform(node, ctx.namer)
+ node = break_canonicalization.transform(node, ctx)
node = asserts.transform(node, ctx)
# Note: sequencing continue canonicalization before for loop one avoids
# dealing with the extra loop increment operation that the for
# canonicalization creates.
- node = continue_canonicalization.transform(node, ctx.namer)
+ node = continue_canonicalization.transform(node, ctx)
ctx.namespace['len'] = len
node = _static_analysis_pass(node, ctx)
- node = for_canonicalization.transform(node, ctx.namer)
+ node = for_canonicalization.transform(node, ctx)
# for_canonicalization may insert new global references.
- node = builtin_functions.transform(node)
+ node = builtin_functions.transform(node, ctx)
# builtin_functions may insert new global references.
ctx.namespace['print'] = print
node = _static_analysis_pass(node, ctx)
- node = print_functions.transform(node)
node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES,
nocompile_decorators)
- node = control_flow.transform(node, ctx.namer)
+ node = control_flow.transform(node, ctx)
+
+ # control_flow may create new symbols and change scopes.
+ node = _static_analysis_pass(node, ctx)
node = logical_expressions.transform(node)
- node = side_effect_guards.transform(node, ctx.namer)
+ node = side_effect_guards.transform(node, ctx)
return node
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.py2tf.pyct import qual_names
+
class Namer(object):
"""Implementation of the namer interfaces required by various converters.
def new_symbol(self, name_root, reserved_locals):
"""See control_flow.SymbolNamer.new_symbol."""
+ # reserved_locals may contain QNs.
+ all_reserved_locals = set()
+ for s in reserved_locals:
+ if isinstance(s, qual_names.QN):
+ all_reserved_locals.update(s.qn)
+ elif isinstance(s, str):
+ all_reserved_locals.add(s)
+ else:
+ raise ValueError('Unexpected symbol type "%s"' % type(s))
+
new_name = name_root
n = 0
- while (new_name in self.global_namespace
- or new_name in reserved_locals
- or new_name in self.generated_names):
+ while (new_name in self.global_namespace or
+ new_name in all_reserved_locals or new_name in self.generated_names):
n += 1
new_name = '%s_%d' % (name_root, n)
"anno.py",
"compiler.py",
"context.py",
+ "copier.py",
"parser.py",
"pretty_printer.py",
+ "qual_names.py",
"templates.py",
"transformer.py",
],
)
py_test(
+ name = "copier_test",
+ srcs = ["copier_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
name = "parser_test",
srcs = ["parser_test.py"],
srcs_version = "PY2AND3",
from __future__ import division
from __future__ import print_function
+from enum import Enum
+
+
+class NoValue(Enum):
+
+ def __repr__(self):
+ return self.name
+
+
+class Basic(NoValue):
+ """Container for annotation keys.
+
+ The enum values are used strictly for documentation purposes.
+ """
+
+ QN = 'Qualified name, as it appeared in the code.'
+ SKIP_PROCESSING = (
+ 'This node should be preserved as is and not processed any further.')
+
def getanno(node, key, field_name='___pyct_anno'):
return getattr(node, field_name)[key]
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Copy an AST tree, discarding annotations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+
+
+class CleanCopier(gast.NodeVisitor):
+ """Copy AST nodes.
+
+ The copied nodes will ignore almost all fields that prefixed by '__'.
+ Exceptions make some annotations.
+ """
+
+ # TODO(mdan): Parametrize which annotations get carried over.
+
+ def generic_visit(self, node):
+ new_fields = {}
+ for f in node._fields:
+ if f.startswith('__'):
+ continue
+ if not hasattr(node, f):
+ continue
+ v = getattr(node, f)
+ if isinstance(v, list):
+ v = [self.generic_visit(n) for n in v]
+ elif isinstance(v, tuple):
+ v = tuple(self.generic_visit(n) for n in v)
+ elif isinstance(v, (gast.AST, ast.AST)):
+ v = self.generic_visit(v)
+ else:
+ # Assume everything else is a value type.
+ pass
+ new_fields[f] = v
+ new_node = type(node)(**new_fields)
+ if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
+ anno.setanno(new_node, anno.Basic.SKIP_PROCESSING, True)
+ return new_node
+
+
+def copy_clean(node):
+ copier = CleanCopier()
+ if isinstance(node, list):
+ return [copier.visit(n) for n in node]
+ elif isinstance(node, tuple):
+ return tuple(copier.visit(n) for n in node)
+ else:
+ return copier.visit(node)
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for print_functions module."""
+"""Tests for copier module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import gast
+import ast
-from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.converters import print_functions
-from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import copier
from tensorflow.python.platform import test
-class PrintFunctionsTest(converter_test_base.TestCase):
-
- def test_transform(self):
-
- def test_fn(a):
- print(a)
-
- node = self.parse_and_analyze(test_fn, {'print': print})
- node = print_functions.transform(node)
- result = compiler.ast_to_object(node)
-
- result.test_fn('a')
- self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
+class CopierTest(test.TestCase):
+
+ def test_copy_clean(self):
+ ret = ast.Return(
+ ast.BinOp(
+ op=ast.Add(),
+ left=ast.Name(id='a', ctx=ast.Load()),
+ right=ast.Num(1)))
+ setattr(ret, '__foo', 'bar')
+ node = ast.FunctionDef(
+ name='f',
+ args=ast.arguments(
+ args=[ast.Name(id='a', ctx=ast.Param())],
+ vararg=None,
+ kwarg=None,
+ defaults=[]),
+ body=[ret],
+ decorator_list=[],
+ returns=None)
+ new_node = copier.copy_clean(node)
+ self.assertFalse(node is new_node)
+ self.assertFalse(ret is new_node.body[0])
+ self.assertFalse(hasattr(new_node.body[0], '__foo'))
if __name__ == '__main__':
class PrettyPrinter(gast.NodeVisitor):
"""Print AST nodes."""
- def __init__(self):
+ def __init__(self, color):
self.indent_lvl = 0
self.result = ''
+ self.color = color
+
+ def _color(self, string, color, attrs=None):
+ if self.color:
+ return termcolor.colored(string, color, attrs=attrs)
+ return string
def _type(self, node):
- return termcolor.colored(node.__class__.__name__, None, attrs=['bold'])
+ return self._color(node.__class__.__name__, None, ['bold'])
def _field(self, name):
- return termcolor.colored(name, 'blue')
+ return self._color(name, 'blue')
def _value(self, name):
- return termcolor.colored(name, 'magenta')
+ return self._color(name, 'magenta')
def _warning(self, name):
- return termcolor.colored(name, 'red')
+ return self._color(name, 'red')
def _indent(self):
- return termcolor.colored('| ' * self.indent_lvl, None, attrs=['dark'])
+ return self._color('| ' * self.indent_lvl, None, ['dark'])
def _print(self, s):
self.result += s
self._print('%s]' % (self._indent()))
else:
self._print('%s%s=[]' % (self._indent(), self._field(f)))
+ elif isinstance(v, tuple):
+ if v:
+ self._print('%s%s=(' % (self._indent(), self._field(f)))
+ self.indent_lvl += 1
+ for n in v:
+ self.generic_visit(n)
+ self.indent_lvl -= 1
+ self._print('%s)' % (self._indent()))
+ else:
+ self._print('%s%s=()' % (self._indent(), self._field(f)))
elif isinstance(v, gast.AST):
self.generic_visit(v, f)
elif isinstance(v, str):
self.indent_lvl -= 1
-def fmt(node):
- printer = PrettyPrinter()
+def fmt(node, color=True):
+ printer = PrettyPrinter(color)
if isinstance(node, (list, tuple)):
for n in node:
printer.visit(n)
from tensorflow.python.platform import test
-def f(x):
- return x + 1
-
-
class PrettyPrinterTest(test.TestCase):
def test_format(self):
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for manipulating qualified names.
+
+A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite
+(e.g. 'foo.bar') syntactic symbols.
+
+This is *not* related to the __qualname__ attribute used by inspect, which
+refers to scopes.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+
+
+class QN(object):
+ """Represents a qualified name.
+
+ """
+
+ def __init__(self, base, attr=None):
+ if attr:
+ if not isinstance(base, QN):
+ raise ValueError('For attribute QNs, base must be a QN.')
+ self._parent = base
+ self.qn = base.qn + (attr,)
+ else:
+ self._parent = None
+ self.qn = tuple(base.split('.'))
+
+ def is_composite(self):
+ return len(self.qn) > 1
+
+ @property
+ def parent(self):
+ if self._parent is None:
+ raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0])
+ return self._parent
+
+ def __hash__(self):
+ return hash(self.qn)
+
+ def __eq__(self, other):
+ return self.qn == other.qn
+
+ def __str__(self):
+ return '.'.join(self.qn)
+
+ def __repr__(self):
+ return str(self)
+
+ def ssf(self):
+ """Simple symbol form."""
+ return '_'.join(self.qn)
+
+ def ast(self):
+ # The caller must adjust the context appropriately.
+ if self.is_composite():
+ return gast.Attribute(self.parent.ast(), self.qn[-1], None)
+ return gast.Name(self.qn[0], None, None)
+
+
+class QnResolver(gast.NodeTransformer):
+ """Annotates nodes with QN information.
+
+ Note: Not using NodeAnnos to avoid circular dependencies.
+ """
+
+ def visit_Name(self, node):
+ self.generic_visit(node)
+ anno.setanno(node, anno.Basic.QN, QN(node.id))
+ return node
+
+ def visit_Attribute(self, node):
+ self.generic_visit(node)
+ anno.setanno(node, anno.Basic.QN,
+ QN(anno.getanno(node.value, anno.Basic.QN), node.attr))
+ return node
+
+
+def resolve(node):
+ return QnResolver().visit(node)
py_library(
name = "static_analysis",
srcs = [
- "access.py",
+ "activity.py",
+ "annos.py",
"live_values.py",
"type_info.py",
],
)
py_test(
- name = "access_test",
- srcs = ["access_test.py"],
+ name = "activity_test",
+ srcs = ["activity_test.py"],
srcs_version = "PY2AND3",
deps = [
":static_analysis",
+++ /dev/null
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for access module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gast
-
-from tensorflow.contrib.py2tf.pyct import anno
-from tensorflow.contrib.py2tf.pyct import context
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
-from tensorflow.python.platform import test
-
-
-class ScopeTest(test.TestCase):
-
- def test_basic(self):
- scope = access.Scope(None)
- self.assertFalse(scope.has('foo'))
-
- scope.mark_read('foo')
- self.assertFalse(scope.has('foo'))
-
- scope.mark_write('foo')
- self.assertTrue(scope.has('foo'))
-
- scope.mark_read('bar')
- self.assertFalse(scope.has('bar'))
-
- def test_copy(self):
- scope = access.Scope(None)
- scope.mark_write('foo')
-
- other = access.Scope(None)
- other.copy_from(scope)
-
- self.assertTrue('foo' in other.created)
-
- scope.mark_write('bar')
- scope.copy_from(other)
-
- self.assertFalse('bar' in scope.created)
-
- scope.mark_write('bar')
- scope.merge_from(other)
-
- self.assertTrue('bar' in scope.created)
- self.assertFalse('bar' in other.created)
-
- def test_nesting(self):
- scope = access.Scope(None)
- scope.mark_write('foo')
- scope.mark_read('bar')
-
- child = access.Scope(scope)
- self.assertTrue(child.has('foo'))
- self.assertTrue(scope.has('foo'))
-
- child.mark_write('bar')
- self.assertTrue(child.has('bar'))
- self.assertFalse(scope.has('bar'))
-
- def test_referenced(self):
- scope = access.Scope(None)
- scope.mark_read('a')
-
- child = access.Scope(scope)
- child.mark_read('b')
-
- child2 = access.Scope(child, isolated=False)
- child2.mark_read('c')
-
- self.assertTrue('c' in child2.referenced)
- self.assertTrue('b' in child2.referenced)
- self.assertFalse('a' in child2.referenced)
-
- self.assertTrue('c' in child.referenced)
- self.assertTrue('b' in child.referenced)
- self.assertFalse('a' in child.referenced)
-
-
-class AccessResolverTest(test.TestCase):
-
- def _parse_and_analyze(self, test_fn):
- node, source = parser.parse_entity(test_fn)
- ctx = context.EntityContext(
- namer=None,
- source_code=source,
- source_file=None,
- namespace={},
- arg_values=None,
- arg_types=None,
- recursive=True)
- node = access.resolve(node, ctx)
- return node
-
- def test_local_markers(self):
-
- def test_fn(a): # pylint:disable=unused-argument
- b = c # pylint:disable=undefined-variable
- while b > 0:
- b -= 1
- return b
-
- node = self._parse_and_analyze(test_fn)
- self.assertFalse(anno.getanno(node.body[0].body[0].value,
- 'is_local')) # c in b = c
- self.assertTrue(anno.getanno(node.body[0].body[1].test.left,
- 'is_local')) # b in b > 0
- self.assertTrue(anno.getanno(node.body[0].body[2].value,
- 'is_local')) # b in return b
-
- def assertScopeIs(self, scope, used, modified, created):
- self.assertItemsEqual(used, scope.used)
- self.assertItemsEqual(modified, scope.modified)
- self.assertItemsEqual(created, scope.created)
-
- def test_print_statement(self):
-
- def test_fn(a):
- b = 0
- c = 1
- print(a, b)
- return c
-
- node = self._parse_and_analyze(test_fn)
- print_node = node.body[0].body[2]
- if isinstance(print_node, gast.Print):
- # Python 2
- print_args_scope = anno.getanno(print_node, 'args_scope')
- else:
- # Python 3
- assert isinstance(print_node, gast.Expr)
- # The call node should be the one being annotated.
- print_node = print_node.value
- print_args_scope = anno.getanno(print_node, 'args_scope')
- # We basically need to detect which variables are captured by the call
- # arguments.
- self.assertScopeIs(print_args_scope, ('a', 'b'), (), ())
-
- def test_call(self):
-
- def test_fn(a):
- b = 0
- c = 1
- foo(a, b) # pylint:disable=undefined-variable
- return c
-
- node = self._parse_and_analyze(test_fn)
- call_node = node.body[0].body[2].value
- # We basically need to detect which variables are captured by the call
- # arguments.
- self.assertScopeIs(
- anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ())
-
- def test_while(self):
-
- def test_fn(a):
- b = a
- while b > 0:
- c = b
- b -= 1
- return b, c
-
- node = self._parse_and_analyze(test_fn)
- while_node = node.body[0].body[1]
- self.assertScopeIs(
- anno.getanno(while_node, 'body_scope'), ('b',), ('b', 'c'), ('c',))
- self.assertScopeIs(
- anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'),
- ('b', 'c'), ('a', 'b', 'c'))
-
- def test_for(self):
-
- def test_fn(a):
- b = a
- for _ in a:
- c = b
- b -= 1
- return b, c
-
- node = self._parse_and_analyze(test_fn)
- for_node = node.body[0].body[1]
- self.assertScopeIs(
- anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',))
- self.assertScopeIs(
- anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'),
- ('b', 'c', '_'), ('a', 'b', 'c', '_'))
-
- def test_if(self):
-
- def test_fn(x):
- if x > 0:
- x = -x
- y = 2 * x
- z = -y
- else:
- x = 2 * x
- y = -x
- u = -y
- return z, u
-
- node = self._parse_and_analyze(test_fn)
- if_node = node.body[0].body[0]
- self.assertScopeIs(
- anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'),
- ('y', 'z'))
- # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
- self.assertScopeIs(
- anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
- ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
- self.assertScopeIs(
- anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'),
- ('y', 'u'))
- self.assertScopeIs(
- anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'),
- ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
-
-
-if __name__ == '__main__':
- test.main()
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Access information (reads, writes) resolution."""
+"""Activity analysis."""
from __future__ import absolute_import
from __future__ import division
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
# TODO(mdan): Add support for PY3 (e.g. Param vs arg).
self.params.add(name)
def mark_creation(self, name):
+ if name.is_composite():
+ parent = name.parent
+ if self.has(parent):
+ # This is considered mutation of the parent, not creation.
+ # TODO(mdan): Is that really so?
+ return
+ else:
+ raise ValueError('Unknown symbol "%s".' % parent)
self.created.add(name)
def mark_write(self, name):
self.parent.mark_returned(name)
-class AccessResolver(transformer.Base):
+class ActivityAnalizer(transformer.Base):
"""Annotates nodes with local scope information. See Scope."""
- def __init__(self, context):
- super(AccessResolver, self).__init__(context)
- self.scope = Scope(None)
+ def __init__(self, context, parent_scope):
+ super(ActivityAnalizer, self).__init__(context)
+ self.scope = Scope(parent_scope)
self._in_return_statement = False
- def visit_Name(self, node):
- # TODO(mdan): This is insufficient for object fields, e.g. hp.learning_rate.
- self.generic_visit(node)
+ def _track_symbol(self, node):
+ qn = anno.getanno(node, anno.Basic.QN)
+
if isinstance(node.ctx, gast.Store):
- self.scope.mark_write(node.id)
+ self.scope.mark_write(qn)
elif isinstance(node.ctx, gast.Load):
- anno.setanno(node, 'is_local', self.scope.has(node.id))
- self.scope.mark_read(node.id)
+ self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Param):
# Param contexts appear in function defs, so they have the meaning of
# defining a variable.
# TODO(mdan): This bay be incorrect with nested functions.
# For nested functions, we'll have to add the notion of hiding args from
# the parent scope, not writing to them.
- self.scope.mark_creation(node.id)
- self.scope.mark_param(node.id)
+ self.scope.mark_creation(qn)
+ self.scope.mark_param(qn)
else:
- raise ValueError('Unknown context %s for node %s.' % (type(node.ctx),
- node.id))
- anno.setanno(node, 'is_modified_since_entry',
- self.scope.is_modified_since_entry(node.id))
- anno.setanno(node, 'is_param', self.scope.is_param(node.id))
+ raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
+
+ anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
+ anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY,
+ self.scope.is_modified_since_entry(qn))
+ anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn))
if self._in_return_statement:
- self.scope.mark_returned(node.id)
+ self.scope.mark_returned(qn)
+
+ def visit_Name(self, node):
+ self.generic_visit(node)
+ self._track_symbol(node)
+ return node
+
+ def visit_Attribute(self, node):
+ self.generic_visit(node)
+ self._track_symbol(node)
return node
def visit_Print(self, node):
self.scope = args_scope
for n in node.values:
self.visit(n)
- anno.setanno(node, 'args_scope', args_scope)
+ anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
self.scope = current_scope
return node
# TODO(mdan): Account starargs, kwargs
for n in node.keywords:
self.visit(n)
- anno.setanno(node, 'args_scope', args_scope)
+ anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
self.scope = current_scope
self.visit(node.func)
return node
self.scope = block_scope
for n in block:
self.visit(n)
- anno.setanno(node, '%s_scope' % scope_name, block_scope)
+ anno.setanno(node, scope_name, block_scope)
self.scope = current_scope
return node
before_parent = Scope(None)
before_parent.copy_from(self.scope)
after_children = []
- for child, name in children:
+ for child, scope_name in children:
self.scope.copy_from(before_parent)
- parent = self._process_block_node(parent, child, name)
+ parent = self._process_block_node(parent, child, scope_name)
after_child = Scope(None)
after_child.copy_from(self.scope)
after_children.append(after_child)
for after_child in after_children:
self.scope.merge_from(after_child)
- for child, name in children:
- # TODO(mdan): We don't need this - we have the parent link from scope.
- anno.setanno(parent, '%s_parent_scope' % name, self.scope)
return parent
def visit_If(self, node):
self.visit(node.test)
- node = self._process_parallel_blocks(
- node, ((node.body, 'body'), (node.orelse, 'orelse')))
+ node = self._process_parallel_blocks(node,
+ ((node.body, NodeAnno.BODY_SCOPE),
+ (node.orelse, NodeAnno.ORELSE_SCOPE)))
return node
def visit_For(self, node):
self.visit(node.target)
self.visit(node.iter)
- node = self._process_parallel_blocks(
- node, ((node.body, 'body'), (node.orelse, 'orelse')))
+ node = self._process_parallel_blocks(node,
+ ((node.body, NodeAnno.BODY_SCOPE),
+ (node.orelse, NodeAnno.ORELSE_SCOPE)))
return node
def visit_While(self, node):
self.visit(node.test)
- node = self._process_parallel_blocks(
- node, ((node.body, 'body'), (node.orelse, 'orelse')))
+ node = self._process_parallel_blocks(node,
+ ((node.body, NodeAnno.BODY_SCOPE),
+ (node.orelse, NodeAnno.ORELSE_SCOPE)))
return node
def visit_Return(self, node):
return node
-def resolve(node, context):
- return AccessResolver(context).visit(node)
+def resolve(node, context, parent_scope=None):
+ return ActivityAnalizer(context, parent_scope).visit(node)
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for activity module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import context
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.qual_names import QN
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.platform import test
+
+
+class ScopeTest(test.TestCase):
+
+ def test_basic(self):
+ scope = activity.Scope(None)
+ self.assertFalse(scope.has(QN('foo')))
+
+ scope.mark_read(QN('foo'))
+ self.assertFalse(scope.has(QN('foo')))
+
+ scope.mark_write(QN('foo'))
+ self.assertTrue(scope.has(QN('foo')))
+
+ scope.mark_read(QN('bar'))
+ self.assertFalse(scope.has(QN('bar')))
+
+ def test_copy(self):
+ scope = activity.Scope(None)
+ scope.mark_write(QN('foo'))
+
+ other = activity.Scope(None)
+ other.copy_from(scope)
+
+ self.assertTrue(QN('foo') in other.created)
+
+ scope.mark_write(QN('bar'))
+ scope.copy_from(other)
+
+ self.assertFalse(QN('bar') in scope.created)
+
+ scope.mark_write(QN('bar'))
+ scope.merge_from(other)
+
+ self.assertTrue(QN('bar') in scope.created)
+ self.assertFalse(QN('bar') in other.created)
+
+ def test_nesting(self):
+ scope = activity.Scope(None)
+ scope.mark_write(QN('foo'))
+ scope.mark_read(QN('bar'))
+
+ child = activity.Scope(scope)
+ self.assertTrue(child.has(QN('foo')))
+ self.assertTrue(scope.has(QN('foo')))
+
+ child.mark_write(QN('bar'))
+ self.assertTrue(child.has(QN('bar')))
+ self.assertFalse(scope.has(QN('bar')))
+
+ def test_referenced(self):
+ scope = activity.Scope(None)
+ scope.mark_read(QN('a'))
+
+ child = activity.Scope(scope)
+ child.mark_read(QN('b'))
+
+ child2 = activity.Scope(child, isolated=False)
+ child2.mark_read(QN('c'))
+
+ self.assertTrue(QN('c') in child2.referenced)
+ self.assertTrue(QN('b') in child2.referenced)
+ self.assertFalse(QN('a') in child2.referenced)
+
+ self.assertTrue(QN('c') in child.referenced)
+ self.assertTrue(QN('b') in child.referenced)
+ self.assertFalse(QN('a') in child.referenced)
+
+
+class ActivityAnalizerTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn):
+ node, source = parser.parse_entity(test_fn)
+ ctx = context.EntityContext(
+ namer=None,
+ source_code=source,
+ source_file=None,
+ namespace={},
+ arg_values=None,
+ arg_types=None,
+ recursive=True)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, ctx)
+ return node
+
+ def test_local_markers(self):
+
+ def test_fn(a): # pylint:disable=unused-argument
+ b = c # pylint:disable=undefined-variable
+ while b > 0:
+ b -= 1
+ return b
+
+ node = self._parse_and_analyze(test_fn)
+ self.assertFalse(
+ anno.getanno(node.body[0].body[0].value,
+ NodeAnno.IS_LOCAL)) # c in b = c
+ self.assertTrue(
+ anno.getanno(node.body[0].body[1].test.left,
+ NodeAnno.IS_LOCAL)) # b in b > 0
+ self.assertTrue(
+ anno.getanno(node.body[0].body[2].value,
+ NodeAnno.IS_LOCAL)) # b in return b
+
+ def assertScopeIs(self, scope, used, modified, created):
+ self.assertItemsEqual(used, tuple(str(s) for s in scope.used))
+ self.assertItemsEqual(modified, tuple(str(s) for s in scope.modified))
+ self.assertItemsEqual(created, tuple(str(s) for s in scope.created))
+
+ def test_print_statement(self):
+
+ def test_fn(a):
+ b = 0
+ c = 1
+ print(a, b)
+ return c
+
+ node = self._parse_and_analyze(test_fn)
+ print_node = node.body[0].body[2]
+ if isinstance(print_node, gast.Print):
+ # Python 2
+ print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+ else:
+ # Python 3
+ assert isinstance(print_node, gast.Expr)
+ # The call node should be the one being annotated.
+ print_node = print_node.value
+ print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+ # We basically need to detect which variables are captured by the call
+ # arguments.
+ self.assertScopeIs(print_args_scope, ('a', 'b'), (), ())
+
+ def test_call(self):
+
+ def test_fn(a):
+ b = 0
+ c = 1
+ foo(a, b) # pylint:disable=undefined-variable
+ return c
+
+ node = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[2].value
+ # We basically need to detect which variables are captured by the call
+ # arguments.
+ self.assertScopeIs(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ())
+
+ def test_while(self):
+
+ def test_fn(a):
+ b = a
+ while b > 0:
+ c = b
+ b -= 1
+ return b, c
+
+ node = self._parse_and_analyze(test_fn)
+ while_node = node.body[0].body[1]
+ self.assertScopeIs(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'),
+ ('c',))
+ self.assertScopeIs(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+ ('b', 'c'), ('a', 'b', 'c'))
+
+ def test_for(self):
+
+ def test_fn(a):
+ b = a
+ for _ in a:
+ c = b
+ b -= 1
+ return b, c
+
+ node = self._parse_and_analyze(test_fn)
+ for_node = node.body[0].body[1]
+ self.assertScopeIs(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',))
+ self.assertScopeIs(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+ ('b', 'c', '_'), ('a', 'b', 'c', '_'))
+
+ def test_if(self):
+
+ def test_fn(x):
+ if x > 0:
+ x = -x
+ y = 2 * x
+ z = -y
+ else:
+ x = 2 * x
+ y = -x
+ u = -y
+ return z, u
+
+ node = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
+ ('y', 'z'))
+ # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
+ ('x', 'y', 'u'), ('y', 'u'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+
+ def test_call_with_composite_names(self):
+
+ def foo(*_):
+ pass
+
+ def test_fn(a):
+ foo(a.b, a.c)
+ if a > 0:
+ a.b = 2
+ else:
+ d = 2
+ d.e = a.c
+ f = d.e + 1
+ a.c = f
+
+ node = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[0].value
+ self.assertScopeIs(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (),
+ ())
+ if_node = node.body[0].body[1]
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ())
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+ ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f'))
+
+
+if __name__ == '__main__':
+ test.main()
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Annotations used by the static analizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from enum import Enum
+
+
+class NoValue(Enum):
+
+ def __repr__(self):
+ return self.name
+
+
+class NodeAnno(NoValue):
+ """Additionnal annotations used by the static analyzer.
+
+ These are in addition to the basic annotations declared in anno.py.
+ """
+
+ # Symbols
+
+ IS_LOCAL = 'Symbol is local to the function scope being analized.'
+ IS_PARAM = 'Symbol is a parameter to the function being analized.'
+ IS_MODIFIED_SINCE_ENTRY = (
+ 'Symbol has been explicitly replaced in the current function scope.')
+
+ # Scopes
+ ARGS_SCOPE = 'The scope for the argument list of a function call.'
+ BODY_SCOPE = (
+ 'The scope for the main body of a statement (True branch for if '
+ 'statements, main body for loops).')
+ ORELSE_SCOPE = (
+ 'The scope for the orelse body of a statement (False branch for if '
+ 'statements, orelse body for loops).')
Live values are extracted from the known execution context.
-Requires annotations generated by AccessResolver.
+Requires activity analysis annotations.
"""
from __future__ import absolute_import
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import transformer
+from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
class LiveValueResolver(transformer.Base):
def visit_Name(self, node):
self.generic_visit(node)
if isinstance(node.ctx, gast.Load):
- assert anno.hasanno(node, 'is_local'), node
- symbol_is_local = anno.getanno(node, 'is_local')
- assert anno.hasanno(node, 'is_modified_since_entry'), node
- symbol_is_modified = anno.getanno(node, 'is_modified_since_entry')
- assert anno.hasanno(node, 'is_param'), node
- symbol_is_param = anno.getanno(node, 'is_param')
+ assert anno.hasanno(node, NodeAnno.IS_LOCAL), node
+ symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL)
+ assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node
+ symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY)
+ assert anno.hasanno(node, NodeAnno.IS_PARAM), node
+ symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM)
if not symbol_is_local and not symbol_is_param:
if node.id in self.literals:
anno.setanno(node, 'live_val', obj)
anno.setanno(node, 'fqn', (obj.__name__,))
else:
- raise ValueError('Could not resolve symbol "%s".' % node.id)
+ pass
+ # TODO(mdan): Should we raise an error here?
+ # Can encounter this when:
+ # * a symbol truly lacks reference
+ # * a symbol is new, like the new name of a function we just renamed.
else:
pass
# TODO(mdan): Attempt to trace its value through the local chain.
elif isinstance(node.value, gast.Name):
stem_name = node.value
# All nonlocal symbols should be fully resolved.
- assert anno.hasanno(stem_name, 'is_local'), stem_name
+ assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
# TODO(mdan): Figure out what to do when calling attribute on local object
# Maybe just leave as-is?
return node
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.framework import constant_op
arg_values=None,
arg_types=arg_types,
recursive=True)
- node = access.resolve(node, ctx)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, ctx)
node = live_values.resolve(node, ctx, literals)
node = type_info.resolve(node, ctx)
node = live_values.resolve(node, ctx, literals)
return node
def _process_function_arg(self, arg_name):
- if self.function_level == 1 and arg_name in self.context.arg_types:
+ str_name = str(arg_name)
+ if self.function_level == 1 and str_name in self.context.arg_types:
# Forge a node to hold the type information, so that method calls on
# it can resolve the type.
- type_holder = gast.Name(arg_name, gast.Load(), None)
- type_string, type_obj = self.context.arg_types[arg_name]
+ type_holder = arg_name.ast()
+ type_string, type_obj = self.context.arg_types[str_name]
anno.setanno(type_holder, 'type', type_obj)
anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
self.scope.setval(arg_name, type_holder)
def visit_arg(self, node):
- self._process_function_arg(node.arg)
+ self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN))
return node
def visit_Name(self, node):
self.generic_visit(node)
+ qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Param):
- self._process_function_arg(node.id)
- elif isinstance(node.ctx, gast.Load) and self.scope.hasval(node.id):
+ self._process_function_arg(qn)
+ elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
# E.g. if we had
# a = b
# then for future references to `a` we should have traced_source = `b`
- traced_source = self.scope.getval(node.id)
+ traced_source = self.scope.getval(qn)
if anno.hasanno(traced_source, 'type'):
anno.setanno(node, 'type', anno.getanno(traced_source, 'type'))
anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn'))
for t in targets:
if isinstance(t, gast.Tuple):
for i, e in enumerate(t.elts):
- self.scope.setval(e.id,
- gast.Subscript(
- source, gast.Index(i), ctx=gast.Store()))
- elif isinstance(t, gast.Name):
- self.scope.setval(t.id, source)
- elif isinstance(t, gast.Attribute):
- if not (isinstance(t.value, gast.Name) and t.value.id == 'self'):
- raise ValueError(
- 'Dont know how to handle assignment to attributes of objects'
- ' other than "self": [%s].%s' % (t.value, t.attr))
+ self.scope.setval(
+ anno.getanno(e, anno.Basic.QN),
+ gast.Subscript(source, gast.Index(i), ctx=gast.Store()))
+ elif isinstance(t, (gast.Name, gast.Attribute)):
+ self.scope.setval(anno.getanno(t, anno.Basic.QN), source)
else:
raise ValueError('Dont know how to handle assignment to %s' % t)
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.contrib.py2tf.pyct.static_analysis import activity
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.client import session
arg_values=None,
arg_types=arg_types,
recursive=True)
- node = access.resolve(node, ctx)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, ctx)
node = live_values.resolve(node, ctx, {})
node = type_info.resolve(node, ctx)
node = live_values.resolve(node, ctx, {})
from __future__ import print_function
import ast
-import copy
import textwrap
import gast
+from tensorflow.contrib.py2tf.pyct import copier
from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import qual_names
class ReplaceTransformer(gast.NodeTransformer):
that these placeholders will be replaced by.
"""
self.replacements = replacements
+ self.in_replacements = False
# TODO(mdan): Make a more detailed pass and clean up if needed.
node.name = repl.id
return node
- def visit_Name(self, node):
- if node.id in self.replacements:
- # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations.
- new_nodes = copy.copy(self.replacements[node.id])
- if isinstance(new_nodes, gast.AST):
- new_nodes = [new_nodes]
- # Preserve the target context.
- for n in new_nodes:
- if isinstance(n, gast.Tuple):
- for e in n.elts:
- e.ctx = node.ctx
- n.ctx = node.ctx
- if len(new_nodes) == 1:
- new_nodes, = new_nodes
- return new_nodes
+ def _set_inner_child_context(self, node, ctx):
+ if isinstance(node, gast.Attribute):
+ self._set_inner_child_context(node.value, ctx)
+ node.ctx = gast.Load()
+ elif isinstance(node, gast.Name):
+ node.ctx = ctx
else:
+ raise ValueError('unexpected node type "%s"' % node)
+
+ def visit_Name(self, node):
+ if node.id not in self.replacements:
return node
+ new_nodes = copier.copy_clean(self.replacements[node.id])
+ if isinstance(new_nodes, gast.AST):
+ new_nodes = [new_nodes]
+
+ # Preserve the target context.
+ for n in new_nodes:
+ if isinstance(n, gast.Tuple):
+ for e in n.elts:
+ self._set_inner_child_context(e, node.ctx)
+ if isinstance(n, gast.Attribute):
+ # For attributes, the inner Name node receives the context, while the
+ # outer ones have it set to Load.
+ self._set_inner_child_context(n, node.ctx)
+ else:
+ n.ctx = node.ctx
+
+ if len(new_nodes) == 1:
+ new_nodes, = new_nodes
+
+ return new_nodes
+
-def _strings_to_names(n):
+def _convert_to_ast(n):
+ """Convert from a known data type to AST."""
if isinstance(n, str):
# Note: the node will receive the ctx value from the template, see
# ReplaceTransformer.visit_Name.
return gast.Name(id=n, ctx=None, annotation=None)
+ if isinstance(n, qual_names.QN):
+ return n.ast()
if isinstance(n, list):
- return [_strings_to_names(e) for e in n]
+ return [_convert_to_ast(e) for e in n]
if isinstance(n, tuple):
- return tuple(_strings_to_names(e) for e in n)
+ return tuple(_convert_to_ast(e) for e in n)
return n
raise ValueError('Expected string template, got %s' % type(template))
tree = parser.parse_str(textwrap.dedent(template))
for k in replacements:
- replacements[k] = _strings_to_names(replacements[k])
- return ReplaceTransformer(replacements).visit(tree).body
+ replacements[k] = _convert_to_ast(replacements[k])
+ results = ReplaceTransformer(replacements).visit(tree).body
+ if isinstance(results, list):
+ return [qual_names.resolve(r) for r in results]
+ return qual_names.resolve(results)
class TemplatesTest(test.TestCase):
+ def test_replace_tuple(self):
+ template = """
+ def test_fn(a, c):
+ return b,
+ """
+
+ node = templates.replace(template, b=('a', 'c'))[0]
+ result = compiler.ast_to_object(node)
+ self.assertEquals((2, 3), result.test_fn(2, 3))
+
def test_replace_variable(self):
template = """
def test_fn(a):
import gast
import six
+from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import pretty_printer
self.context = context
def visit(self, node):
+ source_code = self.context.source_code
+ source_file = self.context.source_file
try:
- source_code = self.context.source_code
- source_file = self.context.source_file
if source_code and hasattr(node, 'lineno'):
self._lineno = node.lineno
self._col_offset = node.col_offset
+ if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
+ return node
return super(Base, self).visit(node)
- except (ValueError, AttributeError, NotImplementedError) as e:
- msg = '%s: %s\nOccurred at node:\n%s' % (e.__class__.__name__, str(e),
- pretty_printer.fmt(node))
+ except (ValueError, AttributeError, KeyError, NotImplementedError,
+ AssertionError) as e:
+ msg = '%s: %s\nOccurred at node:\n%s' % (
+ e.__class__.__name__, str(e), pretty_printer.fmt(node, color=False))
if source_code:
line = source_code.splitlines()[self._lineno - 1]
else: