Better type checking for C++ iterators.
authorRobert Bradshaw <robertwb@gmail.com>
Mon, 2 Jul 2012 21:06:23 +0000 (14:06 -0700)
committerRobert Bradshaw <robertwb@gmail.com>
Mon, 2 Jul 2012 21:07:57 +0000 (14:07 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Symtab.py
tests/run/cpp_iterators.pyx
tests/run/cpp_iterators_simple.h [new file with mode: 0644]

index df46e08..f0d9c1d 100755 (executable)
@@ -2010,33 +2010,46 @@ class IteratorNode(ExprNode):
         return py_object_type
     
     def analyse_cpp_types(self, env):
-        begin = self.sequence.type.scope.lookup("begin")
-        end = self.sequence.type.scope.lookup("end")
-        if begin is None:
+        sequence_type = self.sequence.type
+        if sequence_type.is_ptr:
+            sequence_type = sequence_type.base_type
+        begin = sequence_type.scope.lookup("begin")
+        end = sequence_type.scope.lookup("end")
+        if (begin is None
+            or not begin.type.is_ptr
+            or not begin.type.base_type.is_cfunction
+            or begin.type.base_type.args):
             error(self.pos, "missing begin() on %s" % self.sequence.type)
             self.type = error_type
             return
-        if end is None:
+        if (end is None
+            or not end.type.is_ptr
+            or not end.type.base_type.is_cfunction
+            or end.type.base_type.args):
             error(self.pos, "missing end() on %s" % self.sequence.type)
             self.type = error_type
             return
         iter_type = begin.type.base_type.return_type
         if iter_type.is_cpp_class:
-            # TODO(robertwb): Check argument types.
-            if iter_type.scope.lookup("operator!=") is None:
+            if env.lookup_operator_for_types(
+                    self.pos,
+                    "!=",
+                    [iter_type, end.type.base_type.return_type]) is None:
                 error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
                 self.type = error_type
                 return
-            if iter_type.scope.lookup("operator++") is None:
+            if env.lookup_operator_for_types(self.pos, '++', [iter_type]) is None:
                 error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type)
                 self.type = error_type
                 return
-            if iter_type.scope.lookup("operator*") is None:
+            if env.lookup_operator_for_types(self.pos, '*', [iter_type]) is None:
                 error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type)
                 self.type = error_type
                 return
             self.type = iter_type
         elif iter_type.is_ptr:
+            if not (iter_type == end.type.base_type.return_type):
+                error(self.pos, "incompatible types for begin() and end()")
             self.type = iter_type
         else:
             error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
@@ -2133,7 +2146,7 @@ class IteratorNode(ExprNode):
             code.putln("if (%s < 0) break;" % self.counter_cname)
         if sequence_type.is_cpp_class:
             # TODO: Cache end() call?
-            code.putln("if (%s == %s.end()) break;" % (
+            code.putln("if (!(%s != %s.end())) break;" % (
                             self.result(),
                             self.sequence.result()));
             code.putln("%s = *%s;" % (
@@ -2198,7 +2211,7 @@ class NextNode(AtomicExprNode):
         if iterator_type.is_ptr or iterator_type.is_array:
             return iterator_type.base_type
         elif iterator_type.is_cpp_class:
-            item_type = iterator_type.scope.lookup("operator*").type.base_type.return_type
+            item_type = env.lookup_operator_for_types(self.pos, "*", [iterator_type]).type.base_type.return_type
             if item_type.is_reference:
                 item_type = item_type.ref_base_type
             return item_type
index 1139d07..6922b8d 100644 (file)
@@ -758,6 +758,13 @@ class Scope(object):
             return None
         return PyrexTypes.best_match(operands, function.all_alternatives())
 
+    def lookup_operator_for_types(self, pos, operator, types):
+        from Nodes import Node
+        class FakeOperand(Node):
+            pass
+        operands = [FakeOperand(pos, type=type) for type in types]
+        return self.lookup_operator(operator, operands)
+
     def use_utility_code(self, new_code):
         self.global_scope().use_utility_code(new_code)
 
index a07b1ee..ceb25a8 100644 (file)
@@ -2,6 +2,12 @@
 
 from libcpp.vector cimport vector
 
+cdef extern from "cpp_iterators_simple.h":
+    cdef cppclass DoublePointerIter:
+        DoublePointerIter(double* start, int len)
+        double* begin()
+        double* end()
+
 def test_vector(py_v):
     """
     >>> test_vector([1, 2, 3])
@@ -27,3 +33,18 @@ def test_ptrs():
     v.push_back(&b)
     v.push_back(&c)
     return [item[0] for item in v]
+
+def test_custom():
+    """
+    >>> test_custom()
+    [1.0, 2.0, 3.0]
+    """
+    cdef double* values = [1, 2, 3]
+    cdef DoublePointerIter* iter
+    try:
+        iter = new DoublePointerIter(values, 3)
+        # TODO: It'd be nice to automatically dereference this in a way that
+        # would not conflict with the pointer slicing iteration.
+        return [x for x in iter[0]]
+    finally:
+        del iter
diff --git a/tests/run/cpp_iterators_simple.h b/tests/run/cpp_iterators_simple.h
new file mode 100644 (file)
index 0000000..3a4b50e
--- /dev/null
@@ -0,0 +1,10 @@
+class DoublePointerIter {
+public:
+    DoublePointerIter(double* start, int len) : start_(start), len_(len) { }
+    double* begin() { return start_; }
+    double* end() { return start_ + len_; }
+private:
+    double* start_;
+    int len_;
+};
+