if not annotations:
delattr(node, field_name)
node._fields = tuple(f for f in node._fields if f != field_name)
+
+
+def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
+ if hasanno(from_node, key, field_name):
+ setanno(to_node, key, getanno(from_node, key, field_name), field_name)
from tensorflow.python.platform import test
+# TODO(mdan): Consider strong types instead of primitives.
+
+
class AnnoTest(test.TestCase):
def test_basic(self):
with self.assertRaises(AttributeError):
anno.getanno(node, 'foo')
+ def test_copyanno(self):
+ node_1 = ast.Name()
+ anno.setanno(node_1, 'foo', 3)
+
+ node_2 = ast.Name()
+ anno.copyanno(node_1, node_2, 'foo')
+ anno.copyanno(node_1, node_2, 'bar')
+
+ self.assertTrue(anno.hasanno(node_2, 'foo'))
+ self.assertFalse(anno.hasanno(node_2, 'bar'))
+
if __name__ == '__main__':
test.main()