move major part of SimplifyCalls transform into type analysis in GeneralCallNode
authorStefan Behnel <stefan_ml@behnel.de>
Wed, 6 Feb 2013 07:34:40 +0000 (08:34 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Wed, 6 Feb 2013 07:34:40 +0000 (08:34 +0100)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
Cython/Compiler/Pipeline.py

index 4f0fd72..2c3bbde 100755 (executable)
@@ -4299,6 +4299,7 @@ class GeneralCallNode(CallNode):
     #  keyword_args     ExprNode or None  Dict of keyword arguments
 
     type = py_object_type
+    is_simple_call = False
 
     subexprs = ['function', 'positional_args', 'keyword_args']
 
@@ -4324,17 +4325,23 @@ class GeneralCallNode(CallNode):
         if self.analyse_as_type_constructor(env):
             return
         self.function.analyse_types(env)
-        self.positional_args.analyse_types(env)
-        if self.keyword_args:
-            self.keyword_args.analyse_types(env)
         if not self.function.type.is_pyobject:
             if self.function.type.is_error:
                 self.type = error_type
                 return
-            if hasattr(self.function, 'entry') and not self.function.entry.as_variable:
-                error(self.pos, "Keyword and starred arguments not allowed in cdef functions.")
+            if hasattr(self.function, 'entry'):
+                self.map_keywords_to_posargs()
+                if not self.is_simple_call:
+                    if self.function.entry.as_variable:
+                        self.function = self.function.coerce_to_pyobject(env)
+                    else:
+                        error(self.pos, "Keyword and starred arguments "
+                                        "not allowed in cdef functions.")
             else:
                 self.function = self.function.coerce_to_pyobject(env)
+        if self.keyword_args:
+            self.keyword_args.analyse_types(env)
+        self.positional_args.analyse_types(env)
         self.positional_args = \
             self.positional_args.coerce_to_pyobject(env)
         function = self.function
@@ -4348,6 +4355,54 @@ class GeneralCallNode(CallNode):
             self.type = py_object_type
         self.is_temp = 1
 
+    def map_keywords_to_posargs(self):
+        if not isinstance(self.positional_args, TupleNode):
+            # has starred argument
+            return
+        if not isinstance(self.keyword_args, DictNode):
+            # nothing to do here
+            return
+        function = self.function
+        entry = getattr(function, 'entry', None)
+        if not entry or not entry.is_cfunction:
+            return
+
+        args = self.positional_args.args
+        kwargs = self.keyword_args
+        declared_args = entry.type.args
+        if len(declared_args) < len(args):
+            # will lead to an error elsewhere
+            error(self.pos, "function call got too many positional arguments, "
+                            "expected %d, got %s" % (len(declared_args), len(args)))
+            return
+
+        matched_pos_args = set([arg.name for arg in declared_args[:len(args)]])
+        unmatched_args = declared_args[len(args):]
+        matched_kwargs = set()
+        args = list(args)
+        # TODO: match keywords out-of-order and move values
+        #       into ordered temps if necessary
+        for decl_arg, arg in zip(unmatched_args, kwargs.key_value_pairs):
+            name = arg.key.value
+            if name in matched_pos_args:
+                error(arg.pos, "keyword argument '%s' passed twice" % name)
+                return
+            if decl_arg.name == name:
+                matched_kwargs.add(name)
+                args.append(arg.value)
+            else:
+                break
+        if not matched_kwargs:
+            return
+        self.positional_args.args = args
+        if len(kwargs.key_value_pairs) == len(matched_kwargs):
+            self.keyword_args = None
+            self.is_simple_call = True
+        else:
+            kwargs.key_value_pairs = [
+                item for item in kwargs.key_value_pairs
+                if item.key.value not in matched_kwargs ]
+
     def generate_result_code(self, code):
         if self.type.is_error: return
         if self.keyword_args:
index cc86ac2..80d6609 100644 (file)
@@ -1175,59 +1175,22 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
 
 class SimplifyCalls(Visitor.EnvTransform):
     """
-    Statically map keyword arguments in C calls to positional arguments.
+    Replace GeneralCallNode by SimpleCallNode if possible.
     """
     def visit_GeneralCallNode(self, node):
         self.visitchildren(node)
-        arg_tuple = node.positional_args
-        if not isinstance(arg_tuple, ExprNodes.TupleNode):
-            return node
-        args = arg_tuple.args
-        kwargs = node.keyword_args
-
-        function = node.function
-        if not function.is_name:
-            # TODO: optimise methods as well
+        if not node.is_simple_call:
             return node
-        entry = function.entry
-        if not entry:
-            entry = self.current_env().lookup(function.name)
-            if not entry:
-                return node
-        if not entry.is_cfunction or not entry.type:
-            return node
-
-        declared_args = entry.type.args
-        if len(declared_args) < len(args):
-            # will lead to an error elsewhere
-            return node
-        matched_pos_args = set([arg.name for arg in declared_args[:len(args)]])
-        unmatched_args = declared_args[len(args):]
-        matched_kwargs = set()
-        args = list(args)
-        # TODO: match keywords out-of-order and move values
-        #       into ordered temps if necessary
-        for decl_arg, arg in zip(unmatched_args, kwargs.key_value_pairs):
-            name = arg.key.value
-            if name in matched_pos_args:
-                # keyword argument passed twice => should fail elsewhere
-                return node
-            if decl_arg.name == name:
-                matched_kwargs.add(name)
-                args.append(arg.value)
-            else:
-                break
-        if not matched_kwargs:
-            return node
-        if len(kwargs.key_value_pairs) == len(matched_kwargs):
-            return ExprNodes.SimpleCallNode(
-                node.pos,
-                function=function, args=args)
-        arg_tuple.args = args
-        kwargs.key_value_pairs = [
-            item for item in kwargs.key_value_pairs
-            if item.key.value not in matched_kwargs ]
-        return node
+        args = [ unwrap_coerced_node(arg)
+                 for arg in node.positional_args.args ]
+        call_node = ExprNodes.SimpleCallNode(
+            node.pos,
+            function=node.function,
+            args=args)
+        call_node.analyse_types(self.current_env())
+        if node.type != call_node.type:
+            call_node = call_node.coerce_to(node.type)
+        return call_node
 
 
 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
index 8ec43a2..55588fb 100644 (file)
@@ -182,7 +182,6 @@ def create_pipeline(context, mode, exclude_classes=()):
         AnalyseDeclarationsTransform(context),
         AutoTestDictTransform(context),
         EmbedSignature(context),
-        SimplifyCalls(context),
         EarlyReplaceBuiltinCalls(context),  ## Necessary?
         TransformBuiltinMethods(context),  ## Necessary?
         MarkParallelAssignments(context),
@@ -194,6 +193,7 @@ def create_pipeline(context, mode, exclude_classes=()):
         _check_c_declarations,
         InlineDefNodeCalls(context),
         AnalyseExpressionsTransform(context),
+        SimplifyCalls(context),
         FindInvalidUseOfFusedTypes(context),
         CreateClosureClasses(context),  ## After all lookups and type inference
         ExpandInplaceOperators(context),