[mlir][sparse] add sparse matmul codegen check tests
authorAart Bik <ajcbik@google.com>
Tue, 8 Nov 2022 22:58:45 +0000 (14:58 -0800)
committerAart Bik <ajcbik@google.com>
Tue, 8 Nov 2022 23:58:30 +0000 (15:58 -0800)
This found an insertion point bug in access pattern expansion

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D137675

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir [new file with mode: 0644]

index a35f97c..c161fa5 100644 (file)
@@ -852,6 +852,8 @@ public:
              index);
     genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index);
     rewriter.create<scf::YieldOp>(loc, fields);
+    rewriter.setInsertionPointAfter(loop);
+    Value result = genTuple(rewriter, loc, dstType, loop->getResults());
     // Deallocate the buffers on exit of the full loop nest.
     Operation *parent = getTop(op);
     rewriter.setInsertionPointAfter(parent);
@@ -859,8 +861,7 @@ public:
     rewriter.create<memref::DeallocOp>(loc, filled);
     rewriter.create<memref::DeallocOp>(loc, added);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op,
-                       genTuple(rewriter, loc, dstType, loop->getResults()));
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
new file mode 100644 (file)
index 0000000..daaf04a
--- /dev/null
@@ -0,0 +1,138 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN:  --sparsification --sparse-tensor-codegen \
+// RUN:  --canonicalize --cse | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (i,j)>
+}>
+
+//
+// Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
+//
+// CHECK-LABEL:   func.func @matmul(
+// CHECK-SAME:      %[[VAL_0:.*0]]: memref<2xindex>,
+// CHECK-SAME:      %[[VAL_1:.*1]]: memref<3xindex>,
+// CHECK-SAME:      %[[VAL_2:.*2]]: memref<?xindex>,
+// CHECK-SAME:      %[[VAL_3:.*3]]: memref<?xindex>,
+// CHECK-SAME:      %[[VAL_4:.*4]]: memref<?xf64>,
+// CHECK-SAME:      %[[VAL_5:.*5]]: memref<2xindex>,
+// CHECK-SAME:      %[[VAL_6:.*6]]: memref<3xindex>,
+// CHECK-SAME:      %[[VAL_7:.*7]]: memref<?xindex>,
+// CHECK-SAME:      %[[VAL_8:.*8]]: memref<?xindex>,
+// CHECK-SAME:      %[[VAL_9:.*9]]: memref<?xf64>)
+// CHECK-SAME: -> (memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>) {
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_14:.*]] = arith.constant false
+// CHECK-DAG:       %[[VAL_15:.*]] = arith.constant true
+// CHECK-DAG:       %[[VAL_16:.*]] = memref.alloc() : memref<2xindex>
+// CHECK-DAG:       %[[VAL_17:.*]] = memref.alloc() : memref<3xindex>
+// CHECK-DAG:       %[[VAL_18:.*]] = memref.alloc() : memref<16xindex>
+// CHECK-DAG:       %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_20:.*]] = memref.alloc() : memref<16xindex>
+// CHECK-DAG:       %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_22:.*]] = memref.alloc() : memref<16xf64>
+// CHECK-DAG:       %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref<?xf64>
+// CHECK:           linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>)
+// CHECK:           memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex>
+// CHECK:           memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex>
+// CHECK:           %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_12]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
+// CHECK:           %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_24]], %[[VAL_12]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index, index
+// CHECK:           %[[VAL_26:.*]] = memref.alloc() : memref<4xf64>
+// CHECK:           %[[VAL_27:.*]] = memref.alloc() : memref<4xi1>
+// CHECK:           %[[VAL_28:.*]] = memref.alloc() : memref<4xindex>
+// CHECK:           %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref<?xindex>
+// CHECK:           linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>)
+// CHECK:           linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>)
+// CHECK:           %[[VAL_30:.*]]:2 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_21]], %[[VAL_33:.*]] = %[[VAL_23]]) -> (memref<?xindex>, memref<?xf64>) {
+// CHECK:             %[[VAL_34:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK:             %[[VAL_35:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index
+// CHECK:             %[[VAL_36:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK:             %[[VAL_37:.*]] = scf.for %[[VAL_38:.*]] = %[[VAL_34]] to %[[VAL_36]] step %[[VAL_13]] iter_args(%[[VAL_39:.*]] = %[[VAL_12]]) -> (index) {
+// CHECK:               %[[VAL_40:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_38]]] : memref<?xindex>
+// CHECK:               %[[VAL_41:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_38]]] : memref<?xf64>
+// CHECK:               %[[VAL_42:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_40]]] : memref<?xindex>
+// CHECK:               %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_13]] : index
+// CHECK:               %[[VAL_44:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref<?xindex>
+// CHECK:               %[[VAL_45:.*]] = scf.for %[[VAL_46:.*]] = %[[VAL_42]] to %[[VAL_44]] step %[[VAL_13]] iter_args(%[[VAL_47:.*]] = %[[VAL_39]]) -> (index) {
+// CHECK:                 %[[VAL_48:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_46]]] : memref<?xindex>
+// CHECK:                 %[[VAL_49:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64>
+// CHECK:                 %[[VAL_50:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_46]]] : memref<?xf64>
+// CHECK:                 %[[VAL_51:.*]] = arith.mulf %[[VAL_41]], %[[VAL_50]] : f64
+// CHECK:                 %[[VAL_52:.*]] = arith.addf %[[VAL_49]], %[[VAL_51]] : f64
+// CHECK:                 %[[VAL_53:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1>
+// CHECK:                 %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_53]], %[[VAL_14]] : i1
+// CHECK:                 %[[VAL_55:.*]] = scf.if %[[VAL_54]] -> (index) {
+// CHECK:                   memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1>
+// CHECK:                   memref.store %[[VAL_48]], %[[VAL_28]]{{\[}}%[[VAL_47]]] : memref<4xindex>
+// CHECK:                   %[[VAL_56:.*]] = arith.addi %[[VAL_47]], %[[VAL_13]] : index
+// CHECK:                   scf.yield %[[VAL_56]] : index
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[VAL_47]] : index
+// CHECK:                 }
+// CHECK:                 memref.store %[[VAL_52]], %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64>
+// CHECK:                 scf.yield %[[VAL_57:.*]] : index
+// CHECK:               }
+// CHECK:               scf.yield %[[VAL_58:.*]] : index
+// CHECK:             }
+// CHECK:             sparse_tensor.sort %[[VAL_59:.*]], %[[VAL_29]] : memref<?xindex>
+// CHECK:             %[[VAL_60:.*]]:2 = scf.for %[[VAL_61:.*]] = %[[VAL_12]] to %[[VAL_59]] step %[[VAL_13]] iter_args(%[[VAL_62:.*]] = %[[VAL_32]], %[[VAL_63:.*]] = %[[VAL_33]]) -> (memref<?xindex>, memref<?xf64>) {
+// CHECK:               %[[VAL_64:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_61]]] : memref<4xindex>
+// CHECK:               %[[VAL_65:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64>
+// CHECK:               %[[VAL_66:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK:               %[[VAL_67:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK:               %[[VAL_68:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_13]]] : memref<3xindex>
+// CHECK:               %[[VAL_69:.*]] = arith.subi %[[VAL_67]], %[[VAL_13]] : index
+// CHECK:               %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_66]], %[[VAL_67]] : index
+// CHECK:               %[[VAL_71:.*]] = scf.if %[[VAL_70]] -> (i1) {
+// CHECK:                 %[[VAL_72:.*]] = memref.load %[[VAL_62]]{{\[}}%[[VAL_69]]] : memref<?xindex>
+// CHECK:                 %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_72]], %[[VAL_64]] : index
+// CHECK:                 scf.yield %[[VAL_73]] : i1
+// CHECK:               } else {
+// CHECK:                 memref.store %[[VAL_68]], %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref<?xindex>
+// CHECK:                 scf.yield %[[VAL_14]] : i1
+// CHECK:               }
+// CHECK:               %[[VAL_74:.*]] = scf.if %[[VAL_75:.*]] -> (memref<?xindex>) {
+// CHECK:                 scf.yield %[[VAL_62]] : memref<?xindex>
+// CHECK:               } else {
+// CHECK:                 %[[VAL_76:.*]] = arith.addi %[[VAL_68]], %[[VAL_13]] : index
+// CHECK:                 memref.store %[[VAL_76]], %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK:                 %[[VAL_77:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_62]], %[[VAL_64]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
+// CHECK:                 scf.yield %[[VAL_77]] : memref<?xindex>
+// CHECK:               }
+// CHECK:               %[[VAL_78:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_63]], %[[VAL_65]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+// CHECK:               memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64>
+// CHECK:               memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_64]]] : memref<4xi1>
+// CHECK:               scf.yield %[[VAL_79:.*]], %[[VAL_78]] : memref<?xindex>, memref<?xf64>
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_80:.*]]#0, %[[VAL_80]]#1 : memref<?xindex>, memref<?xf64>
+// CHECK:           }
+// CHECK:           memref.dealloc %[[VAL_26]] : memref<4xf64>
+// CHECK:           memref.dealloc %[[VAL_27]] : memref<4xi1>
+// CHECK:           memref.dealloc %[[VAL_28]] : memref<4xindex>
+// CHECK:           %[[VAL_81:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_12]]] : memref<3xindex>
+// CHECK:           %[[VAL_82:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK:           %[[VAL_83:.*]] = scf.for %[[VAL_84:.*]] = %[[VAL_13]] to %[[VAL_81]] step %[[VAL_13]] iter_args(%[[VAL_85:.*]] = %[[VAL_82]]) -> (index) {
+// CHECK:             %[[VAL_86:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref<?xindex>
+// CHECK:             %[[VAL_87:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_12]] : index
+// CHECK:             %[[VAL_88:.*]] = arith.select %[[VAL_87]], %[[VAL_85]], %[[VAL_86]] : index
+// CHECK:             scf.if %[[VAL_87]] {
+// CHECK:               memref.store %[[VAL_85]], %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref<?xindex>
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_88]] : index
+// CHECK:           }
+// CHECK:           return %[[VAL_16]], %[[VAL_17]], %[[VAL_25]], %[[VAL_89:.*]]#0, %[[VAL_89]]#1 : memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK:         }
+func.func @matmul(%A: tensor<4x8xf64, #CSR>,
+                  %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
+  %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
+  %D = linalg.matmul
+    ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>)
+       outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+  return %D: tensor<4x4xf64, #CSR>
+}