[Polly][Unroll] Fix unroll_double test.
authorMichael Kruse <llvm-project@meinersbur.de>
Tue, 16 Mar 2021 13:50:37 +0000 (08:50 -0500)
committerMichael Kruse <llvm-project@meinersbur.de>
Tue, 16 Mar 2021 14:00:42 +0000 (09:00 -0500)
We enumerated the cross product Domain x Scatter, but sorted only be the
scatter key. In case there are are multiple statement instances per
scatter value, the order between statement instances of the same loop
iteration was undefined.

Propertly enumerate and sort only by the scatter value, and group the
domains using the scatter dimension again.

Thanks to Leonard Chan for the report.

polly/lib/Transform/ScheduleTreeTransform.cpp
polly/test/ScheduleOptimizer/ManualOptimization/unroll_double.ll

index c6f9c32..32cef0f 100644 (file)
@@ -533,13 +533,13 @@ isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) {
   PartialSchedUAff = PartialSchedUAff.intersect_domain(Domain);
   isl::union_map PartialSchedUMap = isl::union_map(PartialSchedUAff);
 
-  // Make consumable for the following code.
-  // Schedule at the beginning so it is at coordinate 0.
-  isl::union_set PartialSchedUSet = PartialSchedUMap.reverse().wrap();
+  // Enumerator only the scatter elements.
+  isl::union_set ScatterList = PartialSchedUMap.range();
 
-  SmallVector<isl::point, 16> Elts;
+  // Enumerate all loop iterations.
   // TODO: Diagnose if not enumerable or depends on a parameter.
-  PartialSchedUSet.foreach_point([&Elts](isl::point P) -> isl::stat {
+  SmallVector<isl::point, 16> Elts;
+  ScatterList.foreach_point([&Elts](isl::point P) -> isl::stat {
     Elts.push_back(P);
     return isl::stat::ok();
   });
@@ -554,12 +554,10 @@ isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) {
   // Convert the points to a sequence of filters.
   isl::union_set_list List = isl::union_set_list::alloc(Ctx, Elts.size());
   for (isl::point P : Elts) {
-    isl::basic_set AsSet{P};
-
-    // Throw away the scatter dimension.
-    AsSet = AsSet.unwrap().range();
+    // Determine the domains that map this scatter element.
+    isl::union_set DomainFilter = PartialSchedUMap.intersect_range(P).domain();
 
-    List = List.add(AsSet);
+    List = List.add(DomainFilter);
   }
 
   // Replace original band with unrolled sequence.
index a957727..dcd65b3 100644 (file)
@@ -38,15 +38,25 @@ return:
 
 
 ; CHECK-LABEL: Printing analysis 'Polly - Optimize schedule of SCoP' for region: 'for => return' in function 'func':
+; CHECK: domain: "{ Stmt_body[i0] : 0 <= i0 <= 11 }"
+; CHECK    sequence:
+; CHECK:   - filter: "{ Stmt_body[i0] : 0 <= i0 <= 3 }"
+; CHECK        sequence:
 ; CHECK:       - filter: "{ Stmt_body[0] }"
-; CHECK:       - filter: "{ Stmt_body[1] }"
-; CHECK:       - filter: "{ Stmt_body[2] }"
-; CHECK:       - filter: "{ Stmt_body[3] }"
+; CHECK:       - filter: "{ Stmt_body[i0] : (-1 + i0) mod 4 = 0 }"
+; CHECK:       - filter: "{ Stmt_body[i0] : (2 + i0) mod 4 = 0 }"
+; CHECK:       - filter: "{ Stmt_body[i0] : (1 + i0) mod 4 = 0 }"
+; CHECK    sequence:
+; CHECK:   - filter: "{ Stmt_body[i0] : 4 <= i0 <= 7 }"
+; CHECK        sequence:
 ; CHECK:       - filter: "{ Stmt_body[4] }"
-; CHECK:       - filter: "{ Stmt_body[5] }"
-; CHECK:       - filter: "{ Stmt_body[6] }"
-; CHECK:       - filter: "{ Stmt_body[7] }"
+; CHECK:       - filter: "{ Stmt_body[i0] : (-1 + i0) mod 4 = 0 }"
+; CHECK:       - filter: "{ Stmt_body[i0] : (2 + i0) mod 4 = 0 }"
+; CHECK:       - filter: "{ Stmt_body[i0] : (1 + i0) mod 4 = 0 }"
+; CHECK    sequence:
+; CHECK:   - filter: "{ Stmt_body[i0] : 8 <= i0 <= 11 }"
+; CHECK        sequence:
 ; CHECK:       - filter: "{ Stmt_body[8] }"
-; CHECK:       - filter: "{ Stmt_body[9] }"
-; CHECK:       - filter: "{ Stmt_body[10] }"
-; CHECK:       - filter: "{ Stmt_body[11] }"
+; CHECK:       - filter: "{ Stmt_body[i0] : (-1 + i0) mod 4 = 0 }"
+; CHECK:       - filter: "{ Stmt_body[i0] : (2 + i0) mod 4 = 0 }"
+; CHECK:       - filter: "{ Stmt_body[i0] : (1 + i0) mod 4 = 0 }"