[mlir][sparse] fix crash when sparsifying broadcast operations.
authorPeiming Liu <peiming@google.com>
Thu, 27 Oct 2022 17:12:20 +0000 (17:12 +0000)
committerPeiming Liu <peiming@google.com>
Thu, 27 Oct 2022 17:43:22 +0000 (17:43 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136866

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir [new file with mode: 0644]

index d899f0f..b4d986f 100644 (file)
@@ -97,10 +97,11 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
 SparseTensorLoopEmitter::SparseTensorLoopEmitter(ValueRange tensors,
                                                  bool hasOutput,
                                                  bool isSparseOut)
-    : hasOutput(hasOutput), tensors(tensors.begin(), tensors.end()),
-      dimTypes(tensors.size()), pidxs(tensors.size()), coord(tensors.size()),
-      highs(tensors.size()), ptrBuffer(tensors.size()),
-      idxBuffer(tensors.size()), valBuffer(tensors.size()), loopStack() {
+    : hasOutput(hasOutput), isSparseOut(isSparseOut),
+      tensors(tensors.begin(), tensors.end()), dimTypes(tensors.size()),
+      pidxs(tensors.size()), coord(tensors.size()), highs(tensors.size()),
+      ptrBuffer(tensors.size()), idxBuffer(tensors.size()),
+      valBuffer(tensors.size()), loopStack() {
   for (size_t tid = 0, e = tensors.size(); tid < e; tid++) {
     auto t = tensors[tid];
     // a scalar or 0-dimension tensors
@@ -246,7 +247,7 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
     coord[tid][dim] = iv;
     // generate pidx for dense dim (pidx = i * sz + j)
     auto enc = getSparseTensorEncoding(tensors[tid].getType());
-    if (enc)
+    if (enc && !isSparseOutput(tid))
       pidxs[tid][dim] = genAddress(builder, loc, tid, dim, iv);
   }
 
@@ -353,7 +354,7 @@ Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims(
       pidxs[tid][dim] = min;
       // generate pidx for dense dim (pidx = i * sz + j)
       auto enc = getSparseTensorEncoding(tensors[tid].getType());
-      if (enc)
+      if (enc && !isSparseOutput(tid))
         pidxs[tid][dim] = genAddress(builder, loc, tid, dim, min);
     }
     // NOTE: we can also prepares for next dim here in advance
@@ -419,7 +420,7 @@ void SparseTensorLoopEmitter::emitExtraLocalsForTensorsAtDenseDims(
   for (auto [tid, dim] : llvm::zip(tids, dims)) {
     assert(isDenseDLT(dimTypes[tid][dim]));
     auto enc = getSparseTensorEncoding(tensors[tid].getType());
-    if (enc) {
+    if (enc && !isSparseOutput(tid)) {
       bool validPidx = dim == 0 || pidxs[tid][dim - 1];
       if (!validPidx) {
         // We might not find the pidx for the sparse output tensor as it is
index 4fd02ca..0559eed 100644 (file)
@@ -434,6 +434,8 @@ private:
     return hasOutput && tid == tensors.size() - 1;
   }
 
+  bool isSparseOutput(size_t tid) { return isOutputTensor(tid) && isSparseOut; }
+
   /// Setups [lo, hi] for iterating tensor[dim], it assumes that tensor[0
   /// ...dims-1] has already been setup.
   void prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid,
@@ -462,6 +464,7 @@ private:
   // Whether the loop emitter needs to treat the last tensor as the output
   // tensor.
   bool hasOutput;
+  bool isSparseOut;
   /// Input and (optional) output tensors.
   std::vector<Value> tensors;
   /// The dim type array for each tensor.
index 4cea880..a936d98 100644 (file)
@@ -1130,13 +1130,13 @@ static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen,
       assert(all.test(b));
       assert(merger.index(b) == idx);
       if (isUndefDLT(merger.getDimLevelType(b))) {
-        // This could be a synthetic tensor (for invariants and sparse output
-        // tensor).
-        // In both cases, we mean to generate loops over output tensor.
-        // e.g.,
-        // out[i][j] = invariant;
-        if (merger.getSynTensorID() == tid)
-          tid = merger.getOutTensorID();
+        // An undefined dlt in the lattices, we probably mean to iterate based
+        // on the dim of output tensor.
+        // E.g., this could be a synthetic tensor (for invariants and sparse
+        // output tensor).
+        // out[i][j] = invariant; or a broadcast
+        // out[i][j] = in[i] (j is undef for input)
+        tid = merger.getOutTensorID();
       }
       auto dim = codegen.loopIdxToDim[tid][idx];
       if (dim != INVALID_ID) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir
new file mode 100644 (file)
index 0000000..425b984
--- /dev/null
@@ -0,0 +1,52 @@
+// RUN: mlir-opt %s --sparsification --canonicalize --cse | FileCheck %s
+
+#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+#SparseTensor = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>
+
+#trait = {
+  indexing_maps = [
+    affine_map<(d0, d1, d2) -> (d0, d2)>,
+    affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+  ],
+  iterator_types = ["parallel", "parallel", "parallel"]
+}
+
+// CHECK-LABEL:  @main(
+//  CHECK-SAME:    %[[TMP_arg0:.*]]: tensor<4x5xi32, 
+//   CHECK-DAG:  %[[TMP_c3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:  %[[TMP_c0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:  %[[TMP_c1:.*]] = arith.constant 1 : index
+//       CHECK:  %[[TMP_0:.*]] = bufferization.alloc_tensor()
+//       CHECK:  %[[TMP_1:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 0 : index}
+//       CHECK:  %[[TMP_2:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 0 : index}
+//       CHECK:  %[[TMP_3:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 1 : index}
+//       CHECK:  %[[TMP_4:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 1 : index}
+//       CHECK:  %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]]
+//       CHECK:  %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
+//       CHECK:  %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
+//       CHECK:  scf.for %[[TMP_arg1:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] {
+//       CHECK:    %[[TMP_9:.*]] = memref.load %[[TMP_2]][%[[TMP_arg1]]] : memref<?xindex>
+//       CHECK:    scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c3]] step %[[TMP_c1]] {
+//       CHECK:      %[[TMP_10:.*]] = memref.load %[[TMP_3]][%[[TMP_arg1]]] : memref<?xindex>
+//       CHECK:      %[[TMP_11:.*]] = arith.addi %[[TMP_arg1]], %[[TMP_c1]] : index
+//       CHECK:      %[[TMP_12:.*]] = memref.load %[[TMP_3]][%[[TMP_11]]] : memref<?xindex>
+//       CHECK:      scf.for %[[TMP_arg3:.*]] = %[[TMP_10]] to %[[TMP_12]] step %[[TMP_c1]] {
+//       CHECK:        %[[TMP_13:.*]] = memref.load %[[TMP_4]][%[[TMP_arg3]]] : memref<?xindex>
+//       CHECK:        %[[TMP_14:.*]] = memref.load %[[TMP_5]][%[[TMP_arg3]]] : memref<?xi32>
+//       CHECK:        %[[TMP_15:.*]] = sparse_tensor.insert %[[TMP_14]] into %[[TMP_0]][%[[TMP_9]], %[[TMP_arg2]], %[[TMP_13]]]
+//       CHECK:      }
+//       CHECK:    }
+//       CHECK:  }
+//       CHECK:  %[[TMP_8:.*]] = sparse_tensor.load %[[TMP_0]] hasInserts
+//       CHECK:  return %[[TMP_8]] 
+module @func_sparse {
+  func.func public @main(%arg0: tensor<4x5xi32, #DCSR>) -> tensor<4x3x5xi32, #SparseTensor> {
+    %0 = bufferization.alloc_tensor() : tensor<4x3x5xi32, #SparseTensor>
+    %1 = linalg.generic #trait
+    ins(%arg0 : tensor<4x5xi32, #DCSR>) outs(%0 : tensor<4x3x5xi32, #SparseTensor>) {
+    ^bb0(%in: i32, %out: i32):
+      linalg.yield %in : i32
+    } -> tensor<4x3x5xi32, #SparseTensor>
+    return %1 : tensor<4x3x5xi32, #SparseTensor>
+  }
+}