else:
raise ValueError('unexpected node type "%s"' % node)
+ def visit_Attribute(self, node):
+ node = self.generic_visit(node)
+ if node.attr not in self.replacements:
+ return node
+ repl = self.replacements[node.attr]
+ if not isinstance(repl, gast.Name):
+ raise ValueError(
+ 'An attribute can only be replaced by a Name node. Found: %s' % repl)
+ node.attr = repl.id
+ return node
+
def visit_Name(self, node):
if node.id not in self.replacements:
return node
from __future__ import division
from __future__ import print_function
+import imp
+
import gast
from tensorflow.contrib.py2tf.pyct import compiler
result, _ = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
- def test_code_block(self):
+ def test_replace_code_block(self):
template = """
def test_fn(a):
block
result, _ = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn(1))
+ def test_replace_attribute(self):
+ template = """
+ def test_fn(a):
+ return a.foo
+ """
+
+ node = templates.replace(template, foo='b')[0]
+ result, _ = compiler.ast_to_object(node)
+ mod = imp.new_module('test')
+ mod.b = 3
+ self.assertEquals(3, result.test_fn(mod))
+
+ with self.assertRaises(ValueError):
+ templates.replace(template, foo=1)
+
if __name__ == '__main__':
test.main()