fix type inference for overloaded C++ operators
authorStefan Behnel <stefan_ml@behnel.de>
Tue, 24 Jul 2012 13:57:54 +0000 (15:57 +0200)
committerStefan Behnel <stefan_ml@behnel.de>
Tue, 24 Jul 2012 13:57:54 +0000 (15:57 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
Cython/Compiler/PyrexTypes.py
tests/run/cpp_operators.pyx
tests/run/cpp_type_inference.pyx [new file with mode: 0644]

index a0344be..8dcbf97 100755 (executable)
@@ -6988,6 +6988,13 @@ class UnopNode(ExprNode):
 
     def infer_type(self, env):
         operand_type = self.operand.infer_type(env)
+        if operand_type.is_cpp_class or operand_type.is_ptr:
+            cpp_type = operand_type.find_cpp_operation_type(self.operator)
+            if cpp_type is not None:
+                return cpp_type
+        return self.infer_unop_type(env, operand_type)
+
+    def infer_unop_type(self, env, operand_type):
         if operand_type.is_pyobject:
             return py_object_type
         else:
@@ -7042,30 +7049,23 @@ class UnopNode(ExprNode):
         self.type = PyrexTypes.error_type
 
     def analyse_cpp_operation(self, env):
-        type = self.operand.type
-        if type.is_ptr:
-            type = type.base_type
-        function = type.scope.lookup("operator%s" % self.operator)
-        if not function:
-            error(self.pos, "'%s' operator not defined for %s"
-                % (self.operator, type))
+        cpp_type = self.operand.type.find_cpp_operation_type(self.operator)
+        if cpp_type is None:
+            error(self.pos, "'%s' operator not defined for %s" % (
+                self.operator, type))
             self.type_error()
             return
-        func_type = function.type
-        if func_type.is_ptr:
-            func_type = func_type.base_type
-        self.type = func_type.return_type
+        self.type = cpp_type
 
 
-class NotNode(ExprNode):
+class NotNode(UnopNode):
     #  'not' operator
     #
     #  operand   ExprNode
+    operator = '!'
 
     type = PyrexTypes.c_bint_type
 
-    subexprs = ['operand']
-
     def calculate_constant_result(self):
         self.constant_result = not self.operand.constant_result
 
@@ -7076,23 +7076,19 @@ class NotNode(ExprNode):
         except Exception, e:
             self.compile_time_value_error(e)
 
-    def infer_type(self, env):
+    def infer_unop_type(self, env, operand_type):
         return PyrexTypes.c_bint_type
 
     def analyse_types(self, env):
         self.operand.analyse_types(env)
-        if self.operand.type.is_cpp_class:
-            type = self.operand.type
-            function = type.scope.lookup("operator!")
-            if not function:
-                error(self.pos, "'!' operator not defined for %s"
-                    % (type))
+        operand_type = self.operand.type
+        if operand_type.is_cpp_class:
+            cpp_type = operand_type.find_cpp_operation_type(self.operator)
+            if not cpp_type:
+                error(self.pos, "'!' operator not defined for %s" % operand_type)
                 self.type = PyrexTypes.error_type
                 return
-            func_type = function.type
-            if func_type.is_ptr:
-                func_type = func_type.base_type
-            self.type = func_type.return_type
+            self.type = cpp_type
         else:
             self.operand = self.operand.coerce_to_boolean(env)
 
@@ -7181,6 +7177,12 @@ class DereferenceNode(CUnopNode):
 
     operator = '*'
 
+    def infer_unop_type(self, env, operand_type):
+        if operand_type.is_ptr:
+            return operand_type.base_type
+        else:
+            return PyrexTypes.error_type
+
     def analyse_c_operation(self, env):
         if self.operand.type.is_ptr:
             self.type = self.operand.type.base_type
@@ -7213,19 +7215,23 @@ def inc_dec_constructor(is_prefix, operator):
     return lambda pos, **kwds: DecrementIncrementNode(pos, is_prefix=is_prefix, operator=operator, **kwds)
 
 
-class AmpersandNode(ExprNode):
+class AmpersandNode(CUnopNode):
     #  The C address-of operator.
     #
     #  operand  ExprNode
+    operator = '&'
 
-    subexprs = ['operand']
-
-    def infer_type(self, env):
-        return PyrexTypes.c_ptr_type(self.operand.infer_type(env))
+    def infer_unop_type(self, env, operand_type):
+        return PyrexTypes.c_ptr_type(operand_type)
 
     def analyse_types(self, env):
         self.operand.analyse_types(env)
         argtype = self.operand.type
+        if argtype.is_cpp_class:
+            cpp_type = argtype.find_cpp_operation_type(self.operator)
+            if cpp_type is not None:
+                self.type = cpp_type
+                return
         if not (argtype.is_cfunction or argtype.is_reference or self.operand.is_addressable()):
             if argtype.is_memoryviewslice:
                 self.error("Cannot take address of memoryview slice")
@@ -7932,6 +7938,16 @@ class CBinopNode(BinopNode):
             self.operator,
             self.operand2.result())
 
+    def compute_c_result_type(self, type1, type2):
+        cpp_type = None
+        if type1.is_cpp_class or type1.is_ptr:
+            cpp_type = type1.find_cpp_operation_type(self.operator, type2)
+        # FIXME: handle the reversed case?
+        #if cpp_type is None and (type2.is_cpp_class or type2.is_ptr):
+        #    cpp_type = type2.find_cpp_operation_type(self.operator, type1)
+        # FIXME: do we need to handle other cases here?
+        return cpp_type
+
 
 def c_binop_constructor(operator):
     def make_binop_node(pos, **operands):
index ccca385..978f556 100644 (file)
@@ -2966,11 +2966,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
             return node
         if not node.operand.is_literal:
             return node
-        if isinstance(node.operand, ExprNodes.BoolNode):
-            return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
+        if isinstance(node, ExprNodes.NotNode):
+            return ExprNodes.BoolNode(node.pos, value = bool(node.constant_result),
+                                      constant_result = bool(node.constant_result))
+        elif isinstance(node.operand, ExprNodes.BoolNode):
+            return ExprNodes.IntNode(node.pos, value = str(int(node.constant_result)),
                                      type = PyrexTypes.c_int_type,
-                                     constant_result = node.constant_result)
-        if node.operator == '+':
+                                     constant_result = int(node.constant_result))
+        elif node.operator == '+':
             return self._handle_UnaryPlusNode(node)
         elif node.operator == '-':
             return self._handle_UnaryMinusNode(node)
index 3d3d5b7..1bb5529 100755 (executable)
@@ -2292,6 +2292,11 @@ class CPtrType(CPointerBaseType):
     def invalid_value(self):
         return "1"
 
+    def find_cpp_operation_type(self, operator, operand_type=None):
+        if self.base_type.is_cpp_class:
+            return self.base_type.find_cpp_operation_type(operator, operand_type=None)
+        return None
+
 class CNullPtrType(CPtrType):
 
     is_null_ptr = 1
@@ -3164,6 +3169,19 @@ class CppClassType(CType):
     def attributes_known(self):
         return self.scope is not None
 
+    def find_cpp_operation_type(self, operator, operand_type=None):
+        operands = [self]
+        if operand_type is not None:
+            operands.append(operand_type)
+        # pos == None => no errors
+        operator_entry = self.scope.lookup_operator_for_types(None, operator, operands)
+        if not operator_entry:
+            return None
+        func_type = operator_entry.type
+        if func_type.is_ptr:
+            func_type = func_type.base_type
+        return func_type.return_type
+
 
 class TemplatePlaceholderType(CType):
 
index e214ea4..99133d8 100644 (file)
 # tag: cpp
 
+from cython cimport typeof
+
 cimport cython.operator
 from cython.operator cimport dereference as deref
 
-cdef out(s):
-    print s.decode('ASCII')
+from libc.string cimport const_char
+
+cdef out(s, result_type=None):
+    print '%s [%s]' % (s.decode('ascii'), result_type)
 
 cdef extern from "cpp_operators_helper.h":
     cdef cppclass TestOps:
 
-        char* operator+()
-        char* operator-()
-        char* operator*()
-        char* operator~()
-        char* operator!()
-
-        char* operator++()
-        char* operator--()
-        char* operator++(int)
-        char* operator--(int)
-
-        char* operator+(int)
-        char* operator-(int)
-        char* operator*(int)
-        char* operator/(int)
-        char* operator%(int)
-
-        char* operator|(int)
-        char* operator&(int)
-        char* operator^(int)
-        char* operator,(int)
-
-        char* operator<<(int)
-        char* operator>>(int)
-
-        char* operator==(int)
-        char* operator!=(int)
-        char* operator>=(int)
-        char* operator<=(int)
-        char* operator>(int)
-        char* operator<(int)
-
-        char* operator[](int)
-        char* operator()(int)
+        const_char* operator+()
+        const_char* operator-()
+        const_char* operator*()
+        const_char* operator~()
+        const_char* operator!()
+
+        const_char* operator++()
+        const_char* operator--()
+        const_char* operator++(int)
+        const_char* operator--(int)
+
+        const_char* operator+(int)
+        const_char* operator-(int)
+        const_char* operator*(int)
+        const_char* operator/(int)
+        const_char* operator%(int)
+
+        const_char* operator|(int)
+        const_char* operator&(int)
+        const_char* operator^(int)
+        const_char* operator,(int)
+
+        const_char* operator<<(int)
+        const_char* operator>>(int)
+
+        const_char* operator==(int)
+        const_char* operator!=(int)
+        const_char* operator>=(int)
+        const_char* operator<=(int)
+        const_char* operator>(int)
+        const_char* operator<(int)
+
+        const_char* operator[](int)
+        const_char* operator()(int)
 
 def test_unops():
     """
     >>> test_unops()
-    unary +
-    unary -
-    unary ~
-    unary *
-    unary !
+    unary + [const_char *]
+    unary - [const_char *]
+    unary ~ [const_char *]
+    unary * [const_char *]
+    unary ! [const_char *]
     """
     cdef TestOps* t = new TestOps()
-    out(+t[0])
-    out(-t[0])
-    out(~t[0])
-    out(deref(t[0]))
-    out(not t[0])
+    out(+t[0], typeof(+t[0]))
+    out(-t[0], typeof(-t[0]))
+    out(~t[0], typeof(~t[0]))
+    x = deref(t[0])
+    out(x, typeof(x))
+    out(not t[0], typeof(not t[0]))
     del t
 
 def test_incdec():
     """
     >>> test_incdec()
-    unary ++
-    unary --
-    post ++
-    post --
+    unary ++ [const_char *]
+    unary -- [const_char *]
+    post ++ [const_char *]
+    post -- [const_char *]
     """
     cdef TestOps* t = new TestOps()
-    out(cython.operator.preincrement(t[0]))
-    out(cython.operator.predecrement(t[0]))
-    out(cython.operator.postincrement(t[0]))
-    out(cython.operator.postdecrement(t[0]))
+    a = cython.operator.preincrement(t[0])
+    out(a, typeof(a))
+    b = cython.operator.predecrement(t[0])
+    out(b, typeof(b))
+    c = cython.operator.postincrement(t[0])
+    out(c, typeof(c))
+    d = cython.operator.postdecrement(t[0])
+    out(d, typeof(d))
     del t
 
 def test_binop():
     """
     >>> test_binop()
-    binary +
-    binary -
-    binary *
-    binary /
-    binary %
-    binary &
-    binary |
-    binary ^
-    binary <<
-    binary >>
-    binary COMMA
+    binary + [const_char *]
+    binary - [const_char *]
+    binary * [const_char *]
+    binary / [const_char *]
+    binary % [const_char *]
+    binary & [const_char *]
+    binary | [const_char *]
+    binary ^ [const_char *]
+    binary << [const_char *]
+    binary >> [const_char *]
+    binary COMMA [const_char *]
     """
     cdef TestOps* t = new TestOps()
-    out(t[0] + 1)
-    out(t[0] - 1)
-    out(t[0] * 1)
-    out(t[0] / 1)
-    out(t[0] % 1)
+    out(t[0] + 1, typeof(t[0] + 1))
+    out(t[0] - 1, typeof(t[0] - 1))
+    out(t[0] * 1, typeof(t[0] * 1))
+    out(t[0] / 1, typeof(t[0] / 1))
+    out(t[0] % 1, typeof(t[0] % 1))
 
-    out(t[0] & 1)
-    out(t[0] | 1)
-    out(t[0] ^ 1)
+    out(t[0] & 1, typeof(t[0] & 1))
+    out(t[0] | 1, typeof(t[0] | 1))
+    out(t[0] ^ 1, typeof(t[0] ^ 1))
 
-    out(t[0] << 1)
-    out(t[0] >> 1)
+    out(t[0] << 1, typeof(t[0] << 1))
+    out(t[0] >> 1, typeof(t[0] >> 1))
 
-    out(cython.operator.comma(t[0], 1))
+    x = cython.operator.comma(t[0], 1)
+    out(x, typeof(x))
     del t
 
 def test_cmp():
     """
     >>> test_cmp()
-    binary ==
-    binary !=
-    binary >=
-    binary >
-    binary <=
-    binary <
+    binary == [const_char *]
+    binary != [const_char *]
+    binary >= [const_char *]
+    binary > [const_char *]
+    binary <= [const_char *]
+    binary < [const_char *]
     """
     cdef TestOps* t = new TestOps()
-    out(t[0] == 1)
-    out(t[0] != 1)
-    out(t[0] >= 1)
-    out(t[0] > 1)
-    out(t[0] <= 1)
-    out(t[0] < 1)
+    out(t[0] == 1, typeof(t[0] == 1))
+    out(t[0] != 1, typeof(t[0] != 1))
+    out(t[0] >= 1, typeof(t[0] >= 1))
+    out(t[0] > 1, typeof(t[0] > 1))
+    out(t[0] <= 1, typeof(t[0] <= 1))
+    out(t[0] < 1, typeof(t[0] < 1))
     del t
 
 def test_index_call():
     """
     >>> test_index_call()
-    binary []
-    binary ()
+    binary [] [const_char *]
+    binary () [const_char *]
     """
     cdef TestOps* t = new TestOps()
-    out(t[0][100])
-    out(t[0](100))
+    out(t[0][100], typeof(t[0][100]))
+    out(t[0](100), typeof(t[0](100)))
     del t
diff --git a/tests/run/cpp_type_inference.pyx b/tests/run/cpp_type_inference.pyx
new file mode 100644 (file)
index 0000000..264edda
--- /dev/null
@@ -0,0 +1,24 @@
+# tag: cpp
+
+from cython cimport typeof
+
+from cython.operator cimport dereference as d
+from cython.operator cimport preincrement as incr
+from libcpp.vector cimport vector
+
+def test_reversed_vector_iteration(L):
+    """
+    >>> test_reversed_vector_iteration([1,2,3])
+    int: 3
+    int: 2
+    int: 1
+    int
+    """
+    cdef vector[int] v = L
+
+    it = v.rbegin()
+    while it != v.rend():
+        a = d(it)
+        incr(it)
+        print('%s: %s' % (typeof(a), a))
+    print(typeof(a))