From 404c8968026a355731192f1b48a6f7f16394a5dd Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 26 Dec 2013 15:40:58 -0800 Subject: [PATCH] Add support for external C++ template functions. The syntax follows that of template classes, namely cdef T foo[T](T, ...) --- Cython/Compiler/ExprNodes.py | 64 +++++++++++++++++++++++++++++++------------ Cython/Compiler/Nodes.py | 56 ++++++++++++++++++++++++++++++++----- Cython/Compiler/Pipeline.py | 2 +- Cython/Compiler/PyrexTypes.py | 7 +---- Cython/Compiler/Symtab.py | 9 ++++-- 5 files changed, 104 insertions(+), 34 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 8e71851..0c9261e 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -2257,7 +2257,7 @@ class IteratorNode(ExprNode): elif sequence_type.is_pyobject: return sequence_type return py_object_type - + def analyse_cpp_types(self, env): sequence_type = self.sequence.type if sequence_type.is_ptr: @@ -2721,6 +2721,7 @@ class IndexNode(ExprNode): # base ExprNode # index ExprNode # indices [ExprNode] + # type_indices [PyrexType] # is_buffer_access boolean Whether this is a buffer access. # # indices is used on buffer access, index on non-buffer access. @@ -2732,6 +2733,7 @@ class IndexNode(ExprNode): subexprs = ['base', 'index', 'indices'] indices = None + type_indices = None is_subscript = True is_fused_index = False @@ -3103,8 +3105,7 @@ class IndexNode(ExprNode): else: base_type = self.base.type - fused_index_operation = base_type.is_cfunction and base_type.is_fused - if not fused_index_operation: + if not base_type.is_cfunction: if isinstance(self.index, TupleNode): self.index = self.index.analyse_types( env, skip_children=skip_child_analysis) @@ -3188,8 +3189,17 @@ class IndexNode(ExprNode): self.type = func_type.return_type if setting and not func_type.return_type.is_reference: error(self.pos, "Can't set non-reference result '%s'" % self.type) - elif fused_index_operation: - self.parse_indexed_fused_cdef(env) + elif base_type.is_cfunction: + if base_type.is_fused: + self.parse_indexed_fused_cdef(env) + else: + self.type_indices = self.parse_index_as_types(env) + if base_type.templates is None: + error(self.pos, "Can only parameterize template functions.") + elif len(base_type.templates) != len(self.type_indices): + error(self.pos, "Wrong number of template arguments: expected %s, got %s" % ( + (len(base_type.templates), len(self.type_indices)))) + self.type = base_type.specialize(dict(zip(base_type.templates, self.type_indices))) else: error(self.pos, "Attempting to index non-array type '%s'" % @@ -3215,6 +3225,20 @@ class IndexNode(ExprNode): self.base = self.base.as_none_safe_node(msg) + def parse_index_as_types(self, env, required=True): + if isinstance(self.index, TupleNode): + indices = self.index.args + else: + indices = [self.index] + type_indices = [] + for index in indices: + type_indices.append(index.analyse_as_type(env)) + if type_indices[-1] is None: + if required: + error(index.pos, "not parsable as a type") + return None + return type_indices + def parse_indexed_fused_cdef(self, env): """ Interpret fused_cdef_func[specific_type1, ...] @@ -3234,16 +3258,12 @@ class IndexNode(ExprNode): if self.index.is_name or self.index.is_attribute: positions.append(self.index.pos) - specific_types.append(self.index.analyse_as_type(env)) elif isinstance(self.index, TupleNode): for arg in self.index.args: positions.append(arg.pos) - specific_type = arg.analyse_as_type(env) - specific_types.append(specific_type) - else: - specific_types = [False] + specific_types = self.parse_index_as_types(env, required=False) - if not Utils.all(specific_types): + if specific_types is None: self.index = self.index.analyse_types(env) if not self.base.entry.as_variable: @@ -3362,6 +3382,10 @@ class IndexNode(ExprNode): index_code = "((unsigned char)(PyByteArray_AS_STRING(%s)[%s]))" else: assert False, "unexpected base type in indexing: %s" % self.base.type + elif self.base.type.is_cfunction: + return "%s<%s>" % ( + self.base.result(), + ",".join([param.declaration_code("") for param in self.type_indices])) else: if (self.type.is_ptr or self.type.is_array) and self.type == self.base.type: error(self.pos, "Invalid use of pointer slice") @@ -3388,7 +3412,9 @@ class IndexNode(ExprNode): def generate_subexpr_evaluation_code(self, code): self.base.generate_evaluation_code(code) - if self.indices is None: + if self.type_indices is not None: + pass + elif self.indices is None: self.index.generate_evaluation_code(code) else: for i in self.indices: @@ -3396,7 +3422,9 @@ class IndexNode(ExprNode): def generate_subexpr_disposal_code(self, code): self.base.generate_disposal_code(code) - if self.indices is None: + if self.type_indices is not None: + pass + elif self.indices is None: self.index.generate_disposal_code(code) else: for i in self.indices: @@ -3866,7 +3894,7 @@ class SliceIndexNode(ExprNode): if (dst_type not in (bytes_type, bytearray_type) and not env.directives['c_string_encoding']): error(self.pos, - "default encoding required for conversion from '%s' to '%s'" % + "default encoding required for conversion from '%s' to '%s'" % (self.base.type, dst_type)) self.type = dst_type return super(SliceIndexNode, self).coerce_to(dst_type, env) @@ -3876,7 +3904,7 @@ class SliceIndexNode(ExprNode): error(self.pos, "Slicing is not currently supported for '%s'." % self.type) return - + base_result = self.base.result() result = self.result() start_code = self.start_code() @@ -3929,8 +3957,8 @@ class SliceIndexNode(ExprNode): code.error_goto_if_null(result, self.pos))) elif self.base.type is unicode_type: - code.globalstate.use_utility_code( - UtilityCode.load_cached("PyUnicode_Substring", "StringTools.c")) + code.globalstate.use_utility_code( + UtilityCode.load_cached("PyUnicode_Substring", "StringTools.c")) code.putln( "%s = __Pyx_PyUnicode_Substring(%s, %s, %s); %s" % ( result, @@ -10599,7 +10627,7 @@ class CoerceToPyTypeNode(CoercionNode): if (type not in (bytes_type, bytearray_type) and not env.directives['c_string_encoding']): error(arg.pos, - "default encoding required for conversion from '%s' to '%s'" % + "default encoding required for conversion from '%s' to '%s'" % (arg.type, type)) self.type = type else: diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index bb3938b..4b2c206 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -19,8 +19,8 @@ import Naming import PyrexTypes import TypeSlots from PyrexTypes import py_object_type, error_type -from Symtab import ModuleScope, LocalScope, ClosureScope, \ - StructOrUnionScope, PyClassScope, CppClassScope +from Symtab import (ModuleScope, LocalScope, ClosureScope, + StructOrUnionScope, PyClassScope, CppClassScope, TemplateScope) from Code import UtilityCode from StringEncoding import EncodedString, escape_byte_string, split_string_literal import Options @@ -465,6 +465,9 @@ class CDeclaratorNode(Node): calling_convention = "" + def analyse_templates(self): + # Only C++ functions have templates. + return None class CNameDeclaratorNode(CDeclaratorNode): # name string The Cython name being declared @@ -523,7 +526,7 @@ class CArrayDeclaratorNode(CDeclaratorNode): child_attrs = ["base", "dimension"] def analyse(self, base_type, env, nonempty = 0): - if base_type.is_cpp_class: + if base_type.is_cpp_class or base_type.is_cfunction: from ExprNodes import TupleNode if isinstance(self.dimension, TupleNode): args = self.dimension.args @@ -565,6 +568,7 @@ class CArrayDeclaratorNode(CDeclaratorNode): class CFuncDeclaratorNode(CDeclaratorNode): # base CDeclaratorNode # args [CArgDeclNode] + # templates [TemplatePlaceholderType] # has_varargs boolean # exception_value ConstNode # exception_check boolean True if PyErr_Occurred check needed @@ -575,6 +579,28 @@ class CFuncDeclaratorNode(CDeclaratorNode): overridable = 0 optional_arg_count = 0 + templates = None + + def analyse_templates(self): + if isinstance(self.base, CArrayDeclaratorNode): + from ExprNodes import TupleNode, NameNode + template_node = self.base.dimension + if isinstance(template_node, TupleNode): + template_nodes = template_node.args + elif isinstance(template_node, NameNode): + template_nodes = [template_node] + else: + error(template_node.pos, "Template arguments must be a list of names") + self.templates = [] + for template in template_nodes: + if isinstance(template, NameNode): + self.templates.append(PyrexTypes.TemplatePlaceholderType(template.name)) + else: + error(template.pos, "Template arguments must be a list of names") + self.base = self.base.base + return self.templates + else: + return None def analyse(self, return_type, env, nonempty = 0, directive_locals = {}): if nonempty: @@ -659,7 +685,8 @@ class CFuncDeclaratorNode(CDeclaratorNode): optional_arg_count = self.optional_arg_count, exception_value = exc_val, exception_check = exc_check, calling_convention = self.base.calling_convention, - nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable) + nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable, + templates = self.templates) if self.optional_arg_count: if func_type.is_fused: @@ -892,7 +919,7 @@ class CSimpleBaseTypeNode(CBaseTypeNode): else: scope = None break - + if scope is None: # Maybe it's a cimport. scope = env.find_imported_module(self.module_path, self.pos) @@ -1164,6 +1191,21 @@ class CVarDefNode(StatNode): if not dest_scope: dest_scope = env self.dest_scope = dest_scope + + if self.declarators: + templates = self.declarators[0].analyse_templates() + else: + templates = None + if templates is not None: + if self.visibility != 'extern': + error(self.pos, "Only extern functions allowed") + if len(self.declarators) > 1: + error(self.declarators[1].pos, "Can't multiply declare template types") + env = TemplateScope('func_template', env) + env.directives = env.outer_scope.directives + for template_param in templates: + env.declare_type(template_param.name, template_param, self.pos) + base_type = self.base_type.analyse(env) if base_type.is_fused and not self.in_pxd and (env.is_c_class_scope or @@ -1175,12 +1217,12 @@ class CVarDefNode(StatNode): visibility = self.visibility for declarator in self.declarators: - + if (len(self.declarators) > 1 and not isinstance(declarator, CNameDeclaratorNode) and env.directives['warn.multiple_declarators']): warning(declarator.pos, "Non-trivial type declarators in shared declaration.", 1) - + if isinstance(declarator, CFuncDeclaratorNode): name_declarator, type = declarator.analyse(base_type, env, directive_locals=self.directive_locals) else: diff --git a/Cython/Compiler/Pipeline.py b/Cython/Compiler/Pipeline.py index 48c2cb8..420307c 100644 --- a/Cython/Compiler/Pipeline.py +++ b/Cython/Compiler/Pipeline.py @@ -119,7 +119,7 @@ class UseUtilityCodeDefinitions(CythonTransform): self.process_entry(node.entry) self.process_entry(node.type_entry) return node - + # # Pipeline factories # diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 84096b3..c4c5737 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -2525,11 +2525,6 @@ class CFuncType(CType): return '(%s)' % s def specialize(self, values): - if self.templates is None: - new_templates = None - else: - new_templates = [v.specialize(values) for v in self.templates] - result = CFuncType(self.return_type.specialize(values), [arg.specialize(values) for arg in self.args], has_varargs = self.has_varargs, @@ -2540,7 +2535,7 @@ class CFuncType(CType): with_gil = self.with_gil, is_overridable = self.is_overridable, optional_arg_count = self.optional_arg_count, - templates = new_templates) + templates = self.templates) result.from_fused = self.is_fused return result diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 7c1d21e..1c59799 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -276,7 +276,7 @@ class Scope(object): # qualified_name string "modname" or "modname.classname" # Python strings in this scope # nogil boolean In a nogil section - # directives dict Helper variable for the recursive + # directives dict Helper variable for the recursive # analysis, contains directive values. # is_internal boolean Is only used internally (simpler setup) @@ -2195,7 +2195,7 @@ class CppClassScope(Scope): entry.pos, entry.cname, entry.visibility) - + return scope @@ -2237,3 +2237,8 @@ class CConstScope(Scope): entry = copy.copy(entry) entry.type = PyrexTypes.c_const_type(entry.type) return entry + +class TemplateScope(Scope): + def __init__(self, name, outer_scope): + Scope.__init__(self, name, outer_scope, None) + self.directives = outer_scope.directives -- 2.7.4