From 92cf54f991d11a5eb43176c90688cf2302f9ab79 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 26 Dec 2013 23:00:31 -0800 Subject: [PATCH] simple template deduction --- Cython/Compiler/PyrexTypes.py | 49 ++++++++++++++++++++++++++++++- tests/run/cpp_template_functions.pyx | 15 ++++++++++ tests/run/cpp_template_functions_helper.h | 5 ++++ 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index c4c5737..ccfd463 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -81,6 +81,18 @@ class BaseType(object): is_fused = property(_get_fused_types, doc="Whether this type or any of its " "subtypes is a fused type") + def deduce_template_params(self, actual): + """ + Deduce any template params in this (argument) type given the actual + argument type. + + http://en.cppreference.com/w/cpp/language/function_template#Template_argument_deduction + """ + if self == actual: + return {} + else: + return None + def __lt__(self, other): """ For sorting. The sorting order should correspond to the preference of @@ -3162,6 +3174,9 @@ class TemplatePlaceholderType(CType): else: return self + def deduce_template_params(self, actual): + return {self: actual} + def same_as_resolved_type(self, other_type): if isinstance(other_type, TemplatePlaceholderType): return self.name == other_type.name @@ -3481,7 +3496,27 @@ def best_match(args, functions, pos=None, env=None): % (expectation, actual_nargs) errors.append((func, error_mesg)) continue - candidates.append((func, func_type)) + if func_type.templates: + arg_types = [arg.type for arg in args] + deductions = reduce( + merge_template_deductions, + [pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)], + {}) + if deductions is not None: + if len(deductions) < len(func_type.templates): + errors.append((func, "Unable to deduce type parameter %s" % ( + ", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())])))) + else: + type_list = [deductions[param] for param in func_type.templates] + from Symtab import Entry + specialization = Entry( + name = func.name + "[%s]" % ",".join([str(t) for t in type_list]), + cname = func.cname + "<%s>" % ",".join([t.declaration_code("") for t in type_list]), + type = func_type.specialize(deductions), + pos = func.pos) + candidates.append((specialization, specialization.type)) + else: + candidates.append((func, func_type)) # Optimize the most common case of no overloading... if len(candidates) == 1: @@ -3573,6 +3608,18 @@ def best_match(args, functions, pos=None, env=None): return None +def merge_template_deductions(a, b): + if a is None or b is None: + return None + all = a + for param, value in b.iteritems(): + if param in all: + if a[param] != b[param]: + return None + else: + all[param] = value + return all + def widest_numeric_type(type1, type2): # Given two numeric types, return the narrowest type # encompassing both of them. diff --git a/tests/run/cpp_template_functions.pyx b/tests/run/cpp_template_functions.pyx index d09e802..361b638 100644 --- a/tests/run/cpp_template_functions.pyx +++ b/tests/run/cpp_template_functions.pyx @@ -3,11 +3,19 @@ from libcpp.pair cimport pair cdef extern from "cpp_template_functions_helper.h": + cdef T no_arg[T]() cdef T one_param[T](T) cdef pair[T, U] two_params[T, U](T, U) cdef cppclass A[T]: pair[T, U] method[U](T, U) +def test_no_arg(): + """ + >>> test_no_arg() + 0 + """ + return no_arg[int]() + def test_one_param(int x): """ >>> test_one_param(3) @@ -31,3 +39,10 @@ def test_method(int x, int y): cdef A[double] a_double return a_int.method[float](x, y), a_double.method[int](x, y) # return a_int.method[double](x, y), a_double.method[int](x, y) + +def test_simple_deduction(int x, double y): + """ + >>> test_simple_deduction(1, 2) + (1, 2.0) + """ + return one_param(x), one_param(y) diff --git a/tests/run/cpp_template_functions_helper.h b/tests/run/cpp_template_functions_helper.h index 1b1cead..f96e667 100644 --- a/tests/run/cpp_template_functions_helper.h +++ b/tests/run/cpp_template_functions_helper.h @@ -1,4 +1,9 @@ template +T no_arg() { + return T(); +} + +template T one_param(T value) { return value; } -- 2.7.4