Add a helper function to copy annotations between nodes.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 6 Mar 2018 19:00:46 +0000 (11:00 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 19:05:00 +0000 (11:05 -0800)
PiperOrigin-RevId: 188047677

tensorflow/contrib/py2tf/pyct/anno.py
tensorflow/contrib/py2tf/pyct/anno_test.py

index 7a0528b..cc4a7ed 100644 (file)
@@ -70,3 +70,8 @@ def delanno(node, key, field_name='___pyct_anno'):
   if not annotations:
     delattr(node, field_name)
     node._fields = tuple(f for f in node._fields if f != field_name)
+
+
+def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
+  if hasanno(from_node, key, field_name):
+    setanno(to_node, key, getanno(from_node, key, field_name), field_name)
index ff40bfe..6c29918 100644 (file)
@@ -24,6 +24,9 @@ from tensorflow.contrib.py2tf.pyct import anno
 from tensorflow.python.platform import test
 
 
+# TODO(mdan): Consider strong types instead of primitives.
+
+
 class AnnoTest(test.TestCase):
 
   def test_basic(self):
@@ -42,6 +45,17 @@ class AnnoTest(test.TestCase):
     with self.assertRaises(AttributeError):
       anno.getanno(node, 'foo')
 
+  def test_copyanno(self):
+    node_1 = ast.Name()
+    anno.setanno(node_1, 'foo', 3)
+
+    node_2 = ast.Name()
+    anno.copyanno(node_1, node_2, 'foo')
+    anno.copyanno(node_1, node_2, 'bar')
+
+    self.assertTrue(anno.hasanno(node_2, 'foo'))
+    self.assertFalse(anno.hasanno(node_2, 'bar'))
+
 
 if __name__ == '__main__':
   test.main()