Add support for external C++ template functions.
authorRobert Bradshaw <robertwb@gmail.com>
Thu, 26 Dec 2013 23:40:58 +0000 (15:40 -0800)
committerRobert Bradshaw <robertwb@gmail.com>
Thu, 26 Dec 2013 23:40:58 +0000 (15:40 -0800)
The syntax follows that of template classes, namely

    cdef T foo[T](T, ...)

Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/Pipeline.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py

index 8e71851..0c9261e 100644 (file)
@@ -2257,7 +2257,7 @@ class IteratorNode(ExprNode):
         elif sequence_type.is_pyobject:
             return sequence_type
         return py_object_type
-    
+
     def analyse_cpp_types(self, env):
         sequence_type = self.sequence.type
         if sequence_type.is_ptr:
@@ -2721,6 +2721,7 @@ class IndexNode(ExprNode):
     #  base     ExprNode
     #  index    ExprNode
     #  indices  [ExprNode]
+    #  type_indices  [PyrexType]
     #  is_buffer_access boolean Whether this is a buffer access.
     #
     #  indices is used on buffer access, index on non-buffer access.
@@ -2732,6 +2733,7 @@ class IndexNode(ExprNode):
 
     subexprs = ['base', 'index', 'indices']
     indices = None
+    type_indices = None
 
     is_subscript = True
     is_fused_index = False
@@ -3103,8 +3105,7 @@ class IndexNode(ExprNode):
         else:
             base_type = self.base.type
 
-            fused_index_operation = base_type.is_cfunction and base_type.is_fused
-            if not fused_index_operation:
+            if not base_type.is_cfunction:
                 if isinstance(self.index, TupleNode):
                     self.index = self.index.analyse_types(
                         env, skip_children=skip_child_analysis)
@@ -3188,8 +3189,17 @@ class IndexNode(ExprNode):
                     self.type = func_type.return_type
                     if setting and not func_type.return_type.is_reference:
                         error(self.pos, "Can't set non-reference result '%s'" % self.type)
-                elif fused_index_operation:
-                    self.parse_indexed_fused_cdef(env)
+                elif base_type.is_cfunction:
+                    if base_type.is_fused:
+                        self.parse_indexed_fused_cdef(env)
+                    else:
+                        self.type_indices = self.parse_index_as_types(env)
+                        if base_type.templates is None:
+                            error(self.pos, "Can only parameterize template functions.")
+                        elif len(base_type.templates) != len(self.type_indices):
+                            error(self.pos, "Wrong number of template arguments: expected %s, got %s" % (
+                                    (len(base_type.templates), len(self.type_indices))))
+                        self.type = base_type.specialize(dict(zip(base_type.templates, self.type_indices)))
                 else:
                     error(self.pos,
                           "Attempting to index non-array type '%s'" %
@@ -3215,6 +3225,20 @@ class IndexNode(ExprNode):
 
         self.base = self.base.as_none_safe_node(msg)
 
+    def parse_index_as_types(self, env, required=True):
+        if isinstance(self.index, TupleNode):
+            indices = self.index.args
+        else:
+            indices = [self.index]
+        type_indices = []
+        for index in indices:
+            type_indices.append(index.analyse_as_type(env))
+            if type_indices[-1] is None:
+                if required:
+                    error(index.pos, "not parsable as a type")
+                return None
+        return type_indices
+
     def parse_indexed_fused_cdef(self, env):
         """
         Interpret fused_cdef_func[specific_type1, ...]
@@ -3234,16 +3258,12 @@ class IndexNode(ExprNode):
 
         if self.index.is_name or self.index.is_attribute:
             positions.append(self.index.pos)
-            specific_types.append(self.index.analyse_as_type(env))
         elif isinstance(self.index, TupleNode):
             for arg in self.index.args:
                 positions.append(arg.pos)
-                specific_type = arg.analyse_as_type(env)
-                specific_types.append(specific_type)
-        else:
-            specific_types = [False]
+        specific_types = self.parse_index_as_types(env, required=False)
 
-        if not Utils.all(specific_types):
+        if specific_types is None:
             self.index = self.index.analyse_types(env)
 
             if not self.base.entry.as_variable:
@@ -3362,6 +3382,10 @@ class IndexNode(ExprNode):
                 index_code = "((unsigned char)(PyByteArray_AS_STRING(%s)[%s]))"
             else:
                 assert False, "unexpected base type in indexing: %s" % self.base.type
+        elif self.base.type.is_cfunction:
+            return "%s<%s>" % (
+                self.base.result(),
+                ",".join([param.declaration_code("") for param in self.type_indices]))
         else:
             if (self.type.is_ptr or self.type.is_array) and self.type == self.base.type:
                 error(self.pos, "Invalid use of pointer slice")
@@ -3388,7 +3412,9 @@ class IndexNode(ExprNode):
 
     def generate_subexpr_evaluation_code(self, code):
         self.base.generate_evaluation_code(code)
-        if self.indices is None:
+        if self.type_indices is not None:
+            pass
+        elif self.indices is None:
             self.index.generate_evaluation_code(code)
         else:
             for i in self.indices:
@@ -3396,7 +3422,9 @@ class IndexNode(ExprNode):
 
     def generate_subexpr_disposal_code(self, code):
         self.base.generate_disposal_code(code)
-        if self.indices is None:
+        if self.type_indices is not None:
+            pass
+        elif self.indices is None:
             self.index.generate_disposal_code(code)
         else:
             for i in self.indices:
@@ -3866,7 +3894,7 @@ class SliceIndexNode(ExprNode):
             if (dst_type not in (bytes_type, bytearray_type)
                     and not env.directives['c_string_encoding']):
                 error(self.pos,
-                    "default encoding required for conversion from '%s' to '%s'" % 
+                    "default encoding required for conversion from '%s' to '%s'" %
                     (self.base.type, dst_type))
             self.type = dst_type
         return super(SliceIndexNode, self).coerce_to(dst_type, env)
@@ -3876,7 +3904,7 @@ class SliceIndexNode(ExprNode):
             error(self.pos,
                   "Slicing is not currently supported for '%s'." % self.type)
             return
-            
+
         base_result = self.base.result()
         result = self.result()
         start_code = self.start_code()
@@ -3929,8 +3957,8 @@ class SliceIndexNode(ExprNode):
                         code.error_goto_if_null(result, self.pos)))
 
         elif self.base.type is unicode_type:
-            code.globalstate.use_utility_code( 
-                          UtilityCode.load_cached("PyUnicode_Substring", "StringTools.c")) 
+            code.globalstate.use_utility_code(
+                          UtilityCode.load_cached("PyUnicode_Substring", "StringTools.c"))
             code.putln(
                 "%s = __Pyx_PyUnicode_Substring(%s, %s, %s); %s" % (
                     result,
@@ -10599,7 +10627,7 @@ class CoerceToPyTypeNode(CoercionNode):
             if (type not in (bytes_type, bytearray_type)
                     and not env.directives['c_string_encoding']):
                 error(arg.pos,
-                    "default encoding required for conversion from '%s' to '%s'" % 
+                    "default encoding required for conversion from '%s' to '%s'" %
                     (arg.type, type))
             self.type = type
         else:
index bb3938b..4b2c206 100644 (file)
@@ -19,8 +19,8 @@ import Naming
 import PyrexTypes
 import TypeSlots
 from PyrexTypes import py_object_type, error_type
-from Symtab import ModuleScope, LocalScope, ClosureScope, \
-    StructOrUnionScope, PyClassScope, CppClassScope
+from Symtab import (ModuleScope, LocalScope, ClosureScope,
+    StructOrUnionScope, PyClassScope, CppClassScope, TemplateScope)
 from Code import UtilityCode
 from StringEncoding import EncodedString, escape_byte_string, split_string_literal
 import Options
@@ -465,6 +465,9 @@ class CDeclaratorNode(Node):
 
     calling_convention = ""
 
+    def analyse_templates(self):
+        # Only C++ functions have templates.
+        return None
 
 class CNameDeclaratorNode(CDeclaratorNode):
     #  name    string             The Cython name being declared
@@ -523,7 +526,7 @@ class CArrayDeclaratorNode(CDeclaratorNode):
     child_attrs = ["base", "dimension"]
 
     def analyse(self, base_type, env, nonempty = 0):
-        if base_type.is_cpp_class:
+        if base_type.is_cpp_class or base_type.is_cfunction:
             from ExprNodes import TupleNode
             if isinstance(self.dimension, TupleNode):
                 args = self.dimension.args
@@ -565,6 +568,7 @@ class CArrayDeclaratorNode(CDeclaratorNode):
 class CFuncDeclaratorNode(CDeclaratorNode):
     # base             CDeclaratorNode
     # args             [CArgDeclNode]
+    # templates        [TemplatePlaceholderType]
     # has_varargs      boolean
     # exception_value  ConstNode
     # exception_check  boolean    True if PyErr_Occurred check needed
@@ -575,6 +579,28 @@ class CFuncDeclaratorNode(CDeclaratorNode):
 
     overridable = 0
     optional_arg_count = 0
+    templates = None
+
+    def analyse_templates(self):
+        if isinstance(self.base, CArrayDeclaratorNode):
+            from ExprNodes import TupleNode, NameNode
+            template_node = self.base.dimension
+            if isinstance(template_node, TupleNode):
+                template_nodes = template_node.args
+            elif isinstance(template_node, NameNode):
+                template_nodes = [template_node]
+            else:
+                error(template_node.pos, "Template arguments must be a list of names")
+            self.templates = []
+            for template in template_nodes:
+                if isinstance(template, NameNode):
+                    self.templates.append(PyrexTypes.TemplatePlaceholderType(template.name))
+                else:
+                    error(template.pos, "Template arguments must be a list of names")
+            self.base = self.base.base
+            return self.templates
+        else:
+            return None
 
     def analyse(self, return_type, env, nonempty = 0, directive_locals = {}):
         if nonempty:
@@ -659,7 +685,8 @@ class CFuncDeclaratorNode(CDeclaratorNode):
             optional_arg_count = self.optional_arg_count,
             exception_value = exc_val, exception_check = exc_check,
             calling_convention = self.base.calling_convention,
-            nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable)
+            nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable,
+            templates = self.templates)
 
         if self.optional_arg_count:
             if func_type.is_fused:
@@ -892,7 +919,7 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
                     else:
                         scope = None
                         break
-                
+
                 if scope is None:
                     # Maybe it's a cimport.
                     scope = env.find_imported_module(self.module_path, self.pos)
@@ -1164,6 +1191,21 @@ class CVarDefNode(StatNode):
         if not dest_scope:
             dest_scope = env
         self.dest_scope = dest_scope
+
+        if self.declarators:
+            templates = self.declarators[0].analyse_templates()
+        else:
+            templates = None
+        if templates is not None:
+            if self.visibility != 'extern':
+                error(self.pos, "Only extern functions allowed")
+            if len(self.declarators) > 1:
+                error(self.declarators[1].pos, "Can't multiply declare template types")
+            env = TemplateScope('func_template', env)
+            env.directives = env.outer_scope.directives
+            for template_param in templates:
+                env.declare_type(template_param.name, template_param, self.pos)
+
         base_type = self.base_type.analyse(env)
 
         if base_type.is_fused and not self.in_pxd and (env.is_c_class_scope or
@@ -1175,12 +1217,12 @@ class CVarDefNode(StatNode):
         visibility = self.visibility
 
         for declarator in self.declarators:
-            
+
             if (len(self.declarators) > 1
                 and not isinstance(declarator, CNameDeclaratorNode)
                 and env.directives['warn.multiple_declarators']):
                 warning(declarator.pos, "Non-trivial type declarators in shared declaration.", 1)
-            
+
             if isinstance(declarator, CFuncDeclaratorNode):
                 name_declarator, type = declarator.analyse(base_type, env, directive_locals=self.directive_locals)
             else:
index 48c2cb8..420307c 100644 (file)
@@ -119,7 +119,7 @@ class UseUtilityCodeDefinitions(CythonTransform):
         self.process_entry(node.entry)
         self.process_entry(node.type_entry)
         return node
-                     
+
 #
 # Pipeline factories
 #
index 84096b3..c4c5737 100644 (file)
@@ -2525,11 +2525,6 @@ class CFuncType(CType):
         return '(%s)' % s
 
     def specialize(self, values):
-        if self.templates is None:
-            new_templates = None
-        else:
-            new_templates = [v.specialize(values) for v in self.templates]
-
         result = CFuncType(self.return_type.specialize(values),
                            [arg.specialize(values) for arg in self.args],
                            has_varargs = self.has_varargs,
@@ -2540,7 +2535,7 @@ class CFuncType(CType):
                            with_gil = self.with_gil,
                            is_overridable = self.is_overridable,
                            optional_arg_count = self.optional_arg_count,
-                           templates = new_templates)
+                           templates = self.templates)
 
         result.from_fused = self.is_fused
         return result
index 7c1d21e..1c59799 100644 (file)
@@ -276,7 +276,7 @@ class Scope(object):
     # qualified_name    string             "modname" or "modname.classname"
     #                                        Python strings in this scope
     # nogil             boolean            In a nogil section
-    # directives       dict                Helper variable for the recursive
+    # directives        dict               Helper variable for the recursive
     #                                      analysis, contains directive values.
     # is_internal       boolean            Is only used internally (simpler setup)
 
@@ -2195,7 +2195,7 @@ class CppClassScope(Scope):
                                   entry.pos,
                                   entry.cname,
                                   entry.visibility)
-                
+
         return scope
 
 
@@ -2237,3 +2237,8 @@ class CConstScope(Scope):
             entry = copy.copy(entry)
             entry.type = PyrexTypes.c_const_type(entry.type)
             return entry
+
+class TemplateScope(Scope):
+    def __init__(self, name, outer_scope):
+        Scope.__init__(self, name, outer_scope, None)
+        self.directives = outer_scope.directives