from_cython_utility_code = None
error_on_uninitialized = False
cf_used = True
+ outer_entry = None
def __init__(self, name, cname, type, pos = None, init = None):
self.name = name
self.overloaded_alternatives = []
self.cf_assignments = []
self.cf_references = []
+ self.inner_entries = []
def __repr__(self):
return "Entry(name=%s, type=%s)" % (self.name, self.type)
def all_alternatives(self):
return [self] + self.overloaded_alternatives
+ def all_entries(self):
+ """
+ Returns all entries for this entry, including the equivalent ones
+ in other closures.
+ """
+ if self.from_closure:
+ return self.outer_entry.all_entries()
+
+ entries = []
+ def collect_inner_entries(entry):
+ entries.append(entry)
+ for e in entry.inner_entries:
+ collect_inner_entries(e)
+ collect_inner_entries(self)
+ return entries
+
class Scope(object):
# name string Unqualified name
inner_entry.from_closure = True
inner_entry.is_declared_generic = entry.is_declared_generic
self.entries[name] = inner_entry
+ entry.inner_entries.append(inner_entry)
return inner_entry
return entry
class SimpleAssignmentTypeInferer(object):
"""
Very basic type inference.
+
+ Note: in order to support cross-closure type inference, this must be
+ applies to nested scopes in top-down order.
"""
# TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...)
ready_to_infer = []
for name, entry in scope.entries.items():
if entry.type is unspecified_type:
- if entry.in_closure or entry.from_closure:
- # cross-closure type inference is not currently supported
- entry.type = py_object_type
- continue
+ entries = entry.all_entries()
all = set()
- for assmt in entry.cf_assignments:
- all.update(assmt.type_dependencies(scope))
+ for e in entries:
+ for assmt in e.cf_assignments:
+ all.update(assmt.type_dependencies(e.scope))
if all:
dependancies_by_entry[entry] = all
for dep in all:
while True:
while ready_to_infer:
entry = ready_to_infer.pop()
- types = [assmt.rhs.infer_type(scope)
- for assmt in entry.cf_assignments]
+ types = [
+ assmt.rhs.infer_type(scope)
+ for e in entry.all_entries()
+ for assmt in e.cf_assignments
+ ]
if types and Utils.all(types):
- entry.type = spanning_type(types, entry.might_overflow, entry.pos)
+ entry_type = spanning_type(types, entry.might_overflow, entry.pos)
else:
# FIXME: raise a warning?
# print "No assignments", entry.pos, entry
- entry.type = py_object_type
+ entry_type = py_object_type
+ # propagate entry type to all nested scopes
+ for e in entry.all_entries():
+ e.type = entry_type
if verbose:
message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type))
resolve_dependancy(entry)
--- /dev/null
+# mode: run
+# tag: typeinference
+
+cimport cython
+
+
+def test_outer_inner_double():
+ """
+ >>> print(test_outer_inner_double())
+ double
+ """
+ x = 1.0
+ def inner():
+ nonlocal x
+ x = 2.0
+ inner()
+ assert x == 2.0
+ return cython.typeof(x)
+
+
+def test_outer_inner_double_int():
+ """
+ >>> print(test_outer_inner_double_int())
+ ('double', 'double')
+ """
+ x = 1.0
+ y = 2
+ def inner():
+ nonlocal x, y
+ x = 1
+ y = 2.0
+ inner()
+ return cython.typeof(x), cython.typeof(y)
+
+
+def test_outer_inner_pyarg():
+ """
+ >>> print(test_outer_inner_pyarg())
+ 2
+ long
+ """
+ x = 1
+ def inner(y):
+ return x + y
+ print inner(1)
+ return cython.typeof(x)
+
+
+def test_outer_inner_carg():
+ """
+ >>> print(test_outer_inner_carg())
+ 2.0
+ long
+ """
+ x = 1
+ def inner(double y):
+ return x + y
+ print inner(1)
+ return cython.typeof(x)
+
+
+def test_outer_inner_incompatible():
+ """
+ >>> print(test_outer_inner_incompatible())
+ Python object
+ """
+ x = 1.0
+ def inner():
+ nonlocal x
+ x = 'test'
+ inner()
+ return cython.typeof(x)
+
+
+def test_outer_inner2_double():
+ """
+ >>> print(test_outer_inner2_double())
+ double
+ """
+ x = 1.0
+ def inner1():
+ nonlocal x
+ x = 2
+ def inner2():
+ nonlocal x
+ x = 3.0
+ inner1()
+ inner2()
+ return cython.typeof(x)
print 2, cython.typeof(c)
yield c
-def test_nonlocal_disables_inference():
+def test_with_nonlocal():
"""
- >>> chars = list(test_nonlocal_disables_inference())
- 1 Python object
- 2 Python object
- 2 Python object
+ >>> chars = list(test_with_nonlocal())
+ 1 Py_UCS4
+ 2 Py_UCS4
+ 2 Py_UCS4
>>> len(chars)
2
>>> ''.join(chars) == 'ab'