implement 'yield' inside of 'finally' clause
authorStefan Behnel <stefan_ml@behnel.de>
Wed, 18 Dec 2013 18:02:09 +0000 (19:02 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Wed, 18 Dec 2013 18:02:09 +0000 (19:02 +0100)
CHANGES.rst
Cython/Compiler/Nodes.py
tests/run/tryfinally.pyx

index 5da5a36..032df0c 100644 (file)
@@ -8,6 +8,8 @@ Cython Changelog
 Features added
 --------------
 
+* ``yield`` is supported in ``finally`` clauses.
+
 * The C code generated for finally blocks is duplicated for each exit
   case to allow for better optimisations by the C compiler.
 
index 90429b4..6b4faf5 100644 (file)
@@ -6348,6 +6348,7 @@ class TryFinallyStatNode(StatNode):
     def analyse_expressions(self, env):
         self.body = self.body.analyse_expressions(env)
         self.finally_clause = self.finally_clause.analyse_expressions(env)
+        self.func_return_type = env.return_type
         return self
 
     nogil_check = Node.gil_error
@@ -6396,20 +6397,8 @@ class TryFinallyStatNode(StatNode):
             if self.is_try_finally_in_nogil:
                 code.declare_gilstate()
 
-            code.putln("PyObject *%s, *%s, *%s;" % Naming.exc_vars)
-            if needs_success_cleanup:
-                code.putln("int %s;" % Naming.exc_lineno_name)
-            exc_var_init_zero = ''.join(
-                ["%s = 0; " % var for var in Naming.exc_vars])
-            if needs_success_cleanup:
-                exc_var_init_zero += '%s = 0;' % Naming.exc_lineno_name
-        else:
-            exc_var_init_zero = None
-
         if not self.body.is_terminator:
             code.putln('/*normal exit:*/{')
-            if exc_var_init_zero:
-                code.putln(exc_var_init_zero)
             fresh_finally_clause().generate_execution_code(code)
             if not self.finally_clause.is_terminator:
                 code.put_goto(catch_label)
@@ -6417,10 +6406,22 @@ class TryFinallyStatNode(StatNode):
 
         if preserve_error:
             code.putln('/*exception exit:*/{')
-            code.put('%s: ' % new_error_label)
-            code.putln(exc_var_init_zero)
+            if needs_success_cleanup:
+                exc_lineno_cnames = tuple([
+                    code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
+                    for _ in range(2)])
+                exc_filename_cname = code.funcstate.allocate_temp(
+                    PyrexTypes.CPtrType(PyrexTypes.c_const_type(PyrexTypes.c_char_type)),
+                    manage_ref=False)
+            else:
+                exc_lineno_cnames = exc_filename_cname = None
+            exc_vars = tuple([
+                code.funcstate.allocate_temp(py_object_type, manage_ref=False)
+                for _ in range(3)])
+            code.put_label(new_error_label)
+            code.putln("%s = 0; %s = 0; %s = 0;" % exc_vars)
             self.put_error_catcher(
-                code, temps_to_clean_up, include_lineno=needs_success_cleanup)
+                code, temps_to_clean_up, exc_vars, exc_lineno_cnames, exc_filename_cname)
             finally_old_labels = code.all_new_labels()
 
             code.putln('{')
@@ -6428,18 +6429,27 @@ class TryFinallyStatNode(StatNode):
             code.putln('}')
 
             if needs_success_cleanup:
-                self.put_error_uncatcher(code)
+                self.put_error_uncatcher(code, exc_vars, exc_lineno_cnames, exc_filename_cname)
+                if exc_lineno_cnames:
+                    for cname in exc_lineno_cnames:
+                        code.funcstate.release_temp(cname)
+                if exc_filename_cname:
+                    code.funcstate.release_temp(exc_filename_cname)
                 code.put_goto(old_error_label)
 
             for new_label, old_label in zip(code.get_all_labels(), finally_old_labels):
                 if not code.label_used(new_label):
                     continue
                 code.put_label(new_label)
-                self.put_error_cleaner(code)
+                self.put_error_cleaner(code, exc_vars)
                 code.put_goto(old_label)
+
+            for cname in exc_vars:
+                code.funcstate.release_temp(cname)
             code.putln('}')
 
         code.set_all_labels(old_labels)
+        return_label = code.return_label
         for i, (new_label, old_label) in enumerate(zip(new_labels, old_labels)):
             if not code.label_used(new_label):
                 continue
@@ -6447,10 +6457,25 @@ class TryFinallyStatNode(StatNode):
                 continue  # handled above
 
             code.put('%s: ' % new_label)
-            if exc_var_init_zero:
-                code.putln(exc_var_init_zero)
             code.putln('{')
+            ret_temp = None
+            if old_label == return_label and not self.finally_clause.is_terminator:
+                # store away return value for later reuse
+                if self.func_return_type:
+                    # pure safety check, func_return_type should
+                    # always be set when return label is used
+                    ret_temp = code.funcstate.allocate_temp(
+                        self.func_return_type, manage_ref=False)
+                    code.putln("%s = %s;" % (ret_temp, Naming.retval_cname))
+                    if self.func_return_type.is_pyobject:
+                        code.putln("%s = 0;" % Naming.retval_cname)
             fresh_finally_clause().generate_execution_code(code)
+            if ret_temp:
+                code.putln("%s = %s;" % (Naming.retval_cname, ret_temp))
+                if self.func_return_type.is_pyobject:
+                    code.putln("%s = 0;" % ret_temp)
+                code.funcstate.release_temp(ret_temp)
+                ret_temp = None
             if not self.finally_clause.is_terminator:
                 code.put_goto(old_label)
             code.putln('}')
@@ -6464,7 +6489,8 @@ class TryFinallyStatNode(StatNode):
         self.body.generate_function_definitions(env, code)
         self.finally_clause.generate_function_definitions(env, code)
 
-    def put_error_catcher(self, code, temps_to_clean_up, include_lineno):
+    def put_error_catcher(self, code, temps_to_clean_up, exc_vars,
+                          exc_lineno_cnames, exc_filename_cname):
         code.globalstate.use_utility_code(restore_exception_utility_code)
 
         if self.is_try_finally_in_nogil:
@@ -6473,33 +6499,43 @@ class TryFinallyStatNode(StatNode):
         for temp_name, type in temps_to_clean_up:
             code.put_xdecref_clear(temp_name, type)
 
-        code.putln("__Pyx_ErrFetch(&%s, &%s, &%s);" % Naming.exc_vars)
-        if include_lineno:
-            code.putln("%s = %s;" % (Naming.exc_lineno_name, Naming.lineno_cname))
+        code.putln("__Pyx_ErrFetch(&%s, &%s, &%s);" % exc_vars)
+        for var in exc_vars:
+            code.put_xgotref(var)
+        if exc_lineno_cnames:
+            code.putln("%s = %s; %s = %s; %s = %s;" % (
+                exc_lineno_cnames[0], Naming.lineno_cname,
+                exc_lineno_cnames[1], Naming.clineno_cname,
+                exc_filename_cname, Naming.filename_cname))
 
         if self.is_try_finally_in_nogil:
             code.put_release_ensured_gil()
 
-    def put_error_uncatcher(self, code):
+    def put_error_uncatcher(self, code, exc_vars, exc_lineno_cnames, exc_filename_cname):
         code.globalstate.use_utility_code(restore_exception_utility_code)
 
         if self.is_try_finally_in_nogil:
             code.put_ensure_gil(declare_gilstate=False)
 
-        code.putln("__Pyx_ErrRestore(%s, %s, %s);" % Naming.exc_vars)
-        code.putln("%s = %s;" % (Naming.lineno_cname, Naming.exc_lineno_name))
+        for var in exc_vars:
+            code.put_xgiveref(var)
+        code.putln("__Pyx_ErrRestore(%s, %s, %s);" % exc_vars)
 
         if self.is_try_finally_in_nogil:
             code.put_release_ensured_gil()
 
-        for var in Naming.exc_vars:
-            code.putln("%s = 0;" % var)
+        code.putln("%s = 0; %s = 0; %s = 0;" % exc_vars)
+        if exc_lineno_cnames:
+            code.putln("%s = %s; %s = %s; %s = %s;" % (
+                Naming.lineno_cname, exc_lineno_cnames[0],
+                Naming.clineno_cname, exc_lineno_cnames[1],
+                Naming.filename_cname, exc_filename_cname))
 
-    def put_error_cleaner(self, code):
+    def put_error_cleaner(self, code, exc_vars):
         if self.is_try_finally_in_nogil:
             code.put_ensure_gil(declare_gilstate=False)
-        for var in Naming.exc_vars:
-            code.putln("Py_XDECREF(%s);" % var)
+        for var in exc_vars:
+            code.put_xdecref_clear(var, py_object_type)
         if self.is_try_finally_in_nogil:
             code.put_release_ensured_gil()
 
index a9d8cdf..49ac5b7 100644 (file)
@@ -3,6 +3,12 @@
 
 cimport cython
 
+try:
+    next
+except NameError:
+    def next(it): return it.next()
+
+
 def finally_except():
     """
     >>> try:
@@ -161,3 +167,85 @@ def empty_try_in_except_raise(raise_in_finally):
             if raise_in_finally:
                 raise TypeError('OLA')
         raise
+
+
+def try_all_cases(x):
+    """
+    >>> try_all_cases(None)
+    2
+    >>> try_all_cases('break')
+    4
+    >>> try_all_cases('raise')
+    Traceback (most recent call last):
+    ValueError
+    >>> try_all_cases('return')
+    3
+    >>> try_all_cases('tryraise')
+    Traceback (most recent call last):
+    TypeError
+    >>> try_all_cases('trybreak')
+    4
+    """
+    for i in range(3):
+        try:
+            if i == 0:
+                pass
+            elif i == 1:
+                continue
+            elif x == 'trybreak':
+                break
+            elif x == 'tryraise':
+                raise TypeError()
+            else:
+                return 2
+        finally:
+            if x == 'raise':
+                raise ValueError()
+            elif x == 'break':
+                break
+            elif x == 'return':
+                return 3
+    return 4
+
+
+def finally_yield(x):
+    """
+    >>> g = finally_yield(None)
+    >>> next(g)  # 1
+    1
+    >>> next(g)  # 2
+    1
+    >>> next(g)  # 3
+    Traceback (most recent call last):
+    StopIteration
+
+    >>> g = finally_yield('raise')
+    >>> next(g)  # raise 1
+    1
+    >>> next(g)  # raise 2
+    1
+    >>> next(g)  # raise 3
+    Traceback (most recent call last):
+    TypeError
+
+    >>> g = finally_yield('break')
+    >>> next(g)   # break 1
+    1
+    >>> next(g)   # break 2
+    1
+    >>> next(g)   # break 3
+    Traceback (most recent call last):
+    StopIteration
+    """
+    for i in range(3):
+        try:
+            if i == 0:
+                continue
+            elif x == 'raise':
+                raise TypeError()
+            elif x == 'break':
+                break
+            else:
+                return
+        finally:
+            yield 1