Support C++ iterators in for..in loops.
authorRobert Bradshaw <robertwb@gmail.com>
Mon, 2 Jul 2012 19:40:37 +0000 (12:40 -0700)
committerRobert Bradshaw <robertwb@gmail.com>
Mon, 2 Jul 2012 21:07:57 +0000 (14:07 -0700)
Cython/Compiler/ExprNodes.py

index 4dfed75..d4c50d8 100755 (executable)
@@ -1982,6 +1982,8 @@ class IteratorNode(ExprNode):
                 not self.sequence.type.is_string:
             # C array iteration will be transformed later on
             self.type = self.sequence.type
+        elif self.sequence.type.is_cpp_class:
+            self.analyse_cpp_types(env)
         else:
             self.sequence = self.sequence.coerce_to_pyobject(env)
             if self.sequence.type is list_type or \
@@ -1995,9 +1997,46 @@ class IteratorNode(ExprNode):
         PyrexTypes.py_object_type, [
             PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None),
             ]))
-
+    def analyse_cpp_types(self, env):
+        begin = self.sequence.type.scope.lookup("begin")
+        end = self.sequence.type.scope.lookup("end")
+        if begin is None:
+            error(self.pos, "missing begin() on %s" % self.sequence.type)
+            self.type = error_type
+            return
+        if end is None:
+            error(self.pos, "missing end() on %s" % self.sequence.type)
+            self.type = error_type
+            return
+        iter_type = begin.type.base_type.return_type
+        if iter_type.is_cpp_class:
+            # TODO(robertwb): Check argument types.
+            if iter_type.scope.lookup("operator!=") is None:
+                error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
+                self.type = error_type
+                return
+            if iter_type.scope.lookup("operator++") is None:
+                error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type)
+                self.type = error_type
+                return
+            if iter_type.scope.lookup("operator*") is None:
+                error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type)
+                self.type = error_type
+                return
+            self.type = iter_type
+        elif iter_type.is_ptr:
+            self.type = iter_type
+        else:
+            error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
+            self.type = error_type
+            return
+    
     def generate_result_code(self, code):
         sequence_type = self.sequence.type
+        if sequence_type.is_cpp_class:
+            # TODO: Limit scope.
+            code.putln("%s = %s.begin();" % (self.result(), self.sequence.result()))
+            return
         if sequence_type.is_array or sequence_type.is_ptr:
             raise InternalError("for in carray slice not transformed")
         is_builtin_sequence = sequence_type is list_type or \
@@ -2080,7 +2119,17 @@ class IteratorNode(ExprNode):
         sequence_type = self.sequence.type
         if self.reversed:
             code.putln("if (%s < 0) break;" % self.counter_cname)
-        if sequence_type is list_type:
+        if sequence_type.is_cpp_class:
+            # TODO: Cache end() call?
+            code.putln("if (%s == %s.end()) break;" % (
+                            self.result(),
+                            self.sequence.result()));
+            code.putln("%s = *%s;" % (
+                            result_name,
+                            self.result()))
+            code.putln("++%s;" % self.result())
+            return
+        elif sequence_type is list_type:
             self.generate_next_sequence_item('List', result_name, code)
             return
         elif sequence_type is tuple_type:
@@ -2127,13 +2176,16 @@ class NextNode(AtomicExprNode):
     #
     #  iterator   IteratorNode
 
-    type = py_object_type
-
     def __init__(self, iterator):
         self.pos = iterator.pos
         self.iterator = iterator
-        if iterator.type.is_ptr or iterator.type.is_array:
-            self.type = iterator.type.base_type
+        iterator_type = iterator.type
+        if iterator_type.is_ptr or iterator_type.is_array:
+            self.type = iterator_type.base_type
+        elif iterator_type.is_cpp_class:
+            self.type = iterator_type.scope.lookup("operator*").type.base_type.return_type
+        else:
+            self.type = py_object_type
         self.is_temp = 1
 
     def generate_result_code(self, code):
@@ -2459,6 +2511,18 @@ class IndexNode(ExprNode):
             elif base_type.is_ptr or base_type.is_array:
                 return base_type.base_type
 
+        if base_type.is_cpp_class:
+            class FakeOperand:
+                def __init__(self, **kwds):
+                    self.__dict__.update(kwds)
+            operands = [
+                FakeOperand(pos=self.pos, type=base_type),
+                FakeOperand(pos=self.pos, type=index_type),
+            ]
+            index_func = env.lookup_operator('[]', operands)
+            if index_func is not None:
+                return index_func.type.base_type.return_type
+
         # may be slicing or indexing, we don't know
         if base_type in (unicode_type, str_type):
             # these types always returns their own type on Python indexing/slicing