[mlir][sparse] extend foreach operation to iterator over sparse constant.
authorPeiming Liu <peiming@google.com>
Wed, 9 Nov 2022 00:43:27 +0000 (00:43 +0000)
committerPeiming Liu <peiming@google.com>
Wed, 9 Nov 2022 01:50:34 +0000 (01:50 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137679

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir

index d0613c0..829b045 100644 (file)
@@ -696,11 +696,49 @@ public:
 
     auto loc = op.getLoc();
     Value input = op.getTensor();
+    SmallVector<Value> reduc = op.getInitArgs();
     auto rtp = input.getType().cast<RankedTensorType>();
     int64_t rank = rtp.getRank();
-    auto enc = getSparseTensorEncoding(rtp);
 
-    SmallVector<Value> reduc = op.getInitArgs();
+    // Special-case: for each over a sparse constant uses its own rewriting
+    // rule.
+    if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
+      if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+        // Foreach on constant.
+        DenseElementsAttr indicesAttr = attr.getIndices();
+        DenseElementsAttr valuesAttr = attr.getValues();
+
+        SmallVector<Value> args;
+        for (int i = 0, e = valuesAttr.size(); i < e; i++) {
+          auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
+          for (int j = 0; j < rank; j++) {
+            auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
+            auto coord = rewriter.create<arith::ConstantIndexOp>(
+                loc, coordAttr.getInt());
+            // Remaps coordinates.
+            args.push_back(coord);
+          }
+          // Remaps value.
+          auto val = rewriter.create<arith::ConstantOp>(loc, valAttr);
+          args.push_back(val);
+          // Remaps iteration args.
+          args.append(reduc);
+          auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
+          Operation *yield = cloned.getBody()->getTerminator();
+          rewriter.mergeBlockBefore(cloned.getBody(), op, args);
+          // clean up
+          args.clear();
+          rewriter.eraseOp(cloned);
+          reduc = yield->getOperands();
+          rewriter.eraseOp(yield);
+        }
+        rewriter.replaceOp(op, reduc);
+        return success();
+      }
+    }
+
+    // Otherwise, use loop emitter to generate loops.
+    auto enc = getSparseTensorEncoding(rtp);
 
     // 1. Generates loop for the sparse input.
     SparseTensorLoopEmitter loopEmitter(ValueRange{input});
index aeb63a0..fed119e 100644 (file)
 }>
 
 module {
+  /// uses foreach operator to print coords and values.
+  func.func @foreach_print_const() {
+    // Initialize a tensor.
+    %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32>
+    sparse_tensor.foreach in %0 : tensor<8x7xf32> do {
+      ^bb0(%1: index, %2: index, %v: f32) :
+        vector.print %1: index
+        vector.print %2: index
+        vector.print %v: f32
+     }
+     return
+  }
 
   /// uses foreach operator to print coords and values.
   func.func @foreach_print_1(%arg0: tensor<2x2xf64, #Row>) {
@@ -111,6 +123,13 @@ module {
     // CHECK: 0
     // CHECK-NEXT: 0
     // CHECK-NEXT: 1
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 6
+    // CHECK-NEXT: 5
+    call @foreach_print_const() : () -> ()
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
     // CHECK-NEXT: 0
     // CHECK-NEXT: 1
     // CHECK-NEXT: 2