move object finalisation code from tp_dealloc() into tp_finalize() in Py3.4
authorStefan Behnel <stefan_ml@behnel.de>
Sun, 4 Aug 2013 12:58:05 +0000 (14:58 +0200)
committerStefan Behnel <stefan_ml@behnel.de>
Sun, 4 Aug 2013 12:58:05 +0000 (14:58 +0200)
CHANGES.rst
Cython/Compiler/ModuleNode.py
Cython/Compiler/Symtab.py
Cython/Compiler/TypeSlots.py
Cython/Utility/ExtensionTypes.c

index 612a9c5..002379a 100644 (file)
@@ -8,6 +8,10 @@ Cython Changelog
 Features added
 --------------
 
+* Starting with CPython 3.4, the user provided finalisation code in the
+  ``__dealloc__()`` special method is called by ``tp_finalize()`` instead
+  of ``tp_dealloc()`` to provide a safer execution environment.
+
 * During cyclic garbage collection, attributes of extension types that
   cannot create reference cycles due to their type (e.g. strings) are
   no longer considered for traversal or clearing.
index 6f2a114..6fe933b 100644 (file)
@@ -1006,6 +1006,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                 if scope: # could be None if there was an error
                     self.generate_exttype_vtable(scope, code)
                     self.generate_new_function(scope, code, entry)
+                    self.generate_finalize_function(scope, code)
                     self.generate_dealloc_function(scope, code)
                     if scope.needs_gc():
                         self.generate_traverse_function(scope, code, entry)
@@ -1172,32 +1173,39 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         slot_func = scope.mangle_internal("tp_dealloc")
         base_type = scope.parent_type.base_type
         if tp_slot.slot_code(scope) != slot_func:
-            return # never used
+            return  # never used
+
+        cpp_class_attrs = [entry for entry in scope.var_entries
+                           if entry.type.is_cpp_class]
+        needs_finalisation = scope.needs_finalisation()
+        needs_gc = scope.needs_gc()
 
         slot_func_cname = scope.mangle_internal("tp_dealloc")
         code.putln("")
         code.putln(
             "static void %s(PyObject *o) {" % slot_func_cname)
-
-        weakref_slot = scope.lookup_here("__weakref__")
-        _, (py_attrs, _, memoryview_slices) = scope.get_refcounted_entries()
-        cpp_class_attrs = [entry for entry in scope.var_entries if entry.type.is_cpp_class]
-
-        if (py_attrs
-            or cpp_class_attrs
-            or memoryview_slices
-            or weakref_slot in scope.var_entries):
+        if cpp_class_attrs:
             self.generate_self_cast(scope, code)
-        
-        # We must mark ths object as (gc) untracked while tearing it down, lest
-        # the garbage collection is invoked while running this destructor.
-        if scope.needs_gc():
+
+        if needs_gc and not base_type:
             code.putln("PyObject_GC_UnTrack(o);")
 
-        # call the user's __dealloc__
-        self.generate_usr_dealloc_call(scope, code)
-        if weakref_slot in scope.var_entries:
-            code.putln("if (p->__weakref__) PyObject_ClearWeakRefs(o);")
+        if needs_finalisation:
+            # before Py3.4, tp_finalize() isn't available, so this is
+            # the earliest possible time where we can call it ourselves
+            code.putln("#if PY_VERSION_HEX < 0x030400a1")
+            if needs_gc and base_type:
+                # We must mark ths object as (gc) untracked while tearing
+                # it down, lest the garbage collection is invoked while
+                # running this destructor.
+                code.putln("PyObject_GC_UnTrack(o);")
+            slot_func_cname = scope.mangle_internal("tp_finalize")
+            code.putln("%s(o);" % slot_func_cname)
+            if needs_gc and base_type:
+                # The base class deallocator probably expects this to be
+                # tracked, so undo the untracking above.
+                code.putln("PyObject_GC_Track(o);")
+            code.putln("#endif")
 
         for entry in cpp_class_attrs:
             split_cname = entry.type.cname.split('::')
@@ -1206,23 +1214,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             while destructor_name.count('<') != destructor_name.count('>'):
                 destructor_name = split_cname.pop() + '::' + destructor_name
             destructor_name = destructor_name.split('<',1)[0]
-            code.putln("p->%s.%s::~%s();" %
-                (entry.cname, entry.type.declaration_code(""), destructor_name))
-
-        for entry in py_attrs:
-            code.put_xdecref_clear("p->%s" % entry.cname, entry.type, nanny=False,
-                                   clear_before_decref=True)
-
-        for entry in memoryview_slices:
-            code.put_xdecref_memoryviewslice("p->%s" % entry.cname,
-                                             have_gil=True)
+            code.putln("p->%s.%s::~%s();" % (
+                entry.cname, entry.type.declaration_code(""), destructor_name))
 
         if base_type:
-            # The base class deallocator probably expects this to be tracked, so
-            # undo the untracking above.
-            if scope.needs_gc():
-                code.putln("PyObject_GC_Track(o);")
-
             tp_dealloc = TypeSlots.get_base_slot_function(scope, tp_slot)
             if tp_dealloc is not None:
                 code.putln("%s(o);" % tp_dealloc)
@@ -1234,7 +1229,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                 # the module cleanup, which may already have cleared it.
                 # In that case, fall back to traversing the type hierarchy.
                 base_cname = base_type.typeptr_cname
-                code.putln("if (likely(%s)) %s->tp_dealloc(o); else __Pyx_call_next_tp_dealloc(o, %s);" % (
+                code.putln("if (likely(%s)) %s->tp_dealloc(o); "
+                           "else __Pyx_call_next_tp_dealloc(o, %s);" % (
                     base_cname, base_cname, slot_func_cname))
                 code.globalstate.use_utility_code(
                     UtilityCode.load_cached("CallNextTpDealloc", "ExtensionTypes.c"))
@@ -1256,28 +1252,86 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.putln(
             "}")
 
-    def generate_usr_dealloc_call(self, scope, code):
+    def generate_finalize_function(self, scope, code):
+        if not scope.needs_finalisation():
+            return
+
         entry = scope.lookup_here("__dealloc__")
         if entry:
-            code.putln(
-                "{")
-            code.putln(
-                    "PyObject *etype, *eval, *etb;")
-            code.putln(
-                    "PyErr_Fetch(&etype, &eval, &etb);")
-            code.putln(
-                    "++Py_REFCNT(o);")
-            code.putln(
-                    "%s(o);" %
-                        entry.func_cname)
-            code.putln(
-                    "if (PyErr_Occurred()) PyErr_WriteUnraisable(o);")
-            code.putln(
-                    "--Py_REFCNT(o);")
-            code.putln(
-                    "PyErr_Restore(etype, eval, etb);")
-            code.putln(
-                "}")
+            dealloc_user_cfunc = entry.func_cname
+        else:
+            dealloc_user_cfunc = None
+
+        weakref_slot = scope.lookup_here("__weakref__")
+        if weakref_slot not in scope.var_entries:
+            weakref_slot = None
+        _, (py_attrs, _, memoryview_slices) = scope.get_refcounted_entries()
+
+        slot_func_cname = scope.mangle_internal("tp_finalize")
+        code.putln("")
+        code.putln("static void %s(PyObject *o) {" % slot_func_cname)
+
+        if py_attrs or memoryview_slices or weakref_slot:
+            self.generate_self_cast(scope, code)
+
+        if dealloc_user_cfunc:
+            code.putln("PyObject *etype, *eval, *etb;")
+            code.putln("PyErr_Fetch(&etype, &eval, &etb);")
+
+            code.putln("#if PY_VERSION_HEX < 0x030400a1")
+            code.putln("++Py_REFCNT(o);")
+            code.putln("#endif")
+
+            code.putln("%s(o);" % dealloc_user_cfunc)
+            code.putln("if (PyErr_Occurred()) PyErr_WriteUnraisable(o);")
+
+            code.putln("#if PY_VERSION_HEX < 0x030400a1")
+            code.putln("--Py_REFCNT(o);")
+            code.putln("#endif")
+
+            code.putln("PyErr_Restore(etype, eval, etb);")
+
+        if weakref_slot:
+            # Not using a preprocessor test here to avoid warning about
+            # "unused variable p".  Py3.4+ already cleared the weak refs
+            # when we get here.
+            # FIXME: so did Py<3.4 I think? Isn't this redundant?
+            code.putln("if ((PY_VERSION_HEX < 0x030400a1) && p->__weakref__) "
+                       "PyObject_ClearWeakRefs(o);")
+
+        for entry in py_attrs:
+            code.put_xdecref_clear("p->%s" % entry.cname, entry.type, nanny=False,
+                                   clear_before_decref=True)
+
+        for entry in memoryview_slices:
+            code.put_xdecref_memoryviewslice("p->%s" % entry.cname,
+                                             have_gil=True)
+
+        base_type = scope.parent_type.base_type
+        if base_type:
+            code.putln("#if PY_VERSION_HEX >= 0x030400a1")
+            # need to call base_type method
+            tp_finalize = TypeSlots.get_base_slot_function(
+                scope, TypeSlots.FinaliserSlot())
+            if tp_finalize is not None:
+                code.putln("%s(o);" % tp_finalize)
+            elif base_type.is_builtin_type:
+                code.putln("if (%s->tp_finalize) %s->tp_finalize(o);" % (
+                    base_type.typeptr_cname, base_type.typeptr_cname))
+            else:
+                # This is an externally defined type.  Calling through the
+                # cimported base type pointer directly interacts badly with
+                # the module cleanup, which may already have cleared it.
+                # In that case, fall back to traversing the type hierarchy.
+                base_cname = base_type.typeptr_cname
+                code.putln("if (likely(%s && %s->tp_finalize)) %s->tp_finalize(o); "
+                           "else __Pyx_call_next_tp_finalize(o, %s);" % (
+                    base_cname, base_cname, base_cname, slot_func_cname))
+                code.globalstate.use_utility_code(
+                    UtilityCode.load_cached("CallNextTpFinalize", "ExtensionTypes.c"))
+            code.putln("#endif")
+
+        code.putln("}")
 
     def generate_traverse_function(self, scope, code, cclass_entry):
         tp_slot = TypeSlots.GCDependentSlot("tp_traverse")
index a1bc497..7bb019d 100644 (file)
@@ -832,25 +832,6 @@ class Scope(object):
     def add_include_file(self, filename):
         self.outer_scope.add_include_file(filename)
 
-    def get_refcounted_entries(self, include_weakref=False,
-                               include_gc_simple=True):
-        py_attrs = []
-        py_buffers = []
-        memoryview_slices = []
-
-        for entry in self.var_entries:
-            if entry.type.is_pyobject:
-                if include_weakref or entry.name != "__weakref__":
-                    if include_gc_simple or not entry.type.is_gc_simple:
-                        py_attrs.append(entry)
-            elif entry.type == PyrexTypes.c_py_buffer_type:
-                py_buffers.append(entry)
-            elif entry.type.is_memoryviewslice:
-                memoryview_slices.append(entry)
-
-        have_entries = py_attrs or py_buffers or memoryview_slices
-        return have_entries, (py_attrs, py_buffers, memoryview_slices)
-
 
 class PreImportScope(Scope):
 
@@ -1804,6 +1785,33 @@ class CClassScope(ClassScope):
             return not self.parent_type.is_gc_simple
         return False
 
+    def get_refcounted_entries(self, include_weakref=False,
+                               include_gc_simple=True):
+        py_attrs = []
+        py_buffers = []
+        memoryview_slices = []
+
+        for entry in self.var_entries:
+            if entry.type.is_pyobject:
+                if include_weakref or entry.name != "__weakref__":
+                    if include_gc_simple or not entry.type.is_gc_simple:
+                        py_attrs.append(entry)
+            elif entry.type == PyrexTypes.c_py_buffer_type:
+                py_buffers.append(entry)
+            elif entry.type.is_memoryviewslice:
+                memoryview_slices.append(entry)
+
+        have_entries = py_attrs or py_buffers or memoryview_slices
+        return have_entries, (py_attrs, py_buffers, memoryview_slices)
+
+    def needs_finalisation(self):
+        if self.lookup_here("__dealloc__"):
+            return True
+        has_gc_entries, _ = self.get_refcounted_entries(include_weakref=True)
+        if has_gc_entries:
+            return True
+        return False
+
     def declare_var(self, name, type, pos,
                     cname = None, visibility = 'private',
                     api = 0, in_pxd = 0, is_cdef = 0):
@@ -1865,7 +1873,6 @@ class CClassScope(ClassScope):
             self.namespace_cname = "(PyObject *)%s" % self.parent_type.typeptr_cname
             return entry
 
-
     def declare_pyfunction(self, name, pos, allow_redefine=False):
         # Add an entry for a method.
         if name in ('__eq__', '__ne__', '__lt__', '__gt__', '__le__', '__ge__'):
@@ -2033,6 +2040,7 @@ class CClassScope(ClassScope):
             if base_entry.utility_code:
                 entry.utility_code = base_entry.utility_code
 
+
 class CppClassScope(Scope):
     #  Namespace of a C++ class.
 
index e3b7a6d..2038338 100644 (file)
@@ -331,6 +331,21 @@ class GCDependentSlot(InternalMethodSlot):
         return InternalMethodSlot.slot_code(self, scope)
 
 
+class FinaliserSlot(InternalMethodSlot):
+    """
+    Descriptor for tp_finalize().
+    """
+    def __init__(self):
+        InternalMethodSlot.__init__(
+            self, 'tp_finalize',
+            ifdef="PY_VERSION_HEX >= 0x030400a1")
+
+    def slot_code(self, scope):
+        if not scope.needs_finalisation():
+            return '0'
+        return InternalMethodSlot.slot_code(self, scope)
+
+
 class ConstructorSlot(InternalMethodSlot):
     #  Descriptor for tp_new and tp_dealloc.
 
@@ -788,8 +803,7 @@ slot_table = (
     EmptySlot("tp_weaklist"),
     EmptySlot("tp_del"),
     EmptySlot("tp_version_tag", ifdef="PY_VERSION_HEX >= 0x02060000"),
-    # TODO: change __dealloc__ to be called by tp_finalize (PEP 442)
-    EmptySlot("tp_finalize", ifdef="PY_VERSION_HEX >= 0x03040a00"),
+    FinaliserSlot(),  # 'tp_finalize'
 )
 
 #------------------------------------------------------------------------------------------
index 423ed96..f1fcd84 100644 (file)
@@ -16,6 +16,27 @@ static void __Pyx_call_next_tp_dealloc(PyObject* obj, destructor current_tp_deal
         type->tp_dealloc(obj);
 }
 
+/////////////// CallNextTpFinalize.proto ///////////////
+
+#if PY_VERSION_HEX >= 0x030400a1
+static void __Pyx_call_next_tp_finalize(PyObject* obj, destructor current_tp_finalize);
+#endif
+
+/////////////// CallNextTpFinalize ///////////////
+
+#if PY_VERSION_HEX >= 0x030400a1
+static void __Pyx_call_next_tp_finalize(PyObject* obj, destructor current_tp_finalize) {
+    PyTypeObject* type = Py_TYPE(obj);
+    /* try to find the first parent type that has a different tp_finalize() function */
+    while (type && type->tp_finalize != current_tp_finalize)
+        type = type->tp_base;
+    while (type && (!type->tp_finalize || type->tp_finalize == current_tp_finalize))
+        type = type->tp_base;
+    if (type)
+        type->tp_finalize(obj);
+}
+#endif
+
 /////////////// CallNextTpTraverse.proto ///////////////
 
 static int __Pyx_call_next_tp_traverse(PyObject* obj, visitproc v, void *a, traverseproc current_tp_traverse);