More template parameter deduction.
authorRobert Bradshaw <robertwb@gmail.com>
Fri, 27 Dec 2013 07:29:26 +0000 (23:29 -0800)
committerRobert Bradshaw <robertwb@gmail.com>
Fri, 27 Dec 2013 07:29:26 +0000 (23:29 -0800)
Cython/Compiler/PyrexTypes.py
tests/run/cpp_template_functions.pyx
tests/run/cpp_template_functions_helper.h

index ccfd463..1b5556c 100644 (file)
@@ -1239,6 +1239,9 @@ class CConstType(BaseType):
         else:
             return CConstType(base_type)
 
+    def deduce_template_params(self, actual):
+        return self.const_base_type.deduce_template_params(actual)
+
     def create_to_py_utility_code(self, env):
         if self.const_base_type.create_to_py_utility_code(env):
             self.to_py_function = self.const_base_type.to_py_function
@@ -2178,6 +2181,19 @@ class CArrayType(CPointerBaseType):
     def is_complete(self):
         return self.size is not None
 
+    def specialize(self, values):
+        base_type = self.base_type.specialize(values)
+        if base_type == self.base_type:
+            return self
+        else:
+            return CArrayType(base_type)
+
+    def deduce_template_params(self, actual):
+        if isinstance(actual, CArrayType):
+            return self.base_type.deduce_template_params(actual.base_type)
+        else:
+            return None
+
 
 class CPtrType(CPointerBaseType):
     #  base_type     CType              Reference type
@@ -2239,6 +2255,12 @@ class CPtrType(CPointerBaseType):
         else:
             return CPtrType(base_type)
 
+    def deduce_template_params(self, actual):
+        if isinstance(actual, CPtrType):
+            return self.base_type.deduce_template_params(actual.base_type)
+        else:
+            return None
+
     def invalid_value(self):
         return "1"
 
@@ -2279,6 +2301,9 @@ class CReferenceType(BaseType):
         else:
             return CReferenceType(base_type)
 
+    def deduce_template_params(self, actual):
+        return self.ref_base_type.deduce_template_params(actual)
+
     def __getattr__(self, name):
         return getattr(self.ref_base_type, name)
 
@@ -3083,6 +3108,18 @@ class CppClassType(CType):
             specialized.namespace = self.namespace.specialize(values)
         return specialized
 
+    def deduce_template_params(self, actual):
+        if self == actual:
+            return {}
+        # TODO(robertwb): Actual type equality.
+        elif self.declaration_code("") == actual.template_type.declaration_code(""):
+            return reduce(
+                merge_template_deductions,
+                [formal_param.deduce_template_params(actual_param) for (formal_param, actual_param) in zip(self.templates, actual.templates)],
+                {})
+        else:
+            return None
+
     def declaration_code(self, entity_code,
             for_display = 0, dll_linkage = None, pyrex = 0):
         if self.templates:
@@ -3502,19 +3539,20 @@ def best_match(args, functions, pos=None, env=None):
                 merge_template_deductions,
                 [pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)],
                 {})
-            if deductions is not None:
-                if len(deductions) < len(func_type.templates):
-                    errors.append((func, "Unable to deduce type parameter %s" % (
-                        ", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())]))))
-                else:
-                    type_list = [deductions[param] for param in func_type.templates]
-                    from Symtab import Entry
-                    specialization = Entry(
-                        name = func.name + "[%s]" % ",".join([str(t) for t in type_list]),
-                        cname = func.cname + "<%s>" % ",".join([t.declaration_code("") for t in type_list]),
-                        type = func_type.specialize(deductions),
-                        pos = func.pos)
-                    candidates.append((specialization, specialization.type))
+            if deductions is None:
+                errors.append((func, "Unable to deduce type parameters"))
+            elif len(deductions) < len(func_type.templates):
+                errors.append((func, "Unable to deduce type parameter %s" % (
+                    ", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())]))))
+            else:
+                type_list = [deductions[param] for param in func_type.templates]
+                from Symtab import Entry
+                specialization = Entry(
+                    name = func.name + "[%s]" % ",".join([str(t) for t in type_list]),
+                    cname = func.cname + "<%s>" % ",".join([t.declaration_code("") for t in type_list]),
+                    type = func_type.specialize(deductions),
+                    pos = func.pos)
+                candidates.append((specialization, specialization.type))
         else:
             candidates.append((func, func_type))
 
index 361b638..263d0f9 100644 (file)
@@ -8,6 +8,8 @@ cdef extern from "cpp_template_functions_helper.h":
     cdef pair[T, U] two_params[T, U](T, U)
     cdef cppclass A[T]:
         pair[T, U] method[U](T, U)
+    cdef T nested_deduction[T](const T*)
+    pair[T, U] pair_arg[T, U](pair[T, U] a)
 
 def test_no_arg():
     """
@@ -46,3 +48,18 @@ def test_simple_deduction(int x, double y):
     (1, 2.0)
     """
     return one_param(x), one_param(y)
+
+def test_more_deductions(int x, double y):
+    """
+    >>> test_more_deductions(1, 2)
+    (1, 2.0)
+    """
+    return nested_deduction(&x), nested_deduction(&y)
+
+def test_class_deductions(pair[long, double] x):
+    """
+    >>> test_class_deductions((1, 1.5))
+    (1, 1.5)
+    """
+    return pair_arg(x)
+
index f96e667..d5c5aec 100644 (file)
@@ -21,3 +21,13 @@ class A {
             return std::pair<T, U>(a, b);
         }
 };
+
+template <typename T>
+T nested_deduction(const T *a) {
+    return *a;
+}
+
+template <typename T, typename U>
+std::pair<T, U> pair_arg(std::pair<T, U> a) {
+    return a;
+}