[mlir][sparse] support constant affine expression on dense dimension
authorPeiming Liu <peiming@google.com>
Wed, 16 Nov 2022 23:18:16 +0000 (23:18 +0000)
committerPeiming Liu <peiming@google.com>
Tue, 22 Nov 2022 22:34:42 +0000 (22:34 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D138170

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/sparse_affine.mlir

index 55c9463..c9a1525 100644 (file)
@@ -997,7 +997,6 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
                          ArrayRef<size_t> extraTids,
                          ArrayRef<size_t> extraDims) {
   Location loc = op.getLoc();
-  auto iteratorTypes = op.getIteratorTypesArray();
   bool isSparse = isCompressedDLT(merger.getDimLevelType(tid, idx)) ||
                   isSingletonDLT(merger.getDimLevelType(tid, idx));
   bool isParallel = isParallelFor(codegen, isOuter, isSparse);
@@ -1189,6 +1188,42 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   return false;
 }
 
+static void genConstantDenseAddressFromLevel(CodeGen &codegen,
+                                             OpBuilder &builder,
+                                             linalg::GenericOp op, unsigned tid,
+                                             unsigned lvl) {
+  // TODO: Handle affine expression on output tensor.
+  assert(tid < op.getNumDpsInputs());
+
+  OpOperand *input = op.getDpsInputOperands()[tid];
+  ArrayRef<AffineExpr> affines = op.getMatchingIndexingMap(input).getResults();
+  auto enc = getSparseTensorEncoding(input->get().getType());
+  if (enc) {
+    for (unsigned i = lvl, e = affines.size(); i < e; i++) {
+      AffineExpr affine = affines[toOrigDim(enc, i)];
+      if (isDenseDLT(getDimLevelType(enc, i)) &&
+          affine.isa<AffineConstantExpr>()) {
+        codegen.loopEmitter.genDenseAffineAddressAtCurLevel(
+            builder, op.getLoc(), input->getOperandNumber(), i, affine);
+      } else {
+        // Breaks on first non-dense non-constant level.
+        return;
+      }
+    }
+  }
+}
+
+static void genInitConstantDenseAddress(CodeGen &codegen,
+                                        RewriterBase &rewriter,
+                                        linalg::GenericOp op) {
+  // We can generates address for constant affine expression before any loops
+  // starting from the first level as they do not depend on any thing.
+  // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+  // levels can be determined before loops.
+  for (unsigned tid = 0, e = op.getNumDpsInputs(); tid < e; tid++)
+    genConstantDenseAddressFromLevel(codegen, rewriter, op, tid, 0);
+}
+
 static void translateBitsToTidDimPairs(
     Merger &merger, CodeGen &codegen, linalg::GenericOp op, unsigned li,
     unsigned idx, SmallVectorImpl<size_t> &condTids,
@@ -1244,30 +1279,21 @@ static void translateBitsToTidDimPairs(
         if (exp.isa<AffineDimExpr>() || !isDenseDLT(getDimLevelType(enc, i)))
           continue;
 
-        // Constant affine expressions on dense level required to be generated
-        // when
-        // 1. The previous level is an (at-level) invariant compound dense
-        // affine (with no corresponding loop idx); or
-        // 2. The previous level is being generated right now.
-        if (exp.isa<AffineConstantExpr>()) {
-          // TODO:  Should we come up with a more adhersive way to handle
-          // constant expression? We now requires two (somehow ad-hoc) code for
-          // it.
-          assert(false && "do not support constant affine");
-        }
-
-        bool atLevel = false;
-        if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) {
-          // If the compound affine is invariant and we are right at the
-          // level. We need to generate the address according to the affine
-          // expression. This is also the best place we can do it to avoid
-          // putting it inside inner loops.
-          // NOTE: It assumes that the levels of the input tensor are
-          // initialized in order, another more admissible approach might be
-          // accepting out-of-order access between consecutive dense levels.
-          affineTids.push_back(tid);
-          affineDims.push_back(i);
-          exps.push_back(exp);
+        // Constant affine expression are handled in genLoop
+        if (!exp.isa<AffineConstantExpr>()) {
+          bool atLevel = false;
+          if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) {
+            // If the compound affine is invariant and we are right at the
+            // level. We need to generate the address according to the affine
+            // expression. This is also the best place we can do it to avoid
+            // putting it inside inner loops.
+            // NOTE: It assumes that the levels of the input tensor are
+            // initialized in order, another more admissible approach might be
+            // accepting out-of-order access between consecutive dense levels.
+            affineTids.push_back(tid);
+            affineDims.push_back(i);
+            exps.push_back(exp);
+          }
         }
       }
     }
@@ -1310,6 +1336,17 @@ static Operation *startLoop(Merger &merger, CodeGen &codegen,
     codegen.loopEmitter.genDenseAffineAddressAtCurLevel(builder, op.getLoc(),
                                                         tid, dim, exp);
   }
+
+  // Until now, we have entered every <tid, dim> pair in {cond, extra,
+  // affine}Tids/Dims. The addresses of the upcoming levels which are dependent
+  // on constant affines expression may now be determined.
+  auto allTids = llvm::concat<size_t>(condTids, extraTids, affineTids);
+  auto allDims = llvm::concat<size_t>(condDims, extraDims, affineDims);
+  for (auto [tid, dim] : llvm::zip(allTids, allDims)) {
+    if (tid != merger.getOutTensorID())
+      genConstantDenseAddressFromLevel(codegen, builder, op, tid, dim + 1);
+  }
+
   return loop;
 }
 
@@ -1437,7 +1474,6 @@ static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
 //===----------------------------------------------------------------------===//
 
 namespace {
-
 /// Sparse rewriting rule for generic Lingalg operation.
 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
 public:
@@ -1505,6 +1541,7 @@ public:
     CodeGen codegen(options, op.getContext(), tensors, numTensors, numLoops,
                     sparseOut, outerParNest, topSort);
     genBuffers(merger, codegen, rewriter, op);
+    genInitConstantDenseAddress(codegen, rewriter, op);
     genStmt(merger, codegen, rewriter, op, exp, 0);
     genResult(merger, codegen, rewriter, op);
     return success();
index e312a90..1fa2df5 100644 (file)
@@ -225,3 +225,64 @@ func.func @mul_affine_dense_dim_2d(%arga: tensor<34x16xf64, #CSR>,
   } -> tensor<32x16xf64>
   return %0 : tensor<32x16xf64>
 }
+
+#trait5 = {
+  indexing_maps = [
+    affine_map<(i,j) -> (2,j)>,  // a
+    affine_map<(i,j) -> (i,3)>,  // b
+    affine_map<(i,j) -> (i,j)>     // x (out)
+  ],
+  iterator_types = ["parallel","parallel"],
+  doc = "x(i,j) += a(2,j) * b(i,3)"
+}
+
+// CHECK-LABEL:   func.func @mul_const_affine_dense_dim_2d(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: tensor<34x16xf64, 
+// CHECK-SAME:                                             %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 19 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<34x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<34x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<34x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf64>
+// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_6]] {
+// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK:             %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_20]]] : memref<?xf64>
+// CHECK:             %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_24:.*]] = %[[VAL_22]] to %[[VAL_23]] step %[[VAL_6]] {
+// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_25]]] : memref<32x16xf64>
+// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref<?xf64>
+// CHECK:               %[[VAL_28:.*]] = arith.mulf %[[VAL_27]], %[[VAL_21]] : f64
+// CHECK:               %[[VAL_29:.*]] = arith.addf %[[VAL_26]], %[[VAL_28]] : f64
+// CHECK:               memref.store %[[VAL_29]], %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_25]]] : memref<32x16xf64>
+// CHECK:             } 
+// CHECK:           } 
+// CHECK:           %[[VAL_30:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16xf64>
+// CHECK:           return %[[VAL_30]] : tensor<32x16xf64>
+// CHECK:         }
+func.func @mul_const_affine_dense_dim_2d(%arga: tensor<34x16xf64, #CSR>,
+                                         %argb: tensor<32x19xf64, #Row>,
+                                         %argx: tensor<32x16xf64>) -> tensor<32x16xf64> {
+  %0 = linalg.generic #trait5
+     ins(%arga, %argb: tensor<34x16xf64, #CSR>, tensor<32x19xf64, #Row>)
+    outs(%argx: tensor<32x16xf64>) {
+      ^bb(%a: f64, %b: f64, %x: f64):
+        %0 = arith.mulf %a, %b : f64
+        %1 = arith.addf %x, %0 : f64
+        linalg.yield %1 : f64
+  } -> tensor<32x16xf64>
+  return %0 : tensor<32x16xf64>
+}