Allow replacing attributes in templates.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 2 Mar 2018 01:58:07 +0000 (17:58 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Mar 2018 02:02:28 +0000 (18:02 -0800)
PiperOrigin-RevId: 187562864

tensorflow/contrib/py2tf/pyct/templates.py
tensorflow/contrib/py2tf/pyct/templates_test.py

index 6ee6c0c5ceb70d87779ee313670135cadc5214b5..7021e2ba93743deb5ba6fecfe88428600b9489db 100644 (file)
@@ -79,6 +79,17 @@ class ReplaceTransformer(gast.NodeTransformer):
     else:
       raise ValueError('unexpected node type "%s"' % node)
 
+  def visit_Attribute(self, node):
+    node = self.generic_visit(node)
+    if node.attr not in self.replacements:
+      return node
+    repl = self.replacements[node.attr]
+    if not isinstance(repl, gast.Name):
+      raise ValueError(
+          'An attribute can only be replaced by a Name node. Found: %s' % repl)
+    node.attr = repl.id
+    return node
+
   def visit_Name(self, node):
     if node.id not in self.replacements:
       return node
index 8ccfde8573724741b0bbe4eacb3c54beb381ee7e..0d1c1c5d9ecf3fb9d7956f35bfce736389c0ec57 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import imp
+
 import gast
 
 from tensorflow.contrib.py2tf.pyct import compiler
@@ -62,7 +64,7 @@ class TemplatesTest(test.TestCase):
     result, _ = compiler.ast_to_object(node)
     self.assertEquals(7, result.test_fn(2))
 
-  def test_code_block(self):
+  def test_replace_code_block(self):
     template = """
       def test_fn(a):
         block
@@ -79,6 +81,21 @@ class TemplatesTest(test.TestCase):
     result, _ = compiler.ast_to_object(node)
     self.assertEquals(3, result.test_fn(1))
 
+  def test_replace_attribute(self):
+    template = """
+      def test_fn(a):
+        return a.foo
+    """
+
+    node = templates.replace(template, foo='b')[0]
+    result, _ = compiler.ast_to_object(node)
+    mod = imp.new_module('test')
+    mod.b = 3
+    self.assertEquals(3, result.test_fn(mod))
+
+    with self.assertRaises(ValueError):
+      templates.replace(template, foo=1)
+
 
 if __name__ == '__main__':
   test.main()