return py_object_type
def analyse_cpp_types(self, env):
- begin = self.sequence.type.scope.lookup("begin")
- end = self.sequence.type.scope.lookup("end")
- if begin is None:
+ sequence_type = self.sequence.type
+ if sequence_type.is_ptr:
+ sequence_type = sequence_type.base_type
+ begin = sequence_type.scope.lookup("begin")
+ end = sequence_type.scope.lookup("end")
+ if (begin is None
+ or not begin.type.is_ptr
+ or not begin.type.base_type.is_cfunction
+ or begin.type.base_type.args):
error(self.pos, "missing begin() on %s" % self.sequence.type)
self.type = error_type
return
- if end is None:
+ if (end is None
+ or not end.type.is_ptr
+ or not end.type.base_type.is_cfunction
+ or end.type.base_type.args):
error(self.pos, "missing end() on %s" % self.sequence.type)
self.type = error_type
return
iter_type = begin.type.base_type.return_type
if iter_type.is_cpp_class:
- # TODO(robertwb): Check argument types.
- if iter_type.scope.lookup("operator!=") is None:
+ if env.lookup_operator_for_types(
+ self.pos,
+ "!=",
+ [iter_type, end.type.base_type.return_type]) is None:
error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
- if iter_type.scope.lookup("operator++") is None:
+ if env.lookup_operator_for_types(self.pos, '++', [iter_type]) is None:
error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
- if iter_type.scope.lookup("operator*") is None:
+ if env.lookup_operator_for_types(self.pos, '*', [iter_type]) is None:
error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
self.type = iter_type
elif iter_type.is_ptr:
+ if not (iter_type == end.type.base_type.return_type):
+ error(self.pos, "incompatible types for begin() and end()")
self.type = iter_type
else:
error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
code.putln("if (%s < 0) break;" % self.counter_cname)
if sequence_type.is_cpp_class:
# TODO: Cache end() call?
- code.putln("if (%s == %s.end()) break;" % (
+ code.putln("if (!(%s != %s.end())) break;" % (
self.result(),
self.sequence.result()));
code.putln("%s = *%s;" % (
if iterator_type.is_ptr or iterator_type.is_array:
return iterator_type.base_type
elif iterator_type.is_cpp_class:
- item_type = iterator_type.scope.lookup("operator*").type.base_type.return_type
+ item_type = env.lookup_operator_for_types(self.pos, "*", [iterator_type]).type.base_type.return_type
if item_type.is_reference:
item_type = item_type.ref_base_type
return item_type
return None
return PyrexTypes.best_match(operands, function.all_alternatives())
+ def lookup_operator_for_types(self, pos, operator, types):
+ from Nodes import Node
+ class FakeOperand(Node):
+ pass
+ operands = [FakeOperand(pos, type=type) for type in types]
+ return self.lookup_operator(operator, operands)
+
def use_utility_code(self, new_code):
self.global_scope().use_utility_code(new_code)
from libcpp.vector cimport vector
+cdef extern from "cpp_iterators_simple.h":
+ cdef cppclass DoublePointerIter:
+ DoublePointerIter(double* start, int len)
+ double* begin()
+ double* end()
+
def test_vector(py_v):
"""
>>> test_vector([1, 2, 3])
v.push_back(&b)
v.push_back(&c)
return [item[0] for item in v]
+
+def test_custom():
+ """
+ >>> test_custom()
+ [1.0, 2.0, 3.0]
+ """
+ cdef double* values = [1, 2, 3]
+ cdef DoublePointerIter* iter
+ try:
+ iter = new DoublePointerIter(values, 3)
+ # TODO: It'd be nice to automatically dereference this in a way that
+ # would not conflict with the pointer slicing iteration.
+ return [x for x in iter[0]]
+ finally:
+ del iter