is_fused = property(_get_fused_types, doc="Whether this type or any of its "
"subtypes is a fused type")
+ def deduce_template_params(self, actual):
+ """
+ Deduce any template params in this (argument) type given the actual
+ argument type.
+
+ http://en.cppreference.com/w/cpp/language/function_template#Template_argument_deduction
+ """
+ if self == actual:
+ return {}
+ else:
+ return None
+
def __lt__(self, other):
"""
For sorting. The sorting order should correspond to the preference of
else:
return self
+ def deduce_template_params(self, actual):
+ return {self: actual}
+
def same_as_resolved_type(self, other_type):
if isinstance(other_type, TemplatePlaceholderType):
return self.name == other_type.name
% (expectation, actual_nargs)
errors.append((func, error_mesg))
continue
- candidates.append((func, func_type))
+ if func_type.templates:
+ arg_types = [arg.type for arg in args]
+ deductions = reduce(
+ merge_template_deductions,
+ [pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)],
+ {})
+ if deductions is not None:
+ if len(deductions) < len(func_type.templates):
+ errors.append((func, "Unable to deduce type parameter %s" % (
+ ", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())]))))
+ else:
+ type_list = [deductions[param] for param in func_type.templates]
+ from Symtab import Entry
+ specialization = Entry(
+ name = func.name + "[%s]" % ",".join([str(t) for t in type_list]),
+ cname = func.cname + "<%s>" % ",".join([t.declaration_code("") for t in type_list]),
+ type = func_type.specialize(deductions),
+ pos = func.pos)
+ candidates.append((specialization, specialization.type))
+ else:
+ candidates.append((func, func_type))
# Optimize the most common case of no overloading...
if len(candidates) == 1:
return None
+def merge_template_deductions(a, b):
+ if a is None or b is None:
+ return None
+ all = a
+ for param, value in b.iteritems():
+ if param in all:
+ if a[param] != b[param]:
+ return None
+ else:
+ all[param] = value
+ return all
+
def widest_numeric_type(type1, type2):
# Given two numeric types, return the narrowest type
# encompassing both of them.
from libcpp.pair cimport pair
cdef extern from "cpp_template_functions_helper.h":
+ cdef T no_arg[T]()
cdef T one_param[T](T)
cdef pair[T, U] two_params[T, U](T, U)
cdef cppclass A[T]:
pair[T, U] method[U](T, U)
+def test_no_arg():
+ """
+ >>> test_no_arg()
+ 0
+ """
+ return no_arg[int]()
+
def test_one_param(int x):
"""
>>> test_one_param(3)
cdef A[double] a_double
return a_int.method[float](x, y), a_double.method[int](x, y)
# return a_int.method[double](x, y), a_double.method[int](x, y)
+
+def test_simple_deduction(int x, double y):
+ """
+ >>> test_simple_deduction(1, 2)
+ (1, 2.0)
+ """
+ return one_param(x), one_param(y)