simple template deduction
authorRobert Bradshaw <robertwb@gmail.com>
Fri, 27 Dec 2013 07:00:31 +0000 (23:00 -0800)
committerRobert Bradshaw <robertwb@gmail.com>
Fri, 27 Dec 2013 07:00:31 +0000 (23:00 -0800)
Cython/Compiler/PyrexTypes.py
tests/run/cpp_template_functions.pyx
tests/run/cpp_template_functions_helper.h

index c4c5737..ccfd463 100644 (file)
@@ -81,6 +81,18 @@ class BaseType(object):
     is_fused = property(_get_fused_types, doc="Whether this type or any of its "
                                              "subtypes is a fused type")
 
+    def deduce_template_params(self, actual):
+        """
+        Deduce any template params in this (argument) type given the actual
+        argument type.
+
+        http://en.cppreference.com/w/cpp/language/function_template#Template_argument_deduction
+        """
+        if self == actual:
+            return {}
+        else:
+            return None
+
     def __lt__(self, other):
         """
         For sorting. The sorting order should correspond to the preference of
@@ -3162,6 +3174,9 @@ class TemplatePlaceholderType(CType):
         else:
             return self
 
+    def deduce_template_params(self, actual):
+        return {self: actual}
+
     def same_as_resolved_type(self, other_type):
         if isinstance(other_type, TemplatePlaceholderType):
             return self.name == other_type.name
@@ -3481,7 +3496,27 @@ def best_match(args, functions, pos=None, env=None):
                          % (expectation, actual_nargs)
             errors.append((func, error_mesg))
             continue
-        candidates.append((func, func_type))
+        if func_type.templates:
+            arg_types = [arg.type for arg in args]
+            deductions = reduce(
+                merge_template_deductions,
+                [pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)],
+                {})
+            if deductions is not None:
+                if len(deductions) < len(func_type.templates):
+                    errors.append((func, "Unable to deduce type parameter %s" % (
+                        ", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())]))))
+                else:
+                    type_list = [deductions[param] for param in func_type.templates]
+                    from Symtab import Entry
+                    specialization = Entry(
+                        name = func.name + "[%s]" % ",".join([str(t) for t in type_list]),
+                        cname = func.cname + "<%s>" % ",".join([t.declaration_code("") for t in type_list]),
+                        type = func_type.specialize(deductions),
+                        pos = func.pos)
+                    candidates.append((specialization, specialization.type))
+        else:
+            candidates.append((func, func_type))
 
     # Optimize the most common case of no overloading...
     if len(candidates) == 1:
@@ -3573,6 +3608,18 @@ def best_match(args, functions, pos=None, env=None):
 
     return None
 
+def merge_template_deductions(a, b):
+    if a is None or b is None:
+        return None
+    all = a
+    for param, value in b.iteritems():
+        if param in all:
+            if a[param] != b[param]:
+                return None
+        else:
+            all[param] = value
+    return all
+
 def widest_numeric_type(type1, type2):
     # Given two numeric types, return the narrowest type
     # encompassing both of them.
index d09e802..361b638 100644 (file)
@@ -3,11 +3,19 @@
 from libcpp.pair cimport pair
 
 cdef extern from "cpp_template_functions_helper.h":
+    cdef T no_arg[T]()
     cdef T one_param[T](T)
     cdef pair[T, U] two_params[T, U](T, U)
     cdef cppclass A[T]:
         pair[T, U] method[U](T, U)
 
+def test_no_arg():
+    """
+    >>> test_no_arg()
+    0
+    """
+    return no_arg[int]()
+
 def test_one_param(int x):
     """
     >>> test_one_param(3)
@@ -31,3 +39,10 @@ def test_method(int x, int y):
     cdef A[double] a_double
     return a_int.method[float](x, y), a_double.method[int](x, y)
 #    return a_int.method[double](x, y), a_double.method[int](x, y)
+
+def test_simple_deduction(int x, double y):
+    """
+    >>> test_simple_deduction(1, 2)
+    (1, 2.0)
+    """
+    return one_param(x), one_param(y)
index 1b1cead..f96e667 100644 (file)
@@ -1,4 +1,9 @@
 template <typename T>
+T no_arg() {
+    return T();
+}
+
+template <typename T>
 T one_param(T value) {
     return value;
 }