self.break_uses = []
def _create_break_check(self):
-
- def template(var_name):
- (not var_name) # pylint:disable=pointless-statement
-
- expr, = templates.replace(
- template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+ template = """
+ (not var_name)
+ """
+ expr, = templates.replace(template, var_name=self.break_uses[-1][1])
return expr.value
def _create_break_trigger(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = True
-
- block = templates.replace(
- template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+ """
+ block = templates.replace(template, var_name=self.break_uses[-1][1])
block.append(gast.Continue())
return block
def _create_break_init(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = False
-
- assign, = templates.replace(
- template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+ """
+ assign, = templates.replace(template, var_name=self.break_uses[-1][1])
return assign
# TODO(mdan): Surely the transformer supports this better?
# TODO(mdan): Bring print_functions in here.
def _convert_len(self, node):
-
- def template(args):
- tf.shape(args)[0] # pylint:disable=undefined-variable,expression-not-assigned
-
+ template = """
+ tf.shape(args)[0]
+ """
new_call = templates.replace(template, args=node.args)[0].value
return new_call
else:
new_name = self.namer.compiled_function_name(
'__'.join(target_fqn), live_object=target_obj)
- node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
+ node.func = gast.Name(new_name, gast.Load(), None)
return node
def _rename_member_function_of_known_type(self, node):
def _wrap_to_py_func_no_return(self, node):
args_scope = anno.getanno(node, 'args_scope')
# TODO(mdan): Properly handle varargs, kwargs, etc.
- args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)
-
- # pylint:disable=undefined-variable,unused-argument,function-redefined
-
- def template(call, wrapper, args):
-
+ template = """
def wrapper(args):
call(args)
return 1
-
tf.py_func(wrapper, [args], [tf.int64])
-
- # pylint:enable=undefined-variable,unused-argument,function-redefined
-
- wrapper_name = self.namer.compiled_function_name(node.func.id)
+ """
wrapper_def, call_expr = templates.replace(
template,
call=node.func,
- wrapper=gast.Name(wrapper_name, gast.Load(), None),
- args=args)
+ wrapper=self.namer.compiled_function_name(node.func.id),
+ args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used))
anno.setanno(call_expr.value, 'args_scope', args_scope)
# TODO(mdan): Rename this annotation to 'graph_ready'
anno.setanno(wrapper_def, 'skip_processing', True)
self.continuation_uses = []
def _create_continuation_check(self):
-
- def template(var_name):
+ template = """
if not var_name:
pass
-
- cond, = templates.replace(
- template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+ """
+ cond, = templates.replace(template, var_name=self.continuation_uses[-1][1])
cond.body = []
return cond
def _create_continuation_trigger(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = True
-
+ """
assign, = templates.replace(
- template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+ template, var_name=self.continuation_uses[-1][1])
return assign
def _create_continuation_init(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = False
-
+ """
assign, = templates.replace(
- template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+ template, var_name=self.continuation_uses[-1][1])
return assign
def _visit_and_reindent_if_necessary(self, nodes):
raise ValueError(
'The else branch creates new symbols that the if branch does not.')
- def template( # pylint:disable=missing-docstring
- test,
- body_name,
- body,
- orelse_name,
- orelse,
- aliased,
- aliases, # pylint:disable=unused-argument
- aliased_results,
- results): # pylint:disable=unused-argument
-
- def body_name(): # pylint:disable=function-redefined
- aliases, = aliased, # pylint:disable=unused-variable
- body # pylint:disable=pointless-statement
- return (aliased_results,)
-
- def orelse_name(): # pylint:disable=function-redefined
- aliases, = aliased, # pylint:disable=unused-variable
- orelse # pylint:disable=pointless-statement
- return (aliased_results,)
-
- results = tf.cond(test, body_name, orelse_name) # pylint:disable=undefined-variable
-
all_modified = tuple(body_scope.modified | orelse_scope.modified)
all_referenced = body_scope.referenced | orelse_scope.referenced
need_alias = (
(body_scope.modified | orelse_scope.modified) -
(body_scope.created | orelse_scope.created))
- aliased = tuple(need_alias)
- aliases = tuple(
- self.namer.new_symbol(s, all_referenced) for s in aliased)
- alias_map = dict(zip(aliased, aliases))
+ aliased_orig_names = tuple(need_alias)
+ aliased_new_names = tuple(
+ self.namer.new_symbol(s, all_referenced) for s in aliased_orig_names)
+ alias_map = dict(zip(aliased_orig_names, aliased_new_names))
node_body = node.body
node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body]
node_orelse = node.orelse
results = gast.Tuple(
tuple(gast.Name(s, None, None) for s in all_modified), None)
+ template = """
+ def body_name():
+ aliased_new_names, = aliased_orig_names,
+ body
+ return (all_results,)
+ def orelse_name():
+ aliased_new_names, = aliased_orig_names,
+ orelse
+ return (all_results,)
+ results = tf.cond(test, body_name, orelse_name)
+ """
+ body_name = self.namer.new_symbol('if_true', all_referenced)
return templates.replace(
template,
test=node.test,
- body_name=gast.Name(
- self.namer.new_symbol('if_true', all_referenced), None, None),
+ body_name=body_name,
body=node_body,
- orelse_name=gast.Name(
- self.namer.new_symbol('if_false', all_referenced), None, None),
+ orelse_name=self.namer.new_symbol('if_false', all_referenced),
orelse=node_orelse,
- aliased=tuple(gast.Name(s, None, None) for s in aliased),
- aliases=tuple(gast.Name(s, None, None) for s in aliases),
- aliased_results=tuple(
- gast.Name(alias_map[s] if s in aliased else s, None, None)
- for s in all_modified),
+ aliased_orig_names=tuple(aliased_orig_names),
+ aliased_new_names=tuple(aliased_new_names),
+ all_results=tuple(alias_map[s] if s in aliased_orig_names else s
+ for s in all_modified),
results=results)
def visit_While(self, node):
body_scope = anno.getanno(node, 'body_scope')
body_closure = tuple(body_scope.modified - body_scope.created)
- def template(
- state, # pylint:disable=unused-argument
- state_ast_tuple, # pylint:disable=unused-argument
- test_name,
- test, # pylint:disable=unused-argument
- body_name,
- body):
-
- def test_name(state): # pylint:disable=function-redefined,unused-argument
- return test
-
- def body_name(state): # pylint:disable=function-redefined,unused-argument
- body # pylint:disable=pointless-statement
- return state,
-
- state_ast_tuple = tf.while_loop(test_name, body_name, [state]) # pylint:disable=undefined-variable
-
- test_name = self.namer.new_symbol('loop_test', body_scope.referenced)
- body_name = self.namer.new_symbol('loop_body', body_scope.referenced)
if len(body_closure) == 1:
- state = gast.Name(body_closure[0], None, None)
+ state = body_closure[0]
state_ast_tuple = state
else:
- state = tuple(gast.Name(n, None, None) for n in body_closure)
- state_ast_tuple = gast.Tuple(state, None)
+ state = tuple(body_closure)
+ state_ast_tuple = gast.Tuple(
+ tuple(gast.Name(n, None, None) for n in state), None)
+ template = """
+ def test_name(state):
+ return test
+ def body_name(state):
+ body
+ return state,
+ state_ast_tuple = tf.while_loop(test_name, body_name, [state])
+ """
node = templates.replace(
template,
state=state,
state_ast_tuple=state_ast_tuple,
- test_name=gast.Name(test_name, gast.Load(), None),
+ test_name=self.namer.new_symbol('loop_test', body_scope.referenced),
test=node.test,
- body_name=gast.Name(body_name, gast.Load(), None),
+ body_name=self.namer.new_symbol('loop_body', body_scope.referenced),
body=node.body)
return node
# Or maybe we should replace range with tf.range?
if anno.hasanno(node, 'extra_cond'):
-
- def template(loop_iter, target, body, i, n, extra_cond): # pylint:disable=unused-argument
+ template = """
i = 0
- n = len(loop_iter) # pylint:disable=undefined-variable
+ n = len(loop_iter)
while i < n and extra_cond:
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
target = loop_iter[i]
- body # pylint:disable=pointless-statement
+ body
i += 1
-
+ """
return templates.replace(
template,
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=gast.Name(
- self.namer.new_symbol('i', body_scope.referenced), None, None),
- n=gast.Name(
- self.namer.new_symbol('n', body_scope.referenced), None, None),
+ i=self.namer.new_symbol('i', body_scope.referenced),
+ n=self.namer.new_symbol('n', body_scope.referenced),
extra_cond=anno.getanno(node, 'extra_cond'))
else:
-
- def template(loop_iter, target, body, i, n): # pylint:disable=unused-argument
+ template = """
i = 0
- n = len(loop_iter) # pylint:disable=undefined-variable
+ n = len(loop_iter)
while i < n:
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
target = loop_iter[i]
body # pylint:disable=pointless-statement
i += 1
-
+ """
return templates.replace(
template,
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=gast.Name(
- self.namer.new_symbol('i', body_scope.referenced), None, None),
- n=gast.Name(
- self.namer.new_symbol('n', body_scope.referenced), None, None))
+ i=self.namer.new_symbol('i', body_scope.referenced),
+ n=self.namer.new_symbol('n', body_scope.referenced))
def visit_Continue(self, node):
assert False, 'continue statement should be desugared at this point'
return node
def _gate_symbols(self, guard_statement, guarded_args):
-
- def template(args): # pylint:disable=unused-argument
- (args,) = (tf.identity(a) for a in (args,)) # pylint:disable=undefined-variable
-
- guards = templates.replace(
- template, args=tuple(gast.Name(a, None, None) for a in guarded_args))
+ template = """
+ (args,) = (tf.identity(a) for a in (args,))
+ """
+ guards = templates.replace(template, args=tuple(guarded_args))
guard_statement.body.extend(guards)
return guard_statement
# opt.minimize(loss)
# or:
# tf.py_func(...)
-
args_scope = anno.getanno(node.value, 'args_scope')
temp_name = self.namer.new_symbol('temp', args_scope.parent.referenced)
# TODO(mdan): Unsafe reference modification!
args_scope.mark_write(temp_name)
-
- def template(call, temp_result):
+ template = """
temp_result = call
if temp_result is not None:
if not isinstance(temp_result, (list, tuple)):
temp_result = (temp_result,)
- ctx = tf.control_dependencies(temp_result) # pylint:disable=undefined-variable
+ ctx = tf.control_dependencies(temp_result)
else:
- ctx = contextmanager(lambda: (yield))() # pylint:disable=undefined-variable
+ ctx = contextmanager(lambda: (yield))()
with ctx:
# TODO(mdan): Also insert ops to re-fetch if variables are involved.
pass # Will be removed below.
-
- # TODO(mdan): This is brittle. Reorganize this mechanism.
+ """
+ # TODO(mdan): This is brittle. Reorganize the mechanism.
statements = templates.replace(
- template,
- call=node.value,
- temp_result=gast.Name(temp_name, None, None))
+ template, call=node.value, temp_result=temp_name)
control_deps_guard = statements[-1]
control_deps_guard.body = []
return node
+def _strings_to_names(n):
+ if isinstance(n, str):
+ # Note: the node will receive the ctx value from the template, see
+ # ReplaceTransformer.visit_Name.
+ return gast.Name(id=n, ctx=None, annotation=None)
+ if isinstance(n, list):
+ return [_strings_to_names(e) for e in n]
+ if isinstance(n, tuple):
+ return tuple(_strings_to_names(e) for e in n)
+ return n
+
+
def replace(template, **replacements):
"""Replace placeholders in a Python template.
+ AST Name and Tuple nodes always receive the context that inferred from
+ the template. However, when replacing more complex nodes (that can potentially
+ contain Name children), then the caller is responsible for setting the
+ appropriate context.
+
Args:
- template: A function to be used as a template. Any placeholder is expected
- to also be a function argument.
+ template: A string representing Python code. Any symbol name can be used
+ that appears in the template code can be used as placeholder.
**replacements: A mapping from placeholder names to (lists of) AST nodes
- that these placeholders will be replaced by.
+ that these placeholders will be replaced by. String values are also
+ supported as a shorthand for AST Name nodes with the respective ID.
Returns:
- body: An AST node or list of AST nodes with the replacements made. If the
- template was a function, a list will be returned. If the template was a
- node, the same node will be returned. If the template was a string, an
- AST node will be returned (a `Module` node in the case of a multi-line
- string, an `Expr` node otherwise).
+ An AST node or list of AST nodes with the replacements made. If the
+ template was a function, a list will be returned. If the template was a
+ node, the same node will be returned. If the template was a string, an
+ AST node will be returned (a `Module` node in the case of a multi-line
+ string, an `Expr` node otherwise).
Raises:
- ValueError: If a function is used as a template and an incorrect set of
- replacements was passed.
+ ValueError: if the arguments are incorrect.
"""
- tree = parser.parse_object(template).body[0]
- placeholders = set(arg.id for arg in tree.args.args)
- tree.args.args = []
- if tree.args.vararg:
- placeholders.add(tree.args.vararg)
- tree.args.vararg = None
- if set(replacements.keys()) != placeholders:
- raise ValueError(
- 'too many or few replacements. replacements: %s; placeholders: %s' %
- (replacements.keys(), placeholders))
-
- # Perform the replacement, stripping the function into which the template was
- # wrapped.
+ if not isinstance(template, str):
+ raise ValueError('Expected string template, got %s' % type(template))
+ tree = parser.parse_str(template)
+ for k in replacements:
+ replacements[k] = _strings_to_names(replacements[k])
return ReplaceTransformer(replacements).visit(tree).body
class TemplatesTest(test.TestCase):
def test_replace_variable(self):
- def template(a): # pylint:disable=unused-argument
- def test_fn(a): # pylint:disable=unused-variable
+ template = """
+ def test_fn(a):
a += 1
a = 2 * a + 1
- return b # pylint:disable=undefined-variable
+ return b
+ """
- node = templates.replace(
- template, a=gast.Name('b', gast.Load(), None))[0]
+ node = templates.replace(template, a='b')[0]
result = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self):
- def template(fname): # pylint:disable=unused-argument
- def fname(a): # pylint:disable=function-redefined
+ template = """
+ def fname(a):
a += 1
a = 2 * a + 1
return a
+ """
- node = templates.replace(
- template, fname=gast.Name('test_fn', gast.Load(), None))[0]
+ node = templates.replace(template, fname='test_fn')[0]
result = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_code_block(self):
- def template(block): # pylint:disable=unused-argument
- def test_fn(a): # pylint:disable=unused-variable
- block # pylint:disable=pointless-statement
+ template = """
+ def test_fn(a):
+ block
return a
+ """
node = templates.replace(
template,
block=[
- gast.Assign(
- [
- gast.Name('a', gast.Store(), None)
- ],
- gast.BinOp(
- gast.Name('a', gast.Load(), None),
- gast.Add(),
- gast.Num(1))),
+ gast.Assign([
+ gast.Name('a', None, None)
+ ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
] * 2)[0]
result = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn(1))