[mlir][sparse] refine heuristic for iteration graph topsort
authorAart Bik <ajcbik@google.com>
Wed, 1 Sep 2021 22:29:52 +0000 (15:29 -0700)
committerAart Bik <ajcbik@google.com>
Fri, 3 Sep 2021 15:37:15 +0000 (08:37 -0700)
The sparse index order must always be satisfied, but this
may give a choice in topsorts for several cases. We broke
ties in favor of any dense index order, since this gives
good locality. However, breaking ties in favor of pushing
unrelated indices into sparse iteration spaces gives better
asymptotic complexity. This revision improves the heuristic.

Note that in the long run, we are really interested in using
ML for ML to find the best loop ordering as a replacement for
such heuristics.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D109100

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/sparse_2d.mlir

index f303186..6fb09ca 100644 (file)
@@ -30,6 +30,9 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
+// Iteration graph sorting.
+enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 };
+
 // Code generation.
 struct CodeGen {
   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops)
@@ -141,7 +144,7 @@ static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
 /// order yields innermost unit-stride access with better spatial locality.
 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
                                   std::vector<unsigned> &topSort,
-                                  bool sparseOnly) {
+                                  unsigned mask) {
   // Set up an n x n from/to adjacency matrix of the iteration graph
   // for the implicit loop indices i_0 .. i_n-1.
   unsigned n = op.getNumLoops();
@@ -152,8 +155,8 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     auto map = op.getTiedIndexingMap(t);
     auto enc = getSparseTensorEncoding(t->get().getType());
     assert(map.getNumDims() == n);
-    // Skip dense tensor constraints when sparse only is requested.
-    if (sparseOnly && !enc)
+    // Skip dense tensor constraints when not requested.
+    if (!(mask & SortMask::kIncludeDense) && !enc)
       continue;
     // Each tensor expression and optional dimension ordering (row-major
     // by default) puts an ordering constraint on the loop indices. For
@@ -164,6 +167,16 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
       unsigned t = map.getDimPosition(perm(enc, d));
       adjM[f][t] = true;
     }
+    // Push unrelated loops into sparse iteration space, so these
+    // will be skipped more often.
+    if (mask & SortMask::kIncludeUndef) {
+      unsigned tensor = t->getOperandNumber();
+      for (unsigned i = 0; i < n; i++)
+        if (merger.isDim(tensor, i, Dim::kSparse))
+          for (unsigned j = 0; j < n; j++)
+            if (merger.isDim(tensor, j, Dim::kUndef))
+              adjM[i][j] = true;
+    }
   }
 
   // Topologically sort the iteration graph to determine loop order.
@@ -1134,8 +1147,12 @@ public:
     // This assumes that higher-level passes have already put the
     // tensors in each tensor expression in a feasible order.
     std::vector<unsigned> topSort;
-    if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) &&
-        !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
+    if (!computeIterationGraph(merger, op, topSort,
+                               SortMask::kIncludeUndef |
+                                   SortMask::kIncludeDense) &&
+        !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) &&
+        !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) &&
+        !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly))
       return failure();
 
     // Builds the tensor expression for the Linalg operation in SSA form.
index 9f4b3ac..ffc6e8d 100644 (file)
@@ -1043,25 +1043,26 @@ func @scale(%arga: tensor<?x?xf64, #Tds>, %argx: tensor<?x?xf64>) -> tensor<?x?x
 // CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
-// CHECK:             scf.for %[[VAL_22:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] {
-// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_22]]] : memref<?x?xf32>
-// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
-// CHECK:               %[[VAL_25:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index
-// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_25]]] : memref<?xindex>
-// CHECK:               scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] {
-// CHECK:                 %[[VAL_28:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref<?xindex>
-// CHECK:                 %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
-// CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_27]]] : memref<?xf32>
-// CHECK:                 %[[VAL_31:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_28]]] : memref<?x?xf32>
-// CHECK:                 %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32
-// CHECK:                 %[[VAL_33:.*]] = mulf %[[VAL_30]], %[[VAL_32]] : f32
-// CHECK:                 %[[VAL_34:.*]] = addf %[[VAL_29]], %[[VAL_33]] : f32
-// CHECK:                 memref.store %[[VAL_34]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref<?x?xf32>
+// CHECK:             %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK:             %[[VAL_23:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] {
+// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
+// CHECK:               %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (f32) {
+// CHECK:                 %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref<?x?xf32>
+// CHECK:                 %[[VAL_33:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_30]], %[[VAL_26]]] : memref<?x?xf32>
+// CHECK:                 %[[VAL_34:.*]] = mulf %[[VAL_32]], %[[VAL_33]] : f32
+// CHECK:                 %[[VAL_35:.*]] = mulf %[[VAL_27]], %[[VAL_34]] : f32
+// CHECK:                 %[[VAL_36:.*]] = addf %[[VAL_31]], %[[VAL_35]] : f32
+// CHECK:                 scf.yield %[[VAL_36]] : f32
 // CHECK:               }
+// CHECK:               memref.store %[[VAL_37:.*]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref<?x?xf32>
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_35:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
-// CHECK:           return %[[VAL_35]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_38:.*]] = memref.tensor_load %[[VAL_17]] : memref<?x?xf32>
+// CHECK:           return %[[VAL_38]] : tensor<?x?xf32>
 // CHECK:         }
 func @sampled_dense_dense(%args: tensor<?x?xf32, #Tss>,
                           %arga: tensor<?x?xf32>,