From e440f7792a52145a27233746e5fb309d17893de8 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sun, 23 Dec 2012 19:56:42 +0100 Subject: [PATCH] implement type inference across closures --- Cython/Compiler/Symtab.py | 19 +++++++ Cython/Compiler/TypeInference.py | 27 +++++---- tests/run/cross_closure_type_inference.pyx | 89 ++++++++++++++++++++++++++++++ tests/run/generator_type_inference.pyx | 10 ++-- 4 files changed, 130 insertions(+), 15 deletions(-) create mode 100644 tests/run/cross_closure_type_inference.pyx diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 91e755d..a0d907c 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -186,6 +186,7 @@ class Entry(object): from_cython_utility_code = None error_on_uninitialized = False cf_used = True + outer_entry = None def __init__(self, name, cname, type, pos = None, init = None): self.name = name @@ -196,6 +197,7 @@ class Entry(object): self.overloaded_alternatives = [] self.cf_assignments = [] self.cf_references = [] + self.inner_entries = [] def __repr__(self): return "Entry(name=%s, type=%s)" % (self.name, self.type) @@ -207,6 +209,22 @@ class Entry(object): def all_alternatives(self): return [self] + self.overloaded_alternatives + def all_entries(self): + """ + Returns all entries for this entry, including the equivalent ones + in other closures. + """ + if self.from_closure: + return self.outer_entry.all_entries() + + entries = [] + def collect_inner_entries(entry): + entries.append(entry) + for e in entry.inner_entries: + collect_inner_entries(e) + collect_inner_entries(self) + return entries + class Scope(object): # name string Unqualified name @@ -1524,6 +1542,7 @@ class LocalScope(Scope): inner_entry.from_closure = True inner_entry.is_declared_generic = entry.is_declared_generic self.entries[name] = inner_entry + entry.inner_entries.append(inner_entry) return inner_entry return entry diff --git a/Cython/Compiler/TypeInference.py b/Cython/Compiler/TypeInference.py index 4411e0a..9b7af1d 100644 --- a/Cython/Compiler/TypeInference.py +++ b/Cython/Compiler/TypeInference.py @@ -335,6 +335,9 @@ class PyObjectTypeInferer(object): class SimpleAssignmentTypeInferer(object): """ Very basic type inference. + + Note: in order to support cross-closure type inference, this must be + applies to nested scopes in top-down order. """ # TODO: Implement a real type inference algorithm. # (Something more powerful than just extending this one...) @@ -357,13 +360,11 @@ class SimpleAssignmentTypeInferer(object): ready_to_infer = [] for name, entry in scope.entries.items(): if entry.type is unspecified_type: - if entry.in_closure or entry.from_closure: - # cross-closure type inference is not currently supported - entry.type = py_object_type - continue + entries = entry.all_entries() all = set() - for assmt in entry.cf_assignments: - all.update(assmt.type_dependencies(scope)) + for e in entries: + for assmt in e.cf_assignments: + all.update(assmt.type_dependencies(e.scope)) if all: dependancies_by_entry[entry] = all for dep in all: @@ -387,14 +388,20 @@ class SimpleAssignmentTypeInferer(object): while True: while ready_to_infer: entry = ready_to_infer.pop() - types = [assmt.rhs.infer_type(scope) - for assmt in entry.cf_assignments] + types = [ + assmt.rhs.infer_type(scope) + for e in entry.all_entries() + for assmt in e.cf_assignments + ] if types and Utils.all(types): - entry.type = spanning_type(types, entry.might_overflow, entry.pos) + entry_type = spanning_type(types, entry.might_overflow, entry.pos) else: # FIXME: raise a warning? # print "No assignments", entry.pos, entry - entry.type = py_object_type + entry_type = py_object_type + # propagate entry type to all nested scopes + for e in entry.all_entries(): + e.type = entry_type if verbose: message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type)) resolve_dependancy(entry) diff --git a/tests/run/cross_closure_type_inference.pyx b/tests/run/cross_closure_type_inference.pyx new file mode 100644 index 0000000..fbfcb88 --- /dev/null +++ b/tests/run/cross_closure_type_inference.pyx @@ -0,0 +1,89 @@ +# mode: run +# tag: typeinference + +cimport cython + + +def test_outer_inner_double(): + """ + >>> print(test_outer_inner_double()) + double + """ + x = 1.0 + def inner(): + nonlocal x + x = 2.0 + inner() + assert x == 2.0 + return cython.typeof(x) + + +def test_outer_inner_double_int(): + """ + >>> print(test_outer_inner_double_int()) + ('double', 'double') + """ + x = 1.0 + y = 2 + def inner(): + nonlocal x, y + x = 1 + y = 2.0 + inner() + return cython.typeof(x), cython.typeof(y) + + +def test_outer_inner_pyarg(): + """ + >>> print(test_outer_inner_pyarg()) + 2 + long + """ + x = 1 + def inner(y): + return x + y + print inner(1) + return cython.typeof(x) + + +def test_outer_inner_carg(): + """ + >>> print(test_outer_inner_carg()) + 2.0 + long + """ + x = 1 + def inner(double y): + return x + y + print inner(1) + return cython.typeof(x) + + +def test_outer_inner_incompatible(): + """ + >>> print(test_outer_inner_incompatible()) + Python object + """ + x = 1.0 + def inner(): + nonlocal x + x = 'test' + inner() + return cython.typeof(x) + + +def test_outer_inner2_double(): + """ + >>> print(test_outer_inner2_double()) + double + """ + x = 1.0 + def inner1(): + nonlocal x + x = 2 + def inner2(): + nonlocal x + x = 3.0 + inner1() + inner2() + return cython.typeof(x) diff --git a/tests/run/generator_type_inference.pyx b/tests/run/generator_type_inference.pyx index 6a5c90e..7148b8f 100644 --- a/tests/run/generator_type_inference.pyx +++ b/tests/run/generator_type_inference.pyx @@ -31,12 +31,12 @@ def test_unicode_loop(): print 2, cython.typeof(c) yield c -def test_nonlocal_disables_inference(): +def test_with_nonlocal(): """ - >>> chars = list(test_nonlocal_disables_inference()) - 1 Python object - 2 Python object - 2 Python object + >>> chars = list(test_with_nonlocal()) + 1 Py_UCS4 + 2 Py_UCS4 + 2 Py_UCS4 >>> len(chars) 2 >>> ''.join(chars) == 'ab' -- 2.7.4