Make SLP root stmt a vector
authorRichard Biener <rguenther@suse.de>
Wed, 2 Jun 2021 11:25:59 +0000 (13:25 +0200)
committerRichard Biener <rguenther@suse.de>
Tue, 8 Jun 2021 13:09:18 +0000 (15:09 +0200)
This fixes a TODO noticed when adding vectorization of
BIT_INSERT_EXPRs and what's now useful for vectorization of
BB reductions.

2021-06-08  Richard Biener  <rguenther@suse.de>

* tree-vectorizer.h (_slp_instance::root_stmt): Change to...
(_slp_instance::root_stmts): ... a vector.
(SLP_INSTANCE_ROOT_STMT): Rename to ...
(SLP_INSTANCE_ROOT_STMTS): ... this.
(slp_root::root): Change to...
(slp_root::roots): ... a vector.
(slp_root::slp_root): Adjust.
* tree-vect-slp.c (_slp_instance::location): Adjust.
(vect_free_slp_instance): Release the root stmt vector.
(vect_build_slp_instance): Adjust.
(vect_analyze_slp): Likewise.
(_bb_vec_info::~_bb_vec_info): Likewise.
(vect_slp_analyze_operations): Likewise.
(vect_bb_vectorization_profitable_p): Likewise.  Adjust
costs for the root stmt.
(vect_slp_check_for_constructors): Gather all BIT_INSERT_EXPRs
as root stmts.
(vect_slp_analyze_bb_1): Simplify by marking all root stmts
as pure_slp.
(vectorize_slp_instance_root_stmt): Adjust.
(vect_schedule_slp): Likewise.

gcc/tree-vect-slp.c
gcc/tree-vectorizer.h

index ca1539e..cc734e0 100644 (file)
@@ -164,8 +164,8 @@ vect_free_slp_tree (slp_tree node)
 dump_user_location_t
 _slp_instance::location () const
 {
-  if (root_stmt)
-    return root_stmt->stmt;
+  if (!root_stmts.is_empty ())
+    return root_stmts[0]->stmt;
   else
     return SLP_TREE_SCALAR_STMTS (root)[0]->stmt;
 }
@@ -178,6 +178,7 @@ vect_free_slp_instance (slp_instance instance)
 {
   vect_free_slp_tree (SLP_INSTANCE_TREE (instance));
   SLP_INSTANCE_LOADS (instance).release ();
+  SLP_INSTANCE_ROOT_STMTS (instance).release ();
   instance->subgraph_entries.release ();
   instance->cost_vec.release ();
   free (instance);
@@ -2503,7 +2504,7 @@ static bool
 vect_build_slp_instance (vec_info *vinfo,
                         slp_instance_kind kind,
                         vec<stmt_vec_info> &scalar_stmts,
-                        stmt_vec_info root_stmt_info,
+                        vec<stmt_vec_info> &root_stmt_infos,
                         unsigned max_tree_size, unsigned *limit,
                         scalar_stmts_to_slp_tree_map_t *bst_map,
                         /* ???  We need stmt_info for group splitting.  */
@@ -2564,7 +2565,7 @@ vect_build_slp_instance (vec_info *vinfo,
          SLP_INSTANCE_TREE (new_instance) = node;
          SLP_INSTANCE_UNROLLING_FACTOR (new_instance) = unrolling_factor;
          SLP_INSTANCE_LOADS (new_instance) = vNULL;
-         SLP_INSTANCE_ROOT_STMT (new_instance) = root_stmt_info;
+         SLP_INSTANCE_ROOT_STMTS (new_instance) = root_stmt_infos;
          SLP_INSTANCE_KIND (new_instance) = kind;
          new_instance->reduc_phis = NULL;
          new_instance->cost_vec = vNULL;
@@ -2836,13 +2837,20 @@ vect_analyze_slp_instance (vec_info *vinfo,
   else
     gcc_unreachable ();
 
+  vec<stmt_vec_info> roots = vNULL;
+  if (kind == slp_inst_kind_ctor)
+    {
+      roots.create (1);
+      roots.quick_push (stmt_info);
+    }
   /* Build the tree for the SLP instance.  */
   bool res = vect_build_slp_instance (vinfo, kind, scalar_stmts,
-                                     kind == slp_inst_kind_ctor
-                                     ? stmt_info : NULL,
+                                     roots,
                                      max_tree_size, limit, bst_map,
                                      kind == slp_inst_kind_store
                                      ? stmt_info : NULL);
+  if (!res)
+    roots.release ();
 
   /* ???  If this is slp_inst_kind_store and the above succeeded here's
      where we should do store group splitting.  */
@@ -2878,12 +2886,15 @@ vect_analyze_slp (vec_info *vinfo, unsigned max_tree_size)
     {
       for (unsigned i = 0; i < bb_vinfo->roots.length (); ++i)
        {
-         vect_location = bb_vinfo->roots[i].root->stmt;
+         vect_location = bb_vinfo->roots[i].roots[0]->stmt;
          if (vect_build_slp_instance (bb_vinfo, bb_vinfo->roots[i].kind,
                                       bb_vinfo->roots[i].stmts,
-                                      bb_vinfo->roots[i].root,
+                                      bb_vinfo->roots[i].roots,
                                       max_tree_size, &limit, bst_map, NULL))
-           bb_vinfo->roots[i].stmts = vNULL;
+           {
+             bb_vinfo->roots[i].stmts = vNULL;
+             bb_vinfo->roots[i].roots = vNULL;
+           }
        }
     }
 
@@ -3741,7 +3752,10 @@ _bb_vec_info::~_bb_vec_info ()
     }
 
   for (unsigned i = 0; i < roots.length (); ++i)
-    roots[i].stmts.release ();
+    {
+      roots[i].stmts.release ();
+      roots[i].roots.release ();
+    }
   roots.release ();
 }
 
@@ -4154,7 +4168,8 @@ vect_slp_analyze_operations (vec_info *vinfo)
                                             &cost_vec)
          /* Instances with a root stmt require vectorized defs for the
             SLP tree root.  */
-         || (SLP_INSTANCE_ROOT_STMT (instance)
+         /* ???  Do inst->kind check instead.  */
+         || (!SLP_INSTANCE_ROOT_STMTS (instance).is_empty ()
              && (SLP_TREE_DEF_TYPE (SLP_INSTANCE_TREE (instance))
                  != vect_internal_def)))
         {
@@ -4460,9 +4475,11 @@ vect_bb_vectorization_profitable_p (bb_vec_info bb_vinfo,
       auto_vec<bool, 20> life;
       life.safe_grow_cleared (SLP_TREE_LANES (SLP_INSTANCE_TREE (instance)),
                              true);
-      if (SLP_INSTANCE_ROOT_STMT (instance))
-       record_stmt_cost (&scalar_costs, 1, scalar_stmt,
-                         SLP_INSTANCE_ROOT_STMT (instance), 0, vect_body);
+      if (!SLP_INSTANCE_ROOT_STMTS (instance).is_empty ())
+       record_stmt_cost (&scalar_costs,
+                         SLP_INSTANCE_ROOT_STMTS (instance).length (),
+                         scalar_stmt,
+                         SLP_INSTANCE_ROOT_STMTS (instance)[0], 0, vect_body);
       vect_bb_slp_scalar_cost (bb_vinfo,
                               SLP_INSTANCE_TREE (instance),
                               &life, &scalar_costs, visited);
@@ -4691,6 +4708,8 @@ vect_slp_check_for_constructors (bb_vec_info bb_vinfo)
          unsigned lanes_found = 1;
          /* Start with the use chains, the last stmt will be the root.  */
          stmt_vec_info last = bb_vinfo->lookup_stmt (assign);
+         vec<stmt_vec_info> roots = vNULL;
+         roots.safe_push (last);
          do
            {
              use_operand_p use_p;
@@ -4710,9 +4729,12 @@ vect_slp_check_for_constructors (bb_vec_info bb_vinfo)
              lane_defs.quick_push (std::make_pair
                                     (this_lane, gimple_assign_rhs2 (use_ass)));
              last = bb_vinfo->lookup_stmt (use_ass);
+             roots.safe_push (last);
              def = gimple_assign_lhs (use_ass);
            }
          while (lanes_found < nlanes);
+         if (roots.length () > 1)
+           std::swap(roots[0], roots[roots.length () - 1]);
          if (lanes_found < nlanes)
            {
              /* Now search the def chain.  */
@@ -4736,6 +4758,7 @@ vect_slp_check_for_constructors (bb_vec_info bb_vinfo)
                  lane_defs.quick_push (std::make_pair
                                          (this_lane,
                                           gimple_assign_rhs2 (def_stmt)));
+                 roots.safe_push (bb_vinfo->lookup_stmt (def_stmt));
                  def = gimple_assign_rhs1 (def_stmt);
                }
              while (lanes_found < nlanes);
@@ -4749,8 +4772,10 @@ vect_slp_check_for_constructors (bb_vec_info bb_vinfo)
              for (unsigned i = 0; i < nlanes; ++i)
                stmts.quick_push (bb_vinfo->lookup_def (lane_defs[i].second));
              bb_vinfo->roots.safe_push (slp_root (slp_inst_kind_ctor,
-                                                  stmts, last));
+                                                  stmts, roots));
            }
+         else
+           roots.release ();
        }
     }
 }
@@ -4905,22 +4930,11 @@ vect_slp_analyze_bb_1 (bb_vec_info bb_vinfo, int n_stmts, bool &fatal,
         relevant.  */
       vect_mark_slp_stmts (SLP_INSTANCE_TREE (instance));
       vect_mark_slp_stmts_relevant (SLP_INSTANCE_TREE (instance));
-      if (stmt_vec_info root = SLP_INSTANCE_ROOT_STMT (instance))
-       {
-         STMT_SLP_TYPE (root) = pure_slp;
-         if (is_gimple_assign (root->stmt)
-             && gimple_assign_rhs_code (root->stmt) == BIT_INSERT_EXPR)
-           {
-             /* ???  We should probably record the whole vector of
-                root stmts so we do not have to back-track here...  */
-             for (unsigned n = SLP_TREE_LANES (SLP_INSTANCE_TREE (instance));
-                  n != 1; --n)
-               {
-                 root = bb_vinfo->lookup_def (gimple_assign_rhs1 (root->stmt));
-                 STMT_SLP_TYPE (root) = pure_slp;
-               }
-           }
-       }
+      unsigned j;
+      stmt_vec_info root;
+      /* Likewise consider instance root stmts as vectorized.  */
+      FOR_EACH_VEC_ELT (SLP_INSTANCE_ROOT_STMTS (instance), j, root)
+       STMT_SLP_TYPE (root) = pure_slp;
 
       i++;
     }
@@ -6357,47 +6371,50 @@ vectorize_slp_instance_root_stmt (slp_tree node, slp_instance instance)
 {
   gassign *rstmt = NULL;
 
-  if (SLP_TREE_NUMBER_OF_VEC_STMTS (node) == 1)
+  if (instance->kind == slp_inst_kind_ctor)
     {
-      gimple *child_stmt;
-      int j;
-
-      FOR_EACH_VEC_ELT (SLP_TREE_VEC_STMTS (node), j, child_stmt)
+      if (SLP_TREE_NUMBER_OF_VEC_STMTS (node) == 1)
        {
-         tree vect_lhs = gimple_get_lhs (child_stmt);
-         tree root_lhs = gimple_get_lhs (instance->root_stmt->stmt);
-         if (!useless_type_conversion_p (TREE_TYPE (root_lhs),
-                                         TREE_TYPE (vect_lhs)))
-           vect_lhs = build1 (VIEW_CONVERT_EXPR, TREE_TYPE (root_lhs),
-                              vect_lhs);
-         rstmt = gimple_build_assign (root_lhs, vect_lhs);
-         break;
-       }
-    }
-  else if (SLP_TREE_NUMBER_OF_VEC_STMTS (node) > 1)
-    {
-      int nelts = SLP_TREE_NUMBER_OF_VEC_STMTS (node);
-      gimple *child_stmt;
-      int j;
-      vec<constructor_elt, va_gc> *v;
-      vec_alloc (v, nelts);
+         gimple *child_stmt;
+         int j;
 
-      FOR_EACH_VEC_ELT (SLP_TREE_VEC_STMTS (node), j, child_stmt)
+         FOR_EACH_VEC_ELT (SLP_TREE_VEC_STMTS (node), j, child_stmt)
+           {
+             tree vect_lhs = gimple_get_lhs (child_stmt);
+             tree root_lhs = gimple_get_lhs (instance->root_stmts[0]->stmt);
+             if (!useless_type_conversion_p (TREE_TYPE (root_lhs),
+                                             TREE_TYPE (vect_lhs)))
+               vect_lhs = build1 (VIEW_CONVERT_EXPR, TREE_TYPE (root_lhs),
+                                  vect_lhs);
+             rstmt = gimple_build_assign (root_lhs, vect_lhs);
+             break;
+           }
+       }
+      else if (SLP_TREE_NUMBER_OF_VEC_STMTS (node) > 1)
        {
-         CONSTRUCTOR_APPEND_ELT (v,
-                                 NULL_TREE,
-                                 gimple_get_lhs (child_stmt));
+         int nelts = SLP_TREE_NUMBER_OF_VEC_STMTS (node);
+         gimple *child_stmt;
+         int j;
+         vec<constructor_elt, va_gc> *v;
+         vec_alloc (v, nelts);
+
+         FOR_EACH_VEC_ELT (SLP_TREE_VEC_STMTS (node), j, child_stmt)
+           CONSTRUCTOR_APPEND_ELT (v, NULL_TREE,
+                                   gimple_get_lhs (child_stmt));
+         tree lhs = gimple_get_lhs (instance->root_stmts[0]->stmt);
+         tree rtype
+           = TREE_TYPE (gimple_assign_rhs1 (instance->root_stmts[0]->stmt));
+         tree r_constructor = build_constructor (rtype, v);
+         rstmt = gimple_build_assign (lhs, r_constructor);
        }
-      tree lhs = gimple_get_lhs (instance->root_stmt->stmt);
-      tree rtype = TREE_TYPE (gimple_assign_rhs1 (instance->root_stmt->stmt));
-      tree r_constructor = build_constructor (rtype, v);
-      rstmt = gimple_build_assign (lhs, r_constructor);
     }
+  else
+    gcc_unreachable ();
 
-    gcc_assert (rstmt);
+  gcc_assert (rstmt);
 
-    gimple_stmt_iterator rgsi = gsi_for_stmt (instance->root_stmt->stmt);
-    gsi_replace (&rgsi, rstmt, true);
+  gimple_stmt_iterator rgsi = gsi_for_stmt (instance->root_stmts[0]->stmt);
+  gsi_replace (&rgsi, rstmt, true);
 }
 
 struct slp_scc_info
@@ -6567,9 +6584,10 @@ vect_schedule_slp (vec_info *vinfo, vec<slp_instance> slp_instances)
        {
          dump_printf_loc (MSG_NOTE, vect_location,
                           "Vectorizing SLP tree:\n");
-         if (SLP_INSTANCE_ROOT_STMT (instance))
+         /* ???  Dump all?  */
+         if (!SLP_INSTANCE_ROOT_STMTS (instance).is_empty ())
            dump_printf_loc (MSG_NOTE, vect_location, "Root stmt: %G",
-                        SLP_INSTANCE_ROOT_STMT (instance)->stmt);
+                        SLP_INSTANCE_ROOT_STMTS (instance)[0]->stmt);
          vect_print_slp_graph (MSG_NOTE, vect_location,
                                SLP_INSTANCE_TREE (instance));
        }
@@ -6579,7 +6597,7 @@ vect_schedule_slp (vec_info *vinfo, vec<slp_instance> slp_instances)
       if (!scc_info.get (node))
        vect_schedule_scc (vinfo, node, instance, scc_info, maxdfs, stack);
 
-      if (SLP_INSTANCE_ROOT_STMT (instance))
+      if (!SLP_INSTANCE_ROOT_STMTS (instance).is_empty ())
        vectorize_slp_instance_root_stmt (node, instance);
 
       if (dump_enabled_p ())
index 7dcb4cd..06d20c7 100644 (file)
@@ -197,7 +197,7 @@ public:
 
   /* For vector constructors, the constructor stmt that the SLP tree is built
      from, NULL otherwise.  */
-  stmt_vec_info root_stmt;
+  vec<stmt_vec_info> root_stmts;
 
   /* The unrolling factor required to vectorized this SLP instance.  */
   poly_uint64 unrolling_factor;
@@ -226,7 +226,7 @@ public:
 #define SLP_INSTANCE_TREE(S)                     (S)->root
 #define SLP_INSTANCE_UNROLLING_FACTOR(S)         (S)->unrolling_factor
 #define SLP_INSTANCE_LOADS(S)                    (S)->loads
-#define SLP_INSTANCE_ROOT_STMT(S)                (S)->root_stmt
+#define SLP_INSTANCE_ROOT_STMTS(S)               (S)->root_stmts
 #define SLP_INSTANCE_KIND(S)                     (S)->kind
 
 #define SLP_TREE_CHILDREN(S)                     (S)->children
@@ -861,11 +861,11 @@ loop_vec_info_for_loop (class loop *loop)
 struct slp_root
 {
   slp_root (slp_instance_kind kind_, vec<stmt_vec_info> stmts_,
-           stmt_vec_info root_)
-    : kind(kind_), stmts(stmts_), root(root_) {}
+           vec<stmt_vec_info> roots_)
+    : kind(kind_), stmts(stmts_), roots(roots_) {}
   slp_instance_kind kind;
   vec<stmt_vec_info> stmts;
-  stmt_vec_info root;
+  vec<stmt_vec_info> roots;
 };
 
 typedef class _bb_vec_info : public vec_info