Fix C++ template subclassing.
authorRobert Bradshaw <robertwb@math.washington.edu>
Tue, 14 Aug 2012 18:43:28 +0000 (11:43 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Tue, 14 Aug 2012 18:43:28 +0000 (11:43 -0700)
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py
Cython/Compiler/Symtab.py
tests/run/cpp_template_subclasses.pyx [new file with mode: 0644]
tests/run/cpp_template_subclasses_helper.h [new file with mode: 0644]

index 75b3bfa..23c8eca 100644 (file)
@@ -1170,7 +1170,7 @@ class CppClassNode(CStructOrUnionDefNode):
     #  in_pxd        boolean
     #  attributes    [CVarDefNode] or None
     #  entry         Entry
-    #  base_classes  [string]
+    #  base_classes  [CBaseTypeNode]
     #  templates     [string] or None
 
     def declare(self, env):
@@ -1185,16 +1185,8 @@ class CppClassNode(CStructOrUnionDefNode):
     def analyse_declarations(self, env):
         scope = None
         if self.attributes is not None:
-            scope = CppClassScope(self.name, env)
-        base_class_types = []
-        for base_class_name in self.base_classes:
-            base_class_entry = env.lookup(base_class_name)
-            if base_class_entry is None:
-                error(self.pos, "'%s' not found" % base_class_name)
-            elif not base_class_entry.is_type or not base_class_entry.type.is_cpp_class:
-                error(self.pos, "'%s' is not a cpp class type" % base_class_name)
-            else:
-                base_class_types.append(base_class_entry.type)
+            scope = CppClassScope(self.name, env, templates = self.templates)
+        base_class_types = [b.analyse(scope) for b in self.base_classes]
         if self.templates is None:
             template_types = None
         else:
index a6dd21f..2ec8c49 100644 (file)
@@ -3041,10 +3041,10 @@ def p_cpp_class_definition(s, pos,  ctx):
         templates = None
     if s.sy == '(':
         s.next()
-        base_classes = [p_dotted_name(s, False)[2]]
+        base_classes = [p_c_base_type(s, templates = templates)]
         while s.sy == ',':
             s.next()
-            base_classes.append(p_dotted_name(s, False)[2])
+            base_classes.append(p_c_base_type(s, templates = templates))
         s.expect(')')
     else:
         base_classes = []
index 41c5637..be513ee 100644 (file)
@@ -514,10 +514,6 @@ class Scope(object):
             if templates or entry.type.templates:
                 if templates != entry.type.templates:
                     error(pos, "Template parameters do not match previous declaration")
-        if templates is not None and entry.type.scope is not None:
-            for T in templates:
-                template_entry = entry.type.scope.declare(T.name, T.name, T, None, 'extern')
-                template_entry.is_type = 1
 
         def declare_inherited_attributes(entry, base_classes):
             for base_class in base_classes:
@@ -1993,10 +1989,15 @@ class CppClassScope(Scope):
 
     default_constructor = None
 
-    def __init__(self, name, outer_scope):
+    def __init__(self, name, outer_scope, templates=None):
         Scope.__init__(self, name, outer_scope, None)
         self.directives = outer_scope.directives
         self.inherited_var_entries = []
+        if templates is not None:
+            for T in templates:
+                template_entry = self.declare(
+                    T, T, PyrexTypes.TemplatePlaceholderType(T), None, 'extern')
+                template_entry.is_type = 1
 
     def declare_var(self, name, type, pos,
                     cname = None, visibility = 'extern',
diff --git a/tests/run/cpp_template_subclasses.pyx b/tests/run/cpp_template_subclasses.pyx
new file mode 100644 (file)
index 0000000..43fc50d
--- /dev/null
@@ -0,0 +1,116 @@
+# tag: cpp
+
+from cython.operator import dereference as deref
+from libcpp.pair cimport pair
+from libcpp.vector cimport vector
+
+cdef extern from "cpp_template_subclasses_helper.h":
+    cdef cppclass Base:
+        char* name()
+
+    cdef cppclass A[A1](Base):
+        A1 funcA(A1)
+
+    cdef cppclass B[B1, B2](A[B2]):
+        pair[B1, B2] funcB(B1, B2)
+
+    cdef cppclass C[C1](B[long, C1]):
+        C1 funcC(C1)
+
+    cdef cppclass D[D1](C[pair[D1, D1]]):
+        pass
+
+    cdef cppclass E(D[double]):
+        pass
+
+def testA(x):
+    """
+    >>> testA(10)
+    10.0
+    """
+    cdef Base *base
+    cdef A[double] *a = NULL
+    try:
+        a = new A[double]()
+        base = a
+        assert base.name() == b"A", base.name()
+        return a.funcA(x)
+    finally:
+        del a
+
+def testB(x, y):
+    """
+    >>> testB(1, 2)
+    >>> testB(1, 1.5)
+    """
+    cdef Base *base
+    cdef A[double] *a
+    cdef B[long, double] *b = NULL
+    try:
+        base = a = b = new B[long, double]()
+        assert base.name() == b"B", base.name()
+        assert a.funcA(y) == y
+        assert <object>b.funcB(x, y) == (x, y)
+    finally:
+        del b
+
+def testC(x, y):
+    """
+    >>> testC(37, [1, 37])
+    >>> testC(25, [1, 5, 25])
+    >>> testC(105, [1, 3, 5, 7, 15, 21, 35, 105])
+    """
+    cdef Base *base
+    cdef A[vector[long]] *a
+    cdef B[long, vector[long]] *b
+    cdef C[vector[long]] *c = NULL
+    try:
+        base = a = b = c = new C[vector[long]]()
+        assert base.name() == b"C", base.name()
+        assert <object>a.funcA(y) == y
+        assert <object>b.funcB(x, y) == (x, y)
+        assert <object>c.funcC(y) == y
+    finally:
+        del c
+
+def testD(x, y):
+    """
+    >>> testD(1, 1.0)
+    >>> testD(2, 0.5)
+    >>> testD(4, 0.25)
+    """
+    cdef Base *base
+    cdef A[pair[double, double]] *a
+    cdef B[long, pair[double, double]] *b
+    cdef C[pair[double, double]] *c
+    cdef D[double] *d = NULL
+    try:
+        base = a = b = c = d = new D[double]()
+        assert base.name() == b"D", base.name()
+        assert <object>a.funcA((y, y)) == (y, y)
+        assert <object>b.funcB(x, (y, y + 1)) == (x, (y, y + 1))
+        assert <object>c.funcC((y, y)) == (y, y)
+    finally:
+        del d
+
+def testE(x, y):
+    """
+    >>> testD(1, 1.0)
+    >>> testD(2, 0.5)
+    >>> testD(4, 0.25)
+    """
+    cdef Base *base
+    cdef A[pair[double, double]] *a
+    cdef B[long, pair[double, double]] *b
+    cdef C[pair[double, double]] *c
+    cdef D[double] *d
+    cdef E *e = NULL
+    try:
+        base = a = b = c = d = e = new E()
+        assert base.name() == b"E", base.name()
+        assert <object>a.funcA((y, y)) == (y, y)
+        assert <object>b.funcB(x, (y, y + 1)) == (x, (y, y + 1))
+        assert <object>c.funcC((y, y)) == (y, y)
+    finally:
+        del e
+
diff --git a/tests/run/cpp_template_subclasses_helper.h b/tests/run/cpp_template_subclasses_helper.h
new file mode 100644 (file)
index 0000000..a1266f1
--- /dev/null
@@ -0,0 +1,36 @@
+using namespace std;
+
+class Base {
+public:
+    virtual const char* name() { return "Base"; }
+};
+
+template <class A1>
+class A : public Base {
+public:
+    virtual const char* name() { return "A"; }
+    A1 funcA(A1 x) { return x; }
+};
+
+template <class B1, class B2>
+class B : public A<B2> {
+public:
+    virtual const char* name() { return "B"; }
+    pair<B1, B2> funcB(B1 x, B2 y) { return pair<B1, B2>(x, y); }
+};
+
+template <class C1>
+class C : public B<long, C1> {
+public:
+    virtual const char* name() { return "C"; }
+    C1 funcC(C1 x) { return x; }
+};
+
+template <class D1>
+class D : public C<pair<D1, D1> > {
+    virtual const char* name() { return "D"; }
+};
+
+class E : public D<double> {
+    virtual const char* name() { return "E"; }
+};