Allow declaring C++ classes in Cython files.
authorRobert Bradshaw <robertwb@gmail.com>
Tue, 21 Aug 2012 05:37:51 +0000 (22:37 -0700)
committerRobert Bradshaw <robertwb@gmail.com>
Tue, 21 Aug 2012 08:38:27 +0000 (01:38 -0700)
This is necessary to use some C++ APIs, and ugly if not impossible
to work around using cname specifiers and external files.

Cython/Compiler/ExprNodes.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py
Cython/Compiler/Visitor.py

index 78e13e1..b798221 100755 (executable)
@@ -2010,7 +2010,7 @@ class IteratorNode(ExprNode):
         elif sequence_type.is_cpp_class:
             begin = sequence_type.scope.lookup("begin")
             if begin is not None:
-                return begin.type.base_type.return_type
+                return begin.type.return_type
         elif sequence_type.is_pyobject:
             return sequence_type
         return py_object_type
@@ -2022,25 +2022,23 @@ class IteratorNode(ExprNode):
         begin = sequence_type.scope.lookup("begin")
         end = sequence_type.scope.lookup("end")
         if (begin is None
-            or not begin.type.is_ptr
-            or not begin.type.base_type.is_cfunction
-            or begin.type.base_type.args):
+            or not begin.type.is_cfunction
+            or begin.type.args):
             error(self.pos, "missing begin() on %s" % self.sequence.type)
             self.type = error_type
             return
         if (end is None
-            or not end.type.is_ptr
-            or not end.type.base_type.is_cfunction
-            or end.type.base_type.args):
+            or not end.type.is_cfunction
+            or end.type.args):
             error(self.pos, "missing end() on %s" % self.sequence.type)
             self.type = error_type
             return
-        iter_type = begin.type.base_type.return_type
+        iter_type = begin.type.return_type
         if iter_type.is_cpp_class:
             if env.lookup_operator_for_types(
                     self.pos,
                     "!=",
-                    [iter_type, end.type.base_type.return_type]) is None:
+                    [iter_type, end.type.return_type]) is None:
                 error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
                 self.type = error_type
                 return
@@ -2054,7 +2052,7 @@ class IteratorNode(ExprNode):
                 return
             self.type = iter_type
         elif iter_type.is_ptr:
-            if not (iter_type == end.type.base_type.return_type):
+            if not (iter_type == end.type.return_type):
                 error(self.pos, "incompatible types for begin() and end()")
             self.type = iter_type
         else:
@@ -2234,7 +2232,7 @@ class NextNode(AtomicExprNode):
         if iterator_type.is_ptr or iterator_type.is_array:
             return iterator_type.base_type
         elif iterator_type.is_cpp_class:
-            item_type = env.lookup_operator_for_types(self.pos, "*", [iterator_type]).type.base_type.return_type
+            item_type = env.lookup_operator_for_types(self.pos, "*", [iterator_type]).type.return_type
             if item_type.is_reference:
                 item_type = item_type.ref_base_type
             return item_type
@@ -2587,7 +2585,7 @@ class IndexNode(ExprNode):
             ]
             index_func = env.lookup_operator('[]', operands)
             if index_func is not None:
-                return index_func.type.base_type.return_type
+                return index_func.type.return_type
 
         # may be slicing or indexing, we don't know
         if base_type in (unicode_type, str_type):
index c5494b1..9d0463d 100644 (file)
@@ -612,7 +612,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                 type = entry.type
                 if type.is_typedef: # Must test this first!
                     pass
-                elif type.is_struct_or_union:
+                elif type.is_struct_or_union or type.is_cpp_class:
                     self.generate_struct_union_predeclaration(entry, code)
                 elif type.is_extension_type:
                     self.generate_objstruct_predeclaration(type, code)
@@ -627,6 +627,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                     self.generate_enum_definition(entry, code)
                 elif type.is_struct_or_union:
                     self.generate_struct_union_definition(entry, code)
+                elif type.is_cpp_class:
+                    self.generate_cpp_class_definition(entry, code)
                 elif type.is_extension_type:
                     self.generate_objstruct_definition(type, code)
 
@@ -666,6 +668,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
 
     def generate_struct_union_predeclaration(self, entry, code):
         type = entry.type
+        if type.is_cpp_class and type.templates:
+            code.putln("template <class %s>" % ", class ".join([T.declaration_code("") for T in type.templates]))
         code.putln(self.sue_predeclaration(type, type.kind, type.cname))
 
     def sue_header_footer(self, type, kind, name):
@@ -709,6 +713,28 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                 code.putln("  #pragma pack(pop)")
                 code.putln("#endif")
 
+    def generate_cpp_class_definition(self, entry, code):
+        code.mark_pos(entry.pos)
+        type = entry.type
+        scope = type.scope
+        if scope:
+            if type.templates:
+                code.putln("template <class %s>" % ", class ".join([T.declaration_code("") for T in type.templates]))
+            # Just let everything be public.
+            code.put("struct %s" % type.cname)
+            if type.base_classes:
+                base_class_decl = ", public ".join(
+                    [base_class.declaration_code("") for base_class in type.base_classes])
+                code.put(" : public %s" % base_class_decl)
+            code.putln(" {")
+            for attr in scope.var_entries:
+                if attr.type.is_cfunction:
+                    code.put("virtual ")
+                code.putln(
+                    "%s;" %
+                        attr.type.declaration_code(attr.cname))
+            code.putln("};")
+
     def generate_enum_definition(self, entry, code):
         code.mark_pos(entry.pos)
         type = entry.type
index 9060b90..4dac0b7 100644 (file)
@@ -1162,7 +1162,7 @@ class CStructOrUnionDefNode(StatNode):
         pass
 
 
-class CppClassNode(CStructOrUnionDefNode):
+class CppClassNode(CStructOrUnionDefNode, BlockNode):
 
     #  name          string
     #  cname         string or None
@@ -1197,11 +1197,30 @@ class CppClassNode(CStructOrUnionDefNode):
         if self.entry is None:
             return
         self.entry.is_cpp_class = 1
+        scope.class_namespace = self.entry.type.declaration_code("")
+        defined_funcs = []
         if self.attributes is not None:
             if self.in_pxd and not env.in_cinclude:
                 self.entry.defined_in_pxd = 1
             for attr in self.attributes:
                 attr.analyse_declarations(scope)
+                if isinstance(attr, CFuncDefNode):
+                    defined_funcs.append(attr)
+        self.body = StatListNode(self.pos, stats=defined_funcs)
+        self.scope = scope
+
+    def analyse_expressions(self, env):
+        self.body.analyse_expressions(self.entry.type.scope)
+
+    def generate_function_definitions(self, env, code):
+        self.body.generate_function_definitions(self.entry.type.scope, code)
+
+    def generate_execution_code(self, code):
+        self.body.generate_execution_code(code)
+
+    def annotate(self, code):
+        self.body.annotate(code)
+
 
 class CEnumDefNode(StatNode):
     #  name           string or None
@@ -2122,7 +2141,7 @@ class CFuncDefNode(FuncDefNode):
         if cname is None:
             cname = self.entry.func_cname
         entity = type.function_header_code(cname, ', '.join(arg_decls))
-        if self.entry.visibility == 'private':
+        if self.entry.visibility == 'private' and '::' not in cname:
             storage_class = "static "
         else:
             storage_class = ""
index 94e64f9..481a512 100644 (file)
@@ -1640,6 +1640,12 @@ if VALUE is not None:
         node.analyse_declarations(self.env_stack[-1])
         return node
 
+    def visit_CppClassNode(self, node):
+        if node.visibility == 'extern':
+            return None
+        else:
+            return self.visit_ClassDefNode(node)
+    
     def visit_CStructOrUnionDefNode(self, node):
         # Create a wrapper node if needed.
         # We want to use the struct type information (so it can't happen
index 2ec8c49..07a2edd 100644 (file)
@@ -2513,8 +2513,6 @@ def p_cdef_statement(s, ctx):
             error(pos, "Extension types cannot be declared cpdef")
         return p_c_class_definition(s, pos, ctx)
     elif s.sy == 'IDENT' and s.systring == 'cppclass':
-        if ctx.visibility != 'extern':
-            error(pos, "C++ classes need to be declared extern")
         return p_cpp_class_definition(s, pos, ctx)
     elif s.sy == 'IDENT' and s.systring in struct_enum_union:
         if ctx.level not in ('module', 'module_pxd'):
@@ -2706,7 +2704,7 @@ def p_c_func_or_var_declaration(s, pos, ctx):
                                 assignable = 1, nonempty = 1)
     declarator.overridable = ctx.overridable
     if s.sy == ':':
-        if ctx.level not in ('module', 'c_class', 'module_pxd', 'c_class_pxd') and not ctx.templates:
+        if ctx.level not in ('module', 'c_class', 'module_pxd', 'c_class_pxd', 'cpp_class') and not ctx.templates:
             s.error("C function definition not allowed here")
         doc, suite = p_suite(s, Ctx(level = 'function'), with_doc = 1)
         result = Nodes.CFuncDefNode(pos,
@@ -3055,7 +3053,7 @@ def p_cpp_class_definition(s, pos,  ctx):
         s.expect('NEWLINE')
         s.expect_indent()
         attributes = []
-        body_ctx = Ctx(visibility = ctx.visibility)
+        body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class')
         body_ctx.templates = templates
         while s.sy != 'DEDENT':
             if s.systring == 'cppclass':
index d9849c6..d844adb 100755 (executable)
@@ -3004,6 +3004,11 @@ class CppClassType(CType):
     has_attributes = 1
     exception_check = True
     namespace = None
+    
+    # For struct-like declaration.
+    kind = "struct"
+    packed = False
+    typedef_flag = False
 
     subtypes = ['templates']
 
index d2b23be..0b2cb1b 100644 (file)
@@ -486,10 +486,11 @@ class Scope(object):
     def declare_cpp_class(self, name, scope,
             pos, cname = None, base_classes = (),
             visibility = 'extern', templates = None):
-        if visibility != 'extern':
-            error(pos, "C++ classes may only be extern")
         if cname is None:
-            cname = name
+            if self.in_cinclude or (visibility != 'private'):
+                cname = name
+            else:
+                cname = self.mangle(Naming.type_prefix, name)
         base_classes = list(base_classes)
         entry = self.lookup_here(name)
         if not entry:
@@ -497,6 +498,7 @@ class Scope(object):
                 name, scope, cname, base_classes, templates = templates)
             entry = self.declare_type(name, type, pos, cname,
                 visibility = visibility, defining = scope is not None)
+            self.sue_entries.append(entry)
         else:
             if not (entry.is_type and entry.type.is_cpp_class):
                 error(pos, "'%s' redeclared " % name)
@@ -525,6 +527,7 @@ class Scope(object):
                     entry.type.scope.declare_inherited_cpp_attributes(base_class.scope)
         if entry.type.scope:
             declare_inherited_attributes(entry, base_classes)
+            entry.type.scope.declare_var(name="this", cname="this", type=PyrexTypes.CPtrType(entry.type), pos=entry.pos)
         if self.is_cpp_class_scope:
             entry.type.namespace = self.outer_scope.lookup(self.name).type
         return entry
@@ -1992,6 +1995,7 @@ class CppClassScope(Scope):
     is_cpp_class_scope = 1
 
     default_constructor = None
+    class_namespace = None
 
     def __init__(self, name, outer_scope, templates=None):
         Scope.__init__(self, name, outer_scope, None)
@@ -2010,10 +2014,10 @@ class CppClassScope(Scope):
         # Add an entry for an attribute.
         if not cname:
             cname = name
-        if type.is_cfunction:
-            type = PyrexTypes.CPtrType(type)
         entry = self.declare(name, cname, type, pos, visibility)
         entry.is_variable = 1
+        if type.is_cfunction and self.class_namespace:
+            entry.func_cname = "%s::%s" % (self.class_namespace, cname)
         self.var_entries.append(entry)
         if type.is_pyobject and not allow_pyobject:
             error(pos,
index 21af9f2..4643362 100644 (file)
@@ -350,6 +350,12 @@ class EnvTransform(CythonTransform):
         self.env_stack.pop()
         return node
 
+    def visit_CStructOrUnionDefNode(self, node):
+        self.env_stack.append((node, node.scope))
+        self.visitchildren(node)
+        self.env_stack.pop()
+        return node
+
     def visit_ScopedExprNode(self, node):
         if node.expr_scope:
             self.env_stack.append((node, node.expr_scope))