implement type inference across closures
authorStefan Behnel <stefan_ml@behnel.de>
Sun, 23 Dec 2012 18:56:42 +0000 (19:56 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Sun, 23 Dec 2012 18:56:42 +0000 (19:56 +0100)
Cython/Compiler/Symtab.py
Cython/Compiler/TypeInference.py
tests/run/cross_closure_type_inference.pyx [new file with mode: 0644]
tests/run/generator_type_inference.pyx

index 91e755d..a0d907c 100644 (file)
@@ -186,6 +186,7 @@ class Entry(object):
     from_cython_utility_code = None
     error_on_uninitialized = False
     cf_used = True
+    outer_entry = None
 
     def __init__(self, name, cname, type, pos = None, init = None):
         self.name = name
@@ -196,6 +197,7 @@ class Entry(object):
         self.overloaded_alternatives = []
         self.cf_assignments = []
         self.cf_references = []
+        self.inner_entries = []
 
     def __repr__(self):
         return "Entry(name=%s, type=%s)" % (self.name, self.type)
@@ -207,6 +209,22 @@ class Entry(object):
     def all_alternatives(self):
         return [self] + self.overloaded_alternatives
 
+    def all_entries(self):
+        """
+        Returns all entries for this entry, including the equivalent ones
+        in other closures.
+        """
+        if self.from_closure:
+            return self.outer_entry.all_entries()
+
+        entries = []
+        def collect_inner_entries(entry):
+            entries.append(entry)
+            for e in entry.inner_entries:
+                collect_inner_entries(e)
+        collect_inner_entries(self)
+        return entries
+
 
 class Scope(object):
     # name              string             Unqualified name
@@ -1524,6 +1542,7 @@ class LocalScope(Scope):
                 inner_entry.from_closure = True
                 inner_entry.is_declared_generic = entry.is_declared_generic
                 self.entries[name] = inner_entry
+                entry.inner_entries.append(inner_entry)
                 return inner_entry
         return entry
 
index 4411e0a..9b7af1d 100644 (file)
@@ -335,6 +335,9 @@ class PyObjectTypeInferer(object):
 class SimpleAssignmentTypeInferer(object):
     """
     Very basic type inference.
+
+    Note: in order to support cross-closure type inference, this must be
+    applies to nested scopes in top-down order.
     """
     # TODO: Implement a real type inference algorithm.
     # (Something more powerful than just extending this one...)
@@ -357,13 +360,11 @@ class SimpleAssignmentTypeInferer(object):
         ready_to_infer = []
         for name, entry in scope.entries.items():
             if entry.type is unspecified_type:
-                if entry.in_closure or entry.from_closure:
-                    # cross-closure type inference is not currently supported
-                    entry.type = py_object_type
-                    continue
+                entries = entry.all_entries()
                 all = set()
-                for assmt in entry.cf_assignments:
-                    all.update(assmt.type_dependencies(scope))
+                for e in entries:
+                    for assmt in e.cf_assignments:
+                        all.update(assmt.type_dependencies(e.scope))
                 if all:
                     dependancies_by_entry[entry] = all
                     for dep in all:
@@ -387,14 +388,20 @@ class SimpleAssignmentTypeInferer(object):
         while True:
             while ready_to_infer:
                 entry = ready_to_infer.pop()
-                types = [assmt.rhs.infer_type(scope)
-                         for assmt in entry.cf_assignments]
+                types = [
+                    assmt.rhs.infer_type(scope)
+                    for e in entry.all_entries()
+                    for assmt in e.cf_assignments
+                    ]
                 if types and Utils.all(types):
-                    entry.type = spanning_type(types, entry.might_overflow, entry.pos)
+                    entry_type = spanning_type(types, entry.might_overflow, entry.pos)
                 else:
                     # FIXME: raise a warning?
                     # print "No assignments", entry.pos, entry
-                    entry.type = py_object_type
+                    entry_type = py_object_type
+                # propagate entry type to all nested scopes
+                for e in entry.all_entries():
+                    e.type = entry_type
                 if verbose:
                     message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type))
                 resolve_dependancy(entry)
diff --git a/tests/run/cross_closure_type_inference.pyx b/tests/run/cross_closure_type_inference.pyx
new file mode 100644 (file)
index 0000000..fbfcb88
--- /dev/null
@@ -0,0 +1,89 @@
+# mode: run
+# tag: typeinference
+
+cimport cython
+
+
+def test_outer_inner_double():
+    """
+    >>> print(test_outer_inner_double())
+    double
+    """
+    x = 1.0
+    def inner():
+        nonlocal x
+        x = 2.0
+    inner()
+    assert x == 2.0
+    return cython.typeof(x)
+
+
+def test_outer_inner_double_int():
+    """
+    >>> print(test_outer_inner_double_int())
+    ('double', 'double')
+    """
+    x = 1.0
+    y = 2
+    def inner():
+        nonlocal x, y
+        x = 1
+        y = 2.0
+    inner()
+    return cython.typeof(x), cython.typeof(y)
+
+
+def test_outer_inner_pyarg():
+    """
+    >>> print(test_outer_inner_pyarg())
+    2
+    long
+    """
+    x = 1
+    def inner(y):
+        return x + y
+    print inner(1)
+    return cython.typeof(x)
+
+
+def test_outer_inner_carg():
+    """
+    >>> print(test_outer_inner_carg())
+    2.0
+    long
+    """
+    x = 1
+    def inner(double y):
+        return x + y
+    print inner(1)
+    return cython.typeof(x)
+
+
+def test_outer_inner_incompatible():
+    """
+    >>> print(test_outer_inner_incompatible())
+    Python object
+    """
+    x = 1.0
+    def inner():
+        nonlocal x
+        x = 'test'
+    inner()
+    return cython.typeof(x)
+
+
+def test_outer_inner2_double():
+    """
+    >>> print(test_outer_inner2_double())
+    double
+    """
+    x = 1.0
+    def inner1():
+        nonlocal x
+        x = 2
+    def inner2():
+        nonlocal x
+        x = 3.0
+    inner1()
+    inner2()
+    return cython.typeof(x)
index 6a5c90e..7148b8f 100644 (file)
@@ -31,12 +31,12 @@ def test_unicode_loop():
         print 2, cython.typeof(c)
         yield c
 
-def test_nonlocal_disables_inference():
+def test_with_nonlocal():
     """
-    >>> chars = list(test_nonlocal_disables_inference())
-    1 Python object
-    2 Python object
-    2 Python object
+    >>> chars = list(test_with_nonlocal())
+    1 Py_UCS4
+    2 Py_UCS4
+    2 Py_UCS4
     >>> len(chars)
     2
     >>> ''.join(chars) == 'ab'