infer type of items in tuple/list constants on indexing (especially where all items...
authorStefan Behnel <stefan_ml@behnel.de>
Fri, 6 Dec 2013 17:18:37 +0000 (18:18 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Fri, 6 Dec 2013 17:18:37 +0000 (18:18 +0100)
--HG--
extra : amend_source : 1e8bf2d512c244f48654f43e37283d800e5d961a

Cython/Compiler/ExprNodes.py
tests/run/type_inference.pyx

index 3fa4783..64ee623 100644 (file)
@@ -121,6 +121,32 @@ def check_negative_indices(*nodes):
                     "undefined", level=1)
 
 
+def infer_sequence_item_type(env, seq_node, index_node=None, seq_type=None):
+    if not seq_node.is_sequence_constructor:
+        if seq_type is None:
+            seq_type = seq_node.infer_type(env)
+        if seq_type is tuple_type:
+            # tuples are immutable => we can safely follow assignments
+            if seq_node.cf_state and len(seq_node.cf_state) == 1:
+                try:
+                    seq_node = seq_node.cf_state[0].rhs
+                except AttributeError:
+                    pass
+    if seq_node is not None and seq_node.is_sequence_constructor:
+        if index_node is not None and index_node.has_constant_result():
+            try:
+                item = seq_node.args[index_node.constant_result]
+            except (ValueError, TypeError, IndexError):
+                pass
+            else:
+                return item.infer_type(env)
+        # if we're lucky, all items have the same type
+        item_types = set([item.infer_type(env) for item in seq_node.args])
+        if len(item_types) == 1:
+            return item_types.pop()
+    return None
+
+
 class ExprNode(Node):
     #  subexprs     [string]     Class var holding names of subexpr node attrs
     #  type         PyrexType    Type of the result
@@ -2802,6 +2828,12 @@ class IndexNode(ExprNode):
                 #    return PyrexTypes.c_char_type
                 # Py2/3 return different types on indexing bytes objects
                 return py_object_type
+            elif base_type in (tuple_type, list_type):
+                # if base is a literal, take a look at its values
+                item_type = infer_sequence_item_type(
+                    env, self.base, self.index, seq_type=base_type)
+                if item_type is not None:
+                    return item_type
             elif base_type.is_ptr or base_type.is_array:
                 return base_type.base_type
 
@@ -3093,10 +3125,16 @@ class IndexNode(ExprNode):
                 elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
                     self.type = base_type
                 else:
+                    item_type = None
+                    if base_type in (list_type, tuple_type) and self.index.type.is_int:
+                        item_type = infer_sequence_item_type(
+                            env, self.base, self.index, seq_type=base_type)
+                    if item_type is None:
+                        item_type = py_object_type
+                    self.type = item_type
                     if base_type in (list_type, tuple_type, dict_type):
                         # do the None check explicitly (not in a helper) to allow optimising it away
                         self.base = self.base.as_none_safe_node("'NoneType' object is not subscriptable")
-                    self.type = py_object_type
             else:
                 if base_type.is_ptr or base_type.is_array:
                     self.type = base_type.base_type
index 8b220fe..c117cce 100644 (file)
@@ -110,7 +110,23 @@ def indexing():
     t = (4,5,6)
     assert typeof(t) == "tuple object", typeof(t)
     t1 = t[1]
-    assert typeof(t1) == "Python object", typeof(t1)
+    assert typeof(t1) == "long", typeof(t1)
+    t2 = ('abc', 'def', 'ghi')
+    assert typeof(t2) == "tuple object", typeof(t2)
+    t2_1 = t2[1]
+    assert typeof(t2_1) == "str object", typeof(t2_1)
+    t2_2 = t2[t[0]-3]
+    assert typeof(t2_2) == "str object", typeof(t2_2)
+    t5 = (b'abc', 'def', u'ghi')
+    t5_0 = t5[0]
+    assert typeof(t5_0) == "bytes object", typeof(t5_0)
+    t5_1 = t5[1]
+    assert typeof(t5_1) == "str object", typeof(t5_1)
+    t5_2 = t5[2]
+    assert typeof(t5_2) == "unicode object", typeof(t5_2)
+    t5_3 = t5[t[0]-3]
+    assert typeof(t5_3) == "Python object", typeof(t5_3)
+
 
 def multiple_assignments():
     """