"//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/contrib/py2tf/utils",
"@gast_archive//:gast",
+ "@six_archive//:six",
],
)
)
py_test(
- name = "call_trees_test",
- srcs = ["call_trees_test.py"],
+ name = "builtin_functions_test",
+ srcs = ["builtin_functions_test.py"],
srcs_version = "PY2AND3",
deps = [
":test_lib",
)
py_test(
- name = "decorators_test",
- srcs = ["decorators_test.py"],
+ name = "call_trees_test",
+ srcs = ["call_trees_test.py"],
srcs_version = "PY2AND3",
deps = [
":test_lib",
)
py_test(
- name = "builtin_functions_test",
- srcs = ["builtin_functions_test.py"],
+ name = "decorators_test",
+ srcs = ["decorators_test.py"],
srcs_version = "PY2AND3",
deps = [
":test_lib",
# Note: The lone tf.Assert call will be wrapped with control_dependencies
# by side_effect_guards.
template = """
- tf.Assert(test, [tf.constant(msg)])
+ tf.Assert(test, [msg])
"""
if node.msg is None:
from __future__ import print_function
from tensorflow.contrib.py2tf.converters import break_canonicalization
-from tensorflow.contrib.py2tf.converters import control_flow
from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.python.platform import test
-class TestNamer(control_flow.SymbolNamer):
-
- def new_symbol(self, name_root, _):
- return name_root
-
-
class BreakCanonicalizationTest(converter_test_base.TestCase):
def test_basic_break(self):
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = break_canonicalization.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- self.assertEqual(test_fn(0), result.test_fn(0))
- self.assertEqual(test_fn(1), result.test_fn(1))
- self.assertEqual(test_fn(2), result.test_fn(2))
- self.assertEqual(test_fn(3), result.test_fn(3))
- self.assertEqual(test_fn(4), result.test_fn(4))
+ with self.compiled(node) as result:
+ self.assertEqual(test_fn(0), result.test_fn(0))
+ self.assertEqual(test_fn(1), result.test_fn(1))
+ self.assertEqual(test_fn(2), result.test_fn(2))
+ self.assertEqual(test_fn(3), result.test_fn(3))
+ self.assertEqual(test_fn(4), result.test_fn(4))
def test_basic_break_for_loop(self):
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = break_canonicalization.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- # The break is incompletely canonicalized. Everything is in place, but
- # the loop does not break.
- self.assertEqual(test_equiv_fn([]), result.test_fn([]))
- self.assertEqual(test_equiv_fn([1]), result.test_fn([1]))
- self.assertEqual(test_equiv_fn([2]), result.test_fn([2]))
- self.assertEqual(test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4]))
+ with self.compiled(node) as result:
+ # The break is incompletely canonicalized. Everything is in place, but
+ # the loop does not break.
+ self.assertEqual(test_equiv_fn([]), result.test_fn([]))
+ self.assertEqual(test_equiv_fn([1]), result.test_fn([1]))
+ self.assertEqual(test_equiv_fn([2]), result.test_fn([2]))
+ self.assertEqual(
+ test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4]))
def test_continue_deeply_nested(self):
v.append(x)
return v, u, w
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = break_canonicalization.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- self.assertEqual(test_fn(0), result.test_fn(0))
- self.assertEqual(test_fn(1), result.test_fn(1))
- self.assertEqual(test_fn(2), result.test_fn(2))
- self.assertEqual(test_fn(3), result.test_fn(3))
- self.assertEqual(test_fn(4), result.test_fn(4))
+ with self.compiled(node) as result:
+ self.assertEqual(test_fn(0), result.test_fn(0))
+ self.assertEqual(test_fn(1), result.test_fn(1))
+ self.assertEqual(test_fn(2), result.test_fn(2))
+ self.assertEqual(test_fn(3), result.test_fn(3))
+ self.assertEqual(test_fn(4), result.test_fn(4))
if __name__ == '__main__':
class BuiltinFunctionTransformer(transformer.Base):
- """Transforms Print nodes to Call so they can be handled as functions."""
+ """Handles builtin functions and canonicalizes old-style print statement.
+
+ This transformer only covers functions that are translated into a
+ TF equivalent, like `len`.
+ Note that the `print` statement is converted to a function call here, but
+ wrapping the print function to a `py_func` is done by `call_trees` as a
+ generic uncompilable function wrap.
+ """
+
+ # TODO(mdan): Handle print entirely in here.
+ # Fully handling print here makes sense especially since we're considering
+ # using tf.Print instead.
def __init__(self, context):
super(BuiltinFunctionTransformer, self).__init__(context)
def visit_Print(self, node):
self.generic_visit(node)
+ args = node.values
+ # Following is the case when calling print(a, b)
+ if len(args) == 1 and isinstance(args[0], gast.Tuple):
+ args = args[0].elts
template = """
fname(args)
"""
- return templates.replace(template, fname='print', args=node.values)
+ return templates.replace(template, fname='print', args=args)
# pylint:enable=invalid-name
from __future__ import division
from __future__ import print_function
-import gast
+import sys
+
+import six
from tensorflow.contrib.py2tf.converters import builtin_functions
from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
node = self.parse_and_analyze(test_fn, {'len': len})
node = builtin_functions.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- setattr(result, 'tf', array_ops)
- with self.test_session() as sess:
- self.assertEqual(3,
- sess.run(
- result.test_fn(constant_op.constant([0, 0, 0]))))
+ with self.compiled(node, array_ops.shape) as result:
+ with self.test_session() as sess:
+ self.assertEqual(3,
+ sess.run(
+ result.test_fn(constant_op.constant([0, 0, 0]))))
def test_print(self):
node = self.parse_and_analyze(test_fn, {'print': print})
node = builtin_functions.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- result.test_fn('a')
- self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
+ with self.compiled(node) as result:
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ result.test_fn('a')
+ self.assertEqual(out_capturer.getvalue(), 'a\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_print_tuple(self):
+
+ def test_fn(a, b, c):
+ print(a, b, c)
+
+ node = self.parse_and_analyze(test_fn, {'print': print})
+ node = builtin_functions.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ result.test_fn('a', 1, [2, 3])
+ # It appears that the print output looks odd only under Python 2.
+ if six.PY2:
+ self.assertEqual(out_capturer.getvalue(), "('a', 1, [2, 3])\n")
+ else:
+ self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n')
+ finally:
+ sys.stdout = sys.__stdout__
if __name__ == '__main__':
from tensorflow.contrib.py2tf.pyct import anno
from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import qual_names
from tensorflow.contrib.py2tf.pyct import templates
from tensorflow.contrib.py2tf.pyct import transformer
from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
return node
def _wrap_to_py_func_no_return(self, node):
+ func_qn = anno.getanno(node.func, anno.Basic.QN)
args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
+ wrapper_name = self.context.namer.new_symbol(func_qn.ssf(),
+ args_scope.referenced)
+ wrapper_args = []
+ for arg in node.args:
+ if anno.hasanno(arg, anno.Basic.QN):
+ arg_qn = anno.getanno(arg, anno.Basic.QN)
+ else:
+ arg_qn = qual_names.QN('arg')
+ wrapper_args.append(
+ self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced))
# TODO(mdan): Properly handle varargs, kwargs, etc.
+ # TODO(mdan): This is best handled as a dynamic dispatch.
+ # That way we can separate tensors from non-tensor args.
template = """
- def wrapper(args):
- call(args)
+ def wrapper(wrapper_args):
+ call(wrapper_args)
return 1
- tf.py_func(wrapper, [args], [tf.int64])
+ tf.py_func(wrapper, original_args, [tf.int64])
"""
wrapper_def, call_expr = templates.replace(
template,
call=node.func,
- wrapper=self.context.namer.compiled_function_name(node.func.id)[0],
- args=tuple(args_scope.used))
- anno.setanno(call_expr.value, NodeAnno.ARGS_SCOPE, args_scope)
- # TODO(mdan): Rename this annotation to 'graph_ready'
+ wrapper=wrapper_name,
+ original_args=gast.List(elts=node.args, ctx=None),
+ wrapper_args=wrapper_args)
anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)
return (wrapper_def, call_expr)
from tensorflow.contrib.py2tf.converters import call_trees
from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class TestNamer(call_trees.FunctionNamer):
-
- def compiled_function_name(self,
- original_fqn,
- live_entity=None,
- owner_type=None):
- if owner_type is not None:
- return None, False
- return ('renamed_%s' % '_'.join(original_fqn)), True
-
-
class CallTreesTest(converter_test_base.TestCase):
def test_basic(self):
def test_fn_2(a):
return test_fn_1(a) + 1
- node = self.parse_and_analyze(
- test_fn_2, {'test_fn_1': test_fn_1}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
node = call_trees.transform(node, self.ctx, (), ())
- result = compiler.ast_to_object(node)
- # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
- setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
- self.assertEquals(3, result.test_fn_2(1))
+ with self.compiled(node) as result:
+ # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1
+ # manually.
+ result.renamed_test_fn_1 = renamed_test_fn_1
+ self.assertEquals(3, result.test_fn_2(1))
def test_simple_methods(self):
node = self.parse_and_analyze(
TestClass.test_fn_2, {'TestClass': TestClass},
- namer=TestNamer(),
arg_types={'self': (TestClass.__name__, TestClass)})
node = call_trees.transform(node, self.ctx, (), ())
- result = compiler.ast_to_object(node)
- tc = TestClass()
- self.assertEquals(3, result.test_fn_2(tc, 1))
+ with self.compiled(node) as result:
+ tc = TestClass()
+ self.assertEquals(3, result.test_fn_2(tc, 1))
def test_uncompiled_modules(self):
a = math_ops.add(a, constant_op.constant(1))
return a
- node = self.parse_and_analyze(
- test_fn, {
- 'math_ops': math_ops,
- 'constant_op': constant_op
- },
- namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {
+ 'math_ops': math_ops,
+ 'constant_op': constant_op
+ })
node = call_trees.transform(node, self.ctx,
set(((math_ops.__name__,),
(constant_op.__name__,))), ())
- result = compiler.ast_to_object(node)
- setattr(result, 'math_ops', math_ops)
- setattr(result, 'constant_op', constant_op)
-
- with self.test_session() as sess:
- # Not renamed, because the converter doesn't rename the definition itself.
- # (the caller is responsible for that).
- result_tensor = result.test_fn(constant_op.constant(1))
- result_val = sess.run(result_tensor)
- self.assertEquals(3, result_val)
+ with self.compiled(node) as result:
+ result.math_ops = math_ops
+ result.constant_op = constant_op
+ with self.test_session() as sess:
+ # Not renamed, because the converter doesn't rename the definition
+ # itself (the caller is responsible for that).
+ result_tensor = result.test_fn(constant_op.constant(1))
+ self.assertEquals(3, sess.run(result_tensor))
if __name__ == '__main__':
from __future__ import print_function
from tensorflow.contrib.py2tf.converters import continue_canonicalization
-from tensorflow.contrib.py2tf.converters import control_flow
from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.python.platform import test
-class TestNamer(control_flow.SymbolNamer):
-
- def new_symbol(self, name_root, _):
- return name_root
-
-
class ContinueCanonicalizationTest(converter_test_base.TestCase):
def test_basic_continue(self):
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = continue_canonicalization.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- self.assertEqual(test_fn(0), result.test_fn(0))
- self.assertEqual(test_fn(1), result.test_fn(1))
- self.assertEqual(test_fn(2), result.test_fn(2))
- self.assertEqual(test_fn(3), result.test_fn(3))
- self.assertEqual(test_fn(4), result.test_fn(4))
+ with self.compiled(node) as result:
+ self.assertEqual(test_fn(0), result.test_fn(0))
+ self.assertEqual(test_fn(1), result.test_fn(1))
+ self.assertEqual(test_fn(2), result.test_fn(2))
+ self.assertEqual(test_fn(3), result.test_fn(3))
+ self.assertEqual(test_fn(4), result.test_fn(4))
def test_basic_continue_for_loop(self):
v.append(x)
return v
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = continue_canonicalization.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- self.assertEqual(test_fn([]), result.test_fn([]))
- self.assertEqual(test_fn([1]), result.test_fn([1]))
- self.assertEqual(test_fn([2]), result.test_fn([2]))
- self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3]))
+ with self.compiled(node) as result:
+ self.assertEqual(test_fn([]), result.test_fn([]))
+ self.assertEqual(test_fn([1]), result.test_fn([1]))
+ self.assertEqual(test_fn([2]), result.test_fn([2]))
+ self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3]))
def test_continue_deeply_nested(self):
v.append(x)
return v, u, w
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = continue_canonicalization.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- self.assertEqual(test_fn(0), result.test_fn(0))
- self.assertEqual(test_fn(1), result.test_fn(1))
- self.assertEqual(test_fn(2), result.test_fn(2))
- self.assertEqual(test_fn(3), result.test_fn(3))
- self.assertEqual(test_fn(4), result.test_fn(4))
+ with self.compiled(node) as result:
+ self.assertEqual(test_fn(0), result.test_fn(0))
+ self.assertEqual(test_fn(1), result.test_fn(1))
+ self.assertEqual(test_fn(2), result.test_fn(2))
+ self.assertEqual(test_fn(3), result.test_fn(3))
+ self.assertEqual(test_fn(4), result.test_fn(4))
if __name__ == '__main__':
def visit_For(self, node):
assert False, 'for statement should have been canonicalized at this point'
+ def _create_cond_branch(self, body_name, aliased_orig_names,
+ aliased_new_names, body, returns):
+ if aliased_orig_names:
+ template = """
+ def body_name():
+ aliased_new_names, = aliased_orig_names,
+ body
+ return (returns,)
+ """
+ return templates.replace(
+ template,
+ body_name=body_name,
+ body=body,
+ aliased_orig_names=aliased_orig_names,
+ aliased_new_names=aliased_new_names,
+ returns=returns)
+ else:
+ template = """
+ def body_name():
+ body
+ return (returns,)
+ """
+ return templates.replace(
+ template, body_name=body_name, body=body, returns=returns)
+
+ def _create_cond_expr(self, results, test, body_name, orelse_name):
+ if results is not None:
+ template = """
+ results = py2tf_utils.run_cond(test, body_name, orelse_name)
+ """
+ return templates.replace(
+ template,
+ test=test,
+ results=results,
+ body_name=body_name,
+ orelse_name=orelse_name)
+ else:
+ template = """
+ py2tf_utils.run_cond(test, body_name, orelse_name)
+ """
+ return templates.replace(
+ template, test=test, body_name=body_name, orelse_name=orelse_name)
+
def visit_If(self, node):
self.generic_visit(node)
raise ValueError(
'The else branch creates new symbols that the if branch does not.')
- all_modified = tuple(body_scope.modified | orelse_scope.modified)
+ modified = tuple(body_scope.modified | orelse_scope.modified)
all_referenced = body_scope.referenced | orelse_scope.referenced
# Alias the closure variables inside the conditional functions
node_body = ast_util.rename_symbols(node.body, alias_map)
node_orelse = ast_util.rename_symbols(node.orelse, alias_map)
- if len(all_modified) == 1:
- results = all_modified[0]
+ if not modified:
+ # When the cond would return no value, we leave the cond called without
+ # results. That in turn should trigger the side effect guards. The
+ # branch functions will return a dummy value that ensures cond
+ # actually has some return value as well.
+ results = None
+ elif len(modified) == 1:
+ results = modified[0]
else:
- results = gast.Tuple([s.ast() for s in all_modified], None)
+ results = gast.Tuple([s.ast() for s in modified], None)
- if aliased_orig_names:
- template = """
- def body_name():
- aliased_new_names, = aliased_orig_names,
- body
- return (all_results,)
- def orelse_name():
- aliased_new_names, = aliased_orig_names,
- orelse
- return (all_results,)
- results = py2tf_utils.run_cond(test, body_name, orelse_name)
- """
- body_name = self.context.namer.new_symbol('if_true', all_referenced)
- return templates.replace(
- template,
- test=node.test,
- body_name=body_name,
- body=node_body,
- orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
- orelse=node_orelse,
- aliased_orig_names=tuple(aliased_orig_names),
- aliased_new_names=tuple(aliased_new_names),
- all_results=tuple(alias_map[s] if s in aliased_orig_names else s
- for s in all_modified),
- results=results)
+ body_name = self.context.namer.new_symbol('if_true', all_referenced)
+ orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
+ if modified:
+ body_returns = tuple(
+ alias_map[s] if s in aliased_orig_names else s for s in modified)
else:
- template = """
- def body_name():
- body
- return (all_results,)
- def orelse_name():
- orelse
- return (all_results,)
- results = py2tf_utils.run_cond(test, body_name, orelse_name)
- """
- body_name = self.context.namer.new_symbol('if_true', all_referenced)
- return templates.replace(
- template,
- test=node.test,
- body_name=body_name,
- body=node_body,
- orelse_name=self.context.namer.new_symbol('if_false', all_referenced),
- orelse=node_orelse,
- all_results=tuple(s for s in all_modified),
- results=results)
+ body_returns = templates.replace('tf.ones(())')[0].value
+
+ body_def = self._create_cond_branch(
+ body_name,
+ aliased_orig_names=tuple(aliased_orig_names),
+ aliased_new_names=tuple(aliased_new_names),
+ body=node_body,
+ returns=body_returns)
+ orelse_def = self._create_cond_branch(
+ orelse_name,
+ aliased_orig_names=tuple(aliased_orig_names),
+ aliased_new_names=tuple(aliased_new_names),
+ body=node_orelse,
+ returns=body_returns)
+ cond_expr = self._create_cond_expr(results, node.test, body_name,
+ orelse_name)
+
+ return body_def + orelse_def + cond_expr
def visit_While(self, node):
self.generic_visit(node)
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf import utils
from tensorflow.contrib.py2tf.converters import control_flow
from tensorflow.contrib.py2tf.converters import converter_test_base
-from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
-class TestNamer(control_flow.SymbolNamer):
-
- def new_symbol(self, name_root, used):
- i = 0
- while True:
- name = '%s%d' % (name_root, i)
- if name not in used:
- return name
- i += 1
-
-
class ControlFlowTest(converter_test_base.TestCase):
def test_simple_while(self):
i += 1
return s, i, n
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- setattr(result, 'tf', control_flow_ops)
- setattr(result, 'py2tf_utils', utils)
- with self.test_session() as sess:
- self.assertEqual((10, 5, 5),
- sess.run(result.test_fn(constant_op.constant(5))))
+ with self.compiled(node, control_flow_ops.while_loop) as result:
+ with self.test_session() as sess:
+ self.assertEqual((10, 5, 5),
+ sess.run(result.test_fn(constant_op.constant(5))))
def test_while_single_var(self):
n -= 1
return n
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- setattr(result, 'tf', control_flow_ops)
- setattr(result, 'py2tf_utils', utils)
- with self.test_session() as sess:
- self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
+ with self.compiled(node, control_flow_ops.while_loop) as result:
+ with self.test_session() as sess:
+ self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
def test_simple_if(self):
b = 2 * n
return a, b
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- setattr(result, 'tf', control_flow_ops)
- setattr(result, 'py2tf_utils', utils)
- with self.test_session() as sess:
- self.assertEqual((-1, 0), sess.run(
- result.test_fn(constant_op.constant(1))))
- self.assertEqual((0, -2),
- sess.run(result.test_fn(constant_op.constant(-1))))
+ with self.compiled(node, control_flow_ops.cond) as result:
+ with self.test_session() as sess:
+ self.assertEqual((-1, 0),
+ sess.run(result.test_fn(constant_op.constant(1))))
+ self.assertEqual((0, -2),
+ sess.run(result.test_fn(constant_op.constant(-1))))
def test_if_single_var(self):
n = -n
return n
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- setattr(result, 'tf', control_flow_ops)
- setattr(result, 'py2tf_utils', utils)
- with self.test_session() as sess:
- self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
+ with self.compiled(node, control_flow_ops.cond) as result:
+ with self.test_session() as sess:
+ self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
if __name__ == '__main__':
from __future__ import division
from __future__ import print_function
+import contextlib
import imp
+from tensorflow.contrib.py2tf import utils
+from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct import qual_names
from tensorflow.python.platform import test
+class FakeNamer(object):
+
+ def new_symbol(self, name_root, used):
+ i = 0
+ while True:
+ name = '%s%d' % (name_root, i)
+ if name not in used:
+ return name
+ i += 1
+
+ def compiled_function_name(self,
+ original_fqn,
+ live_entity=None,
+ owner_type=None):
+ del live_entity
+ if owner_type is not None:
+ return None, False
+ return ('renamed_%s' % '_'.join(original_fqn)), True
+
+
class TestCase(test.TestCase):
"""Base class for unit tests in this module. Contains relevant utilities."""
+ @contextlib.contextmanager
+ def compiled(self, node, *symbols):
+ source = '<compile failed>'
+ try:
+ result, source = compiler.ast_to_object(node)
+ result.tf = self.make_fake_tf(*symbols)
+ result.py2tf_utils = utils
+ yield result
+ except Exception: # pylint:disable=broad-except
+ print('Offending compiled code:\n%s' % source)
+ raise
+
def make_fake_tf(self, *symbols):
fake_tf = imp.new_module('fake_tf')
for s in symbols:
recursive=True):
node, source = parser.parse_entity(test_fn)
ctx = context.EntityContext(
- namer=namer,
+ namer=namer or FakeNamer(),
source_code=source,
source_file=None,
namespace=namespace,
node = node.body[0].body[0]
node = decorators.transform(node, remove_decorators=())
- result = compiler.ast_to_object(
+ # Since the decorator is not removed, we need to include its source
+ # code. We cannot do it after the fact because decorators are executed
+ # on load.
+ result, _ = compiler.ast_to_object(
node,
source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator)))
self.assertEqual(2, result.test_fn(1))
node = decorators.transform(node, remove_decorators=(function_decorator,))
- result = compiler.ast_to_object(node)
- self.assertEqual(1, result.test_fn(1))
+ with self.compiled(node) as result:
+ self.assertEqual(1, result.test_fn(1))
def test_simple_decorator(self):
node = node.body[0].body[0]
node = decorators.transform(node, remove_decorators=())
- result = compiler.ast_to_object(
+ # Since the decorator is not removed, we need to include its source
+ # code. We cannot do it after the fact because decorators are executed
+ # on load.
+ result, _ = compiler.ast_to_object(
node,
source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator)))
self.assertEqual(2, result.test_fn(1))
node = decorators.transform(node, remove_decorators=(simple_decorator,))
- result = compiler.ast_to_object(node)
- self.assertEqual(1, result.test_fn(1))
+ with self.compiled(node) as result:
+ self.assertEqual(1, result.test_fn(1))
if __name__ == '__main__':
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.converters import control_flow
from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.converters import for_canonicalization
-from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.python.platform import test
-class TestNamer(control_flow.SymbolNamer):
-
- def new_symbol(self, name_root, _):
- return name_root
-
-
class ControlFlowTest(converter_test_base.TestCase):
def test_basic_for(self):
s += e
return s
- node = self.parse_and_analyze(test_fn, {}, namer=TestNamer())
+ node = self.parse_and_analyze(test_fn, {})
node = for_canonicalization.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- l = [1, 2, 3]
- self.assertEqual(test_fn(l), result.test_fn(l))
- l = []
- self.assertEqual(test_fn(l), result.test_fn(l))
+ with self.compiled(node) as result:
+ l = [1, 2, 3]
+ self.assertEqual(test_fn(l), result.test_fn(l))
+ l = []
+ self.assertEqual(test_fn(l), result.test_fn(l))
if __name__ == '__main__':
from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.converters import logical_expressions
-from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
node = self.parse_and_analyze(test_fn, {})
node = logical_expressions.transform(node)
- result = compiler.ast_to_object(node)
- setattr(result, 'tf', math_ops)
- with self.test_session() as sess:
- self.assertTrue(sess.run(result.test_fn(1, 1)))
- self.assertFalse(sess.run(result.test_fn(1, 2)))
+ with self.compiled(node, math_ops.equal) as result:
+ with self.test_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(1, 1)))
+ self.assertFalse(sess.run(result.test_fn(1, 2)))
def test_bool_ops(self):
node = self.parse_and_analyze(test_fn, {})
node = logical_expressions.transform(node)
- result = compiler.ast_to_object(node)
- setattr(result, 'tf', math_ops)
- with self.test_session() as sess:
- self.assertTrue(sess.run(result.test_fn(True, False, True)))
+ with self.compiled(node, math_ops.logical_or,
+ math_ops.logical_and) as result:
+ with self.test_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(True, False, True)))
if __name__ == '__main__':
# _visit_and_reindent.
args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
# NOTE: We can't guard object attributes because they may not be writable.
- guarded_args = tuple(
- s for s in args_scope.used if not s.is_composite())
+ # In addition, avoid renaming well-known names.
+ # TODO(mdan): Move these names into config.
+ unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
+ guarded_args = tuple(s for s in args_scope.used
+ if not s.is_composite() and s not in unguarded_names)
# TODO(mdan): Include all arguments which depended on guarded_args too.
# For example, the following will still cause a race:
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf import utils
from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.converters import side_effect_guards
-from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class TestNamer(side_effect_guards.SymbolNamer):
-
- def new_symbol(self, name_root, _):
- return 'renamed_%s' % name_root
-
-
class SideEffectGuardsTest(converter_test_base.TestCase):
- def _transform_and_compile(self, test_fn):
- ns = {
- 'control_flow_ops': control_flow_ops,
- 'constant_op': constant_op,
- 'gen_math_ops': gen_math_ops,
- 'ops': ops,
- 'state_ops': state_ops,
- }
- node = self.parse_and_analyze(
- test_fn, ns,
- namer=TestNamer())
- node = side_effect_guards.transform(node, self.ctx)
- result = compiler.ast_to_object(node)
- self.attach_namespace(result, **ns)
- result.tf = self.make_fake_tf(array_ops.identity, control_flow_ops.Assert,
- gen_math_ops.greater,
- ops.control_dependencies, ops.Tensor)
- result.py2tf_utils = utils
- return result.test_fn, node
-
def test_side_effect_on_return_only_variable(self):
+ tf = None
+
def test_fn(a):
- state_ops.assign(a, a + 1)
+ tf.assign(a, a + 1)
return a
- tf_test_fn, node = self._transform_and_compile(test_fn)
+ node = self.parse_and_analyze(test_fn, {})
+ node = side_effect_guards.transform(node, self.ctx)
- self.assertEqual(len(node.body[0].body), 1)
- with self.test_session() as sess:
- v = variables.Variable(2)
- sess.run(v.initializer)
- # NOTE: We don't expect the assignment to execute in this case, because
- # variables cannot be reliably guarded.
- self.assertEqual(2, sess.run(tf_test_fn(v)))
+ with self.compiled(node, state_ops.assign, ops.control_dependencies,
+ ops.Tensor) as result:
+ self.assertEqual(len(node.body[0].body), 1)
+ with self.test_session() as sess:
+ v = variables.Variable(2)
+ sess.run(v.initializer)
+ # NOTE: We don't expect the assignment to execute in this case, because
+ # variables cannot be reliably guarded.
+ self.assertEqual(2, sess.run(result.test_fn(v)))
def test_side_effect_on_used_variable(self):
+ tf = None
+
def test_fn(a):
- state_ops.assign(a, a + 1)
+ tf.assign(a, a + 1)
return a + 1
- tf_test_fn, node = self._transform_and_compile(test_fn)
+ node = self.parse_and_analyze(test_fn, {})
+ node = side_effect_guards.transform(node, self.ctx)
- self.assertEqual(len(node.body[0].body), 1)
- with self.test_session() as sess:
- v = variables.Variable(2)
- sess.run(v.initializer)
- # NOTE: Unlike test_side_effect_on_return_only_variable, the variable was
- # used in the local scope and so we could catch the assign's side effect.
- self.assertEqual(4, sess.run(tf_test_fn(v)))
+ with self.compiled(node, state_ops.assign, ops.control_dependencies,
+ ops.Tensor) as result:
+ self.assertEqual(len(node.body[0].body), 1)
+ with self.test_session() as sess:
+ v = variables.Variable(2)
+ sess.run(v.initializer)
+ # NOTE: Unlike test_side_effect_on_return_only_variable, the variable
+ # was used in the local scope and so we could catch the assign's side
+ # effect.
+ self.assertEqual(4, sess.run(result.test_fn(v)))
def test_side_effect_on_tensor(self):
+ tf = None
+
def test_fn(a):
- control_flow_ops.Assert(gen_math_ops.greater(a, 0), ['expected in throw'])
+ tf.Assert(a > 0, ['expected in throw'])
return a
- tf_test_fn, node = self._transform_and_compile(test_fn)
+ node = self.parse_and_analyze(test_fn, {})
+ node = side_effect_guards.transform(node, self.ctx)
- self.assertEqual(len(node.body[0].body), 1)
- with self.test_session() as sess:
- # NOTE: In this case we can also capture the side effect because the
- # argument is a tensor ans we can wrap it inside an identity.
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- 'expected in throw'):
- sess.run(tf_test_fn(constant_op.constant(-1)))
+ with self.compiled(node, array_ops.identity, control_flow_ops.Assert,
+ ops.control_dependencies, ops.Tensor) as result:
+ self.assertEqual(len(node.body[0].body), 1)
+ with self.test_session() as sess:
+ # NOTE: In this case we can also capture the side effect because the
+ # argument is a tensor ans we can wrap it inside an identity.
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ 'expected in throw'):
+ sess.run(result.test_fn(constant_op.constant(-1)))
def test_multiline_block(self):
+ tf = None
+
def test_fn(a):
- state_ops.assign(a, a + 1)
+ tf.assign(a, a + 1)
b = a + 1
- state_ops.assign(a, b + 1)
+ tf.assign(a, b + 1)
c = b + 1
d = c + 1
return d
- tf_test_fn, node = self._transform_and_compile(test_fn)
+ node = self.parse_and_analyze(test_fn, {})
+ node = side_effect_guards.transform(node, self.ctx)
- self.assertEqual(len(node.body[0].body), 1)
- with self.test_session() as sess:
- v = variables.Variable(2)
- sess.run(v.initializer)
- self.assertEqual(6, sess.run(tf_test_fn(v)))
+ with self.compiled(node, array_ops.identity, state_ops.assign,
+ ops.control_dependencies, ops.Tensor) as result:
+ self.assertEqual(len(node.body[0].body), 1)
+ with self.test_session() as sess:
+ v = variables.Variable(2)
+ sess.run(v.initializer)
+ self.assertEqual(6, sess.run(result.test_fn(v)))
def test_multiline_nested_block(self):
+ tf = None
+
def test_fn(a):
- with ops.name_scope('foo'):
- state_ops.assign(a, a + 1)
+ with tf.name_scope('foo'):
+ tf.assign(a, a + 1)
b = a + 1
- # state_ops.assign(a, b + 1)
c = b + 1
d = c + 1
return d
- tf_test_fn, node = self._transform_and_compile(test_fn)
+ node = self.parse_and_analyze(test_fn, {})
+ node = side_effect_guards.transform(node, self.ctx)
- self.assertEqual(len(node.body[0].body[0].body), 1)
- with self.test_session() as sess:
- v = variables.Variable(2)
- sess.run(v.initializer)
- self.assertEqual(6, sess.run(tf_test_fn(v)))
+ with self.compiled(node, array_ops.identity, state_ops.assign,
+ ops.control_dependencies, ops.name_scope,
+ ops.Tensor) as result:
+ self.assertEqual(len(node.body[0].body[0].body), 1)
+ with self.test_session() as sess:
+ v = variables.Variable(2)
+ sess.run(v.initializer)
+ self.assertEqual(6, sess.run(result.test_fn(v)))
def test_multiline_block_unsafe(self):
+ tf = None
+
def test_fn(a):
- state_ops.assign(a, a + 1)
+ tf.assign(a, a + 1)
b = a + 1
- state_ops.assign(a, a + 1)
+ tf.assign(a, a + 1)
c = b + 1
d = c + 1
return d
- tf_test_fn, node = self._transform_and_compile(test_fn)
+ node = self.parse_and_analyze(test_fn, {})
+ node = side_effect_guards.transform(node, self.ctx)
- self.assertEqual(len(node.body[0].body), 1)
- with self.test_session() as sess:
- v = variables.Variable(2)
- sess.run(v.initializer)
- # NOTE: This intentionally highlights the flakiness. The test should be
- # tightened down once that is solved.
- self.assertTrue(sess.run(tf_test_fn(v)) in (6, 7))
+ with self.compiled(node, array_ops.identity, state_ops.assign,
+ ops.control_dependencies, ops.Tensor) as result:
+ self.assertEqual(len(node.body[0].body), 1)
+ with self.test_session() as sess:
+ v = variables.Variable(2)
+ sess.run(v.initializer)
+ # NOTE: This intentionally highlights the flakiness. The test should be
+ # tightened down once that is solved.
+ self.assertTrue(sess.run(result.test_fn(v)) in (6, 7))
if __name__ == '__main__':
from tensorflow.contrib.py2tf.impl import conversion
from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
# TODO(mdan): Properly document the type hints.
return convert(arg_value_hints)(f)(*args, **kwargs)
-def convert(recursive=False, arg_types=None):
+def convert(recursive=False, verbose=False, arg_types=None):
"""Decorator that compiles a function to graph mode.
The decorator is dynamic - invoking compilation whenever the decorated
Args:
recursive: Whether to recusrively convert any functions that the decorator
function may call.
+ verbose: Whether to output the compiled code in the logs.
arg_types: See to_graph.
Returns:
wrapped = to_graph(
f,
recursive=recursive,
+ verbose=verbose,
arg_values=arg_values,
arg_types=arg_types,
partial_types=partial_types)
def to_graph(e,
recursive=True,
+ verbose=False,
arg_values=None,
arg_types=None,
partial_types=None):
e: A Python entity.
recursive: Whether to recusrively convert any functions that the decorator
function may call.
+ verbose: Whether to output the compiled code in the logs.
arg_values: A dict containing value hints for symbols like function
parameters.
arg_types: A dict containing type hints for symbols like function
module.body.append(parser.parse_str(import_line))
for dep in conversion_map.dependency_cache.values():
module.body.append(dep)
- compiled_node = compiler.ast_to_object(module)
+ compiled_node, compiled_src = compiler.ast_to_object(module)
# The compiled code should see everything the entry function saw.
# TODO(mdan): This might not work well if the call tree spans modules?
if tf_inspect.isfunction(e):
compiled_node.__dict__.update(six.get_function_globals(e))
-
compiled_fn = getattr(compiled_node, name)
+
+ if verbose:
+ logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
+
return compiled_fn
else:
raise ValueError('Unexpected symbol type "%s"' % type(s))
+ pieces = name_root.split('_')
+ if pieces[-1].isdigit():
+ name_root = '_'.join(pieces[:-1])
+ n = int(pieces[-1])
+ else:
+ n = 0
new_name = name_root
- n = 0
+
while (new_name in self.global_namespace or
new_name in all_reserved_locals or new_name in self.generated_names):
n += 1
],
)
+py_test(
+ name = "qual_names_test",
+ srcs = ["qual_names_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_test(
name = "templates_test",
srcs = ["templates_test.py"],
f.write(source_prefix)
f.write('\n')
f.write(source)
- return imp.load_source(module_name, f.name)
+ return imp.load_source(module_name, f.name), source
targets=[gast.Name('a', gast.Store(), None)],
value=gast.Str('c'))
])
+
self.assertEqual(
textwrap.dedent("""
if 1:
decorator_list=[],
returns=None)
- mod = compiler.ast_to_object(node)
+ module, source = compiler.ast_to_object(node)
- self.assertEqual(2, mod.f(1))
- with open(mod.__file__, 'r') as temp_output:
+ expected_source = """
+ def f(a):
+ return a + 1
+ """
+ self.assertEqual(
+ textwrap.dedent(expected_source).strip(),
+ source.strip())
+ self.assertEqual(2, module.f(1))
+ with open(module.__file__, 'r') as temp_output:
self.assertEqual(
- textwrap.dedent("""
- def f(a):
- return a + 1
- """).strip(),
+ textwrap.dedent(expected_source).strip(),
temp_output.read().strip())
class QN(object):
- """Represents a qualified name.
-
- """
+ """Represents a qualified name."""
def __init__(self, base, attr=None):
if attr:
self._parent = base
self.qn = base.qn + (attr,)
else:
- self._parent = None
- self.qn = tuple(base.split('.'))
+ if isinstance(base, QN):
+ if base.is_composite():
+ self._parent = base.parent
+ else:
+ self._parent = None
+ self.qn = base.qn
+ else:
+ self._parent = None
+ self.qn = tuple(base.split('.'))
def is_composite(self):
return len(self.qn) > 1
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for qual_names module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import textwrap
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import qual_names
+from tensorflow.python.platform import test
+
+
+class QNTest(test.TestCase):
+
+ def test_basic(self):
+ a = qual_names.QN('a')
+ self.assertEqual(a.qn, ('a',))
+ self.assertEqual(str(a), 'a')
+ self.assertEqual(a.ssf(), 'a')
+ self.assertEqual(a.ast().id, 'a')
+ self.assertFalse(a.is_composite())
+ with self.assertRaises(ValueError):
+ _ = a.parent
+
+ a_b = qual_names.QN(a, 'b')
+ self.assertEqual(a_b.qn, ('a', 'b'))
+ self.assertEqual(str(a_b), 'a.b')
+ self.assertEqual(a_b.ssf(), 'a_b')
+ self.assertEqual(a_b.ast().value.id, 'a')
+ self.assertEqual(a_b.ast().attr, 'b')
+ self.assertTrue(a_b.is_composite())
+ self.assertEqual(a_b.parent.qn, ('a',))
+
+ a2 = qual_names.QN(a)
+ self.assertEqual(a2.qn, ('a',))
+ with self.assertRaises(ValueError):
+ _ = a.parent
+
+ a_b2 = qual_names.QN(a_b)
+ self.assertEqual(a_b2.qn, ('a', 'b'))
+ self.assertEqual(a_b2.parent.qn, ('a',))
+
+ self.assertTrue(a2 == a)
+ self.assertFalse(a2 is a)
+
+ self.assertTrue(a_b.parent == a)
+ self.assertTrue(a_b2.parent == a)
+
+ self.assertTrue(a_b2 == a_b)
+ self.assertFalse(a_b2 is a_b)
+ self.assertFalse(a_b2 == a)
+
+ with self.assertRaises(ValueError):
+ qual_names.QN('a', 'b')
+
+ def test_hashable(self):
+ d = {qual_names.QN('a'): 'a', qual_names.QN('b'): 'b'}
+
+ self.assertEqual(d[qual_names.QN('a')], 'a')
+ self.assertEqual(d[qual_names.QN('b')], 'b')
+ self.assertTrue(qual_names.QN('c') not in d)
+
+
+class QNResolverTest(test.TestCase):
+
+ def assertQNStringIs(self, node, qn_str):
+ self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str)
+
+ def test_resolve(self):
+ samples = """
+ a
+ a.b
+ (c, d.e)
+ [f, (g.h.i)]
+ j(k, l)
+ """
+ nodes = qual_names.resolve(parser.parse_str(textwrap.dedent(samples)))
+ nodes = tuple(n.value for n in nodes.body)
+
+ self.assertQNStringIs(nodes[0], 'a')
+ self.assertQNStringIs(nodes[1], 'a.b')
+ self.assertQNStringIs(nodes[2].elts[0], 'c')
+ self.assertQNStringIs(nodes[2].elts[1], 'd.e')
+ self.assertQNStringIs(nodes[3].elts[0], 'f')
+ self.assertQNStringIs(nodes[3].elts[1], 'g.h.i')
+ self.assertQNStringIs(nodes[4].func, 'j')
+ self.assertQNStringIs(nodes[4].args[0], 'k')
+ self.assertQNStringIs(nodes[4].args[1], 'l')
+
+
+if __name__ == '__main__':
+ test.main()
repl = self.replacements[node.name]
if not isinstance(repl, (gast.Name, ast.Name)):
raise ValueError(
- 'A function name can only be replaced by a Name node. Found: %s',
+ 'A function name can only be replaced by a Name node. Found: %s' %
repl)
node.name = repl.id
return node
node.ctx = gast.Load()
elif isinstance(node, gast.Name):
node.ctx = ctx
+ elif isinstance(node, (gast.Str, gast.Num)):
+ pass
else:
raise ValueError('unexpected node type "%s"' % node)
"""
node = templates.replace(template, b=('a', 'c'))[0]
- result = compiler.ast_to_object(node)
+ result, _ = compiler.ast_to_object(node)
+
self.assertEquals((2, 3), result.test_fn(2, 3))
def test_replace_variable(self):
"""
node = templates.replace(template, a='b')[0]
- result = compiler.ast_to_object(node)
+ result, _ = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self):
"""
node = templates.replace(template, fname='test_fn')[0]
- result = compiler.ast_to_object(node)
+ result, _ = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_code_block(self):
gast.Name('a', None, None)
], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
] * 2)[0]
- result = compiler.ast_to_object(node)
+ result, _ = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn(1))