[clang.py] Store reference to TranslationUnit in Cursor and Type
authorGregory Szorc <gregory.szorc@gmail.com>
Tue, 15 May 2012 19:51:02 +0000 (19:51 +0000)
committerGregory Szorc <gregory.szorc@gmail.com>
Tue, 15 May 2012 19:51:02 +0000 (19:51 +0000)
llvm-svn: 156846

clang/bindings/python/clang/cindex.py
clang/bindings/python/tests/cindex/test_cursor.py
clang/bindings/python/tests/cindex/test_type.py

index 54a3bfd..c599faa 100644 (file)
@@ -929,7 +929,12 @@ class Cursor(Structure):
 
     @staticmethod
     def from_location(tu, location):
-        return Cursor_get(tu, location)
+        # We store a reference to the TU in the instance so the TU won't get
+        # collected before the cursor.
+        cursor = Cursor_get(tu, location)
+        cursor._tu = tu
+
+        return cursor
 
     def __eq__(self, other):
         return Cursor_eq(self, other)
@@ -1127,6 +1132,13 @@ class Cursor(Structure):
 
         return self._lexical_parent
 
+    @property
+    def translation_unit(self):
+        """Returns the TranslationUnit to which this Cursor belongs."""
+        # If this triggers an AttributeError, the instance was not properly
+        # created.
+        return self._tu
+
     def get_children(self):
         """Return an iterator for accessing the children of this cursor."""
 
@@ -1135,6 +1147,9 @@ class Cursor(Structure):
             # FIXME: Document this assertion in API.
             # FIXME: There should just be an isNull method.
             assert child != Cursor_null()
+
+            # Create reference to TU so it isn't GC'd before Cursor.
+            child._tu = self._tu
             children.append(child)
             return 1 # continue
         children = []
@@ -1147,6 +1162,22 @@ class Cursor(Structure):
         # FIXME: There should just be an isNull method.
         if res == Cursor_null():
             return None
+
+        # Store a reference to the TU in the Python object so it won't get GC'd
+        # before the Cursor.
+        tu = None
+        for arg in args:
+            if isinstance(arg, TranslationUnit):
+                tu = arg
+                break
+
+            if hasattr(arg, 'translation_unit'):
+                tu = arg.translation_unit
+                break
+
+        assert tu is not None
+
+        res._tu = tu
         return res
 
 
@@ -1324,9 +1355,26 @@ class Type(Structure):
 
         return result
 
+    @property
+    def translation_unit(self):
+        """The TranslationUnit to which this Type is associated."""
+        # If this triggers an AttributeError, the instance was not properly
+        # instantiated.
+        return self._tu
+
     @staticmethod
     def from_result(res, fn, args):
         assert isinstance(res, Type)
+
+        tu = None
+        for arg in args:
+            if hasattr(arg, 'translation_unit'):
+                tu = arg.translation_unit
+                break
+
+        assert tu is not None
+        res._tu = tu
+
         return res
 
     def get_canonical(self):
index c80cf10..9a37d2f 100644 (file)
@@ -1,4 +1,7 @@
+import gc
+
 from clang.cindex import CursorKind
+from clang.cindex import TranslationUnit
 from clang.cindex import TypeKind
 from .util import get_cursor
 from .util import get_cursors
@@ -38,6 +41,8 @@ def test_get_children():
     tu_nodes = list(it)
 
     assert len(tu_nodes) == 3
+    for cursor in tu_nodes:
+        assert cursor.translation_unit is not None
 
     assert tu_nodes[0] != tu_nodes[1]
     assert tu_nodes[0].kind == CursorKind.STRUCT_DECL
@@ -47,6 +52,7 @@ def test_get_children():
     assert tu_nodes[0].location.line == 4
     assert tu_nodes[0].location.column == 8
     assert tu_nodes[0].hash > 0
+    assert tu_nodes[0].translation_unit is not None
 
     s0_nodes = list(tu_nodes[0].get_children())
     assert len(s0_nodes) == 2
@@ -67,6 +73,23 @@ def test_get_children():
     assert tu_nodes[2].displayname == 'f0(int, int)'
     assert tu_nodes[2].is_definition() == True
 
+def test_references():
+    """Ensure that references to TranslationUnit are kept."""
+    tu = get_tu('int x;')
+    cursors = list(tu.cursor.get_children())
+    assert len(cursors) > 0
+
+    cursor = cursors[0]
+    assert isinstance(cursor.translation_unit, TranslationUnit)
+
+    # Delete reference to TU and perform a full GC.
+    del tu
+    gc.collect()
+    assert isinstance(cursor.translation_unit, TranslationUnit)
+
+    # If the TU was destroyed, this should cause a segfault.
+    parent = cursor.semantic_parent
+
 def test_canonical():
     source = 'struct X; struct X; struct X { int member; };'
     tu = get_tu(source)
index 9b5a16e..28e4411 100644 (file)
@@ -1,4 +1,7 @@
+import gc
+
 from clang.cindex import CursorKind
+from clang.cindex import TranslationUnit
 from clang.cindex import TypeKind
 from nose.tools import raises
 from .util import get_cursor
@@ -28,6 +31,7 @@ def test_a_struct():
     assert teststruct is not None, "Could not find teststruct."
     fields = list(teststruct.get_children())
     assert all(x.kind == CursorKind.FIELD_DECL for x in fields)
+    assert all(x.translation_unit is not None for x in fields)
 
     assert fields[0].spelling == 'a'
     assert not fields[0].type.is_const_qualified()
@@ -72,6 +76,26 @@ def test_a_struct():
     assert fields[7].type.get_pointee().get_pointee().kind == TypeKind.POINTER
     assert fields[7].type.get_pointee().get_pointee().get_pointee().kind == TypeKind.INT
 
+def test_references():
+    """Ensure that a Type maintains a reference to a TranslationUnit."""
+
+    tu = get_tu('int x;')
+    children = list(tu.cursor.get_children())
+    assert len(children) > 0
+
+    cursor = children[0]
+    t = cursor.type
+
+    assert isinstance(t.translation_unit, TranslationUnit)
+
+    # Delete main TranslationUnit reference and force a GC.
+    del tu
+    gc.collect()
+    assert isinstance(t.translation_unit, TranslationUnit)
+
+    # If the TU was destroyed, this should cause a segfault.
+    decl = t.get_declaration()
+
 constarrayInput="""
 struct teststruct {
   void *A[2];