From 129177eaf0ecb202a6f44ab8d23ad51fe00d15f6 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 8 Nov 2022 14:58:45 -0800 Subject: [PATCH] [mlir][sparse] add sparse matmul codegen check tests This found an insertion point bug in access pattern expansion Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D137675 --- .../Transforms/SparseTensorCodegen.cpp | 5 +- .../SparseTensor/sparse_matmul_codegen.mlir | 138 +++++++++++++++++++++ 2 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index a35f97c..c161fa5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -852,6 +852,8 @@ public: index); genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index); rewriter.create(loc, fields); + rewriter.setInsertionPointAfter(loop); + Value result = genTuple(rewriter, loc, dstType, loop->getResults()); // Deallocate the buffers on exit of the full loop nest. Operation *parent = getTop(op); rewriter.setInsertionPointAfter(parent); @@ -859,8 +861,7 @@ public: rewriter.create(loc, filled); rewriter.create(loc, added); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, - genTuple(rewriter, loc, dstType, loop->getResults())); + rewriter.replaceOp(op, result); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir new file mode 100644 index 0000000..daaf04a --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -0,0 +1,138 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// RUN: mlir-opt %s --linalg-generalize-named-ops \ +// RUN: --sparsification --sparse-tensor-codegen \ +// RUN: --canonicalize --cse | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + dimOrdering = affine_map<(i,j) -> (i,j)> +}> + +// +// Computes C = A x B with all matrices sparse (SpMSpM) in CSR. +// +// CHECK-LABEL: func.func @matmul( +// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_4:.*4]]: memref, +// CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>, +// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>, +// CHECK-SAME: %[[VAL_7:.*7]]: memref, +// CHECK-SAME: %[[VAL_8:.*8]]: memref, +// CHECK-SAME: %[[VAL_9:.*9]]: memref) +// CHECK-SAME: -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_14:.*]] = arith.constant false +// CHECK-DAG: %[[VAL_15:.*]] = arith.constant true +// CHECK-DAG: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex> +// CHECK-DAG: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex> +// CHECK-DAG: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex> +// CHECK-DAG: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref +// CHECK-DAG: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex> +// CHECK-DAG: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref +// CHECK-DAG: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64> +// CHECK-DAG: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref +// CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>) +// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex> +// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex> +// CHECK: %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_12]] {idx = 0 : index} : memref<3xindex>, memref, index +// CHECK: %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_24]], %[[VAL_12]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref, index, index +// CHECK: %[[VAL_26:.*]] = memref.alloc() : memref<4xf64> +// CHECK: %[[VAL_27:.*]] = memref.alloc() : memref<4xi1> +// CHECK: %[[VAL_28:.*]] = memref.alloc() : memref<4xindex> +// CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref +// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>) +// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>) +// CHECK: %[[VAL_30:.*]]:2 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_21]], %[[VAL_33:.*]] = %[[VAL_23]]) -> (memref, memref) { +// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index +// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_37:.*]] = scf.for %[[VAL_38:.*]] = %[[VAL_34]] to %[[VAL_36]] step %[[VAL_13]] iter_args(%[[VAL_39:.*]] = %[[VAL_12]]) -> (index) { +// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_38]]] : memref +// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_38]]] : memref +// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_40]]] : memref +// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_13]] : index +// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref +// CHECK: %[[VAL_45:.*]] = scf.for %[[VAL_46:.*]] = %[[VAL_42]] to %[[VAL_44]] step %[[VAL_13]] iter_args(%[[VAL_47:.*]] = %[[VAL_39]]) -> (index) { +// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_46]]] : memref +// CHECK: %[[VAL_49:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64> +// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_46]]] : memref +// CHECK: %[[VAL_51:.*]] = arith.mulf %[[VAL_41]], %[[VAL_50]] : f64 +// CHECK: %[[VAL_52:.*]] = arith.addf %[[VAL_49]], %[[VAL_51]] : f64 +// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1> +// CHECK: %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_53]], %[[VAL_14]] : i1 +// CHECK: %[[VAL_55:.*]] = scf.if %[[VAL_54]] -> (index) { +// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_48]]] : memref<4xi1> +// CHECK: memref.store %[[VAL_48]], %[[VAL_28]]{{\[}}%[[VAL_47]]] : memref<4xindex> +// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_47]], %[[VAL_13]] : index +// CHECK: scf.yield %[[VAL_56]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_47]] : index +// CHECK: } +// CHECK: memref.store %[[VAL_52]], %[[VAL_26]]{{\[}}%[[VAL_48]]] : memref<4xf64> +// CHECK: scf.yield %[[VAL_57:.*]] : index +// CHECK: } +// CHECK: scf.yield %[[VAL_58:.*]] : index +// CHECK: } +// CHECK: sparse_tensor.sort %[[VAL_59:.*]], %[[VAL_29]] : memref +// CHECK: %[[VAL_60:.*]]:2 = scf.for %[[VAL_61:.*]] = %[[VAL_12]] to %[[VAL_59]] step %[[VAL_13]] iter_args(%[[VAL_62:.*]] = %[[VAL_32]], %[[VAL_63:.*]] = %[[VAL_33]]) -> (memref, memref) { +// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_61]]] : memref<4xindex> +// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64> +// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_13]]] : memref<3xindex> +// CHECK: %[[VAL_69:.*]] = arith.subi %[[VAL_67]], %[[VAL_13]] : index +// CHECK: %[[VAL_70:.*]] = arith.cmpi ult, %[[VAL_66]], %[[VAL_67]] : index +// CHECK: %[[VAL_71:.*]] = scf.if %[[VAL_70]] -> (i1) { +// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_62]]{{\[}}%[[VAL_69]]] : memref +// CHECK: %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_72]], %[[VAL_64]] : index +// CHECK: scf.yield %[[VAL_73]] : i1 +// CHECK: } else { +// CHECK: memref.store %[[VAL_68]], %[[VAL_25]]{{\[}}%[[VAL_31]]] : memref +// CHECK: scf.yield %[[VAL_14]] : i1 +// CHECK: } +// CHECK: %[[VAL_74:.*]] = scf.if %[[VAL_75:.*]] -> (memref) { +// CHECK: scf.yield %[[VAL_62]] : memref +// CHECK: } else { +// CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_68]], %[[VAL_13]] : index +// CHECK: memref.store %[[VAL_76]], %[[VAL_25]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_77:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_62]], %[[VAL_64]] {idx = 1 : index} : memref<3xindex>, memref, index +// CHECK: scf.yield %[[VAL_77]] : memref +// CHECK: } +// CHECK: %[[VAL_78:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_63]], %[[VAL_65]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_64]]] : memref<4xf64> +// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_64]]] : memref<4xi1> +// CHECK: scf.yield %[[VAL_79:.*]], %[[VAL_78]] : memref, memref +// CHECK: } +// CHECK: scf.yield %[[VAL_80:.*]]#0, %[[VAL_80]]#1 : memref, memref +// CHECK: } +// CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64> +// CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1> +// CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex> +// CHECK: %[[VAL_81:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_12]]] : memref<3xindex> +// CHECK: %[[VAL_82:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_83:.*]] = scf.for %[[VAL_84:.*]] = %[[VAL_13]] to %[[VAL_81]] step %[[VAL_13]] iter_args(%[[VAL_85:.*]] = %[[VAL_82]]) -> (index) { +// CHECK: %[[VAL_86:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref +// CHECK: %[[VAL_87:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_12]] : index +// CHECK: %[[VAL_88:.*]] = arith.select %[[VAL_87]], %[[VAL_85]], %[[VAL_86]] : index +// CHECK: scf.if %[[VAL_87]] { +// CHECK: memref.store %[[VAL_85]], %[[VAL_25]]{{\[}}%[[VAL_84]]] : memref +// CHECK: } +// CHECK: scf.yield %[[VAL_88]] : index +// CHECK: } +// CHECK: return %[[VAL_16]], %[[VAL_17]], %[[VAL_25]], %[[VAL_89:.*]]#0, %[[VAL_89]]#1 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: } +func.func @matmul(%A: tensor<4x8xf64, #CSR>, + %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> { + %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR> + %D = linalg.matmul + ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>) + outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> + return %D: tensor<4x4xf64, #CSR> +} -- 2.7.4