From dc46d5c979101e73fd61c1b6de942e17a2e8e480 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 8 Sep 2022 11:41:29 -0700 Subject: [PATCH] [mlir][sparse] improve dimop rewriting during conversion Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D133512 --- .../Transforms/SparseTensorConversion.cpp | 7 +- mlir/test/Dialect/SparseTensor/conversion.mlir | 93 +++++--- mlir/test/Dialect/SparseTensor/sparse_expand.mlir | 16 +- .../Dialect/SparseTensor/sparse_fill_zero.mlir | 237 +++++++++++---------- 4 files changed, 197 insertions(+), 156 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index df7a7de..e758a9b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1167,14 +1167,13 @@ public: // All initialization should be done on entry of the loop nest. rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); // Determine the size for access expansion (always the innermost stored - // dimension size, but we need to translate it back to the original - // dimension since the dim size utility applies dimension ordering). + // dimension size, translated back to original dimension). Note that we + // recursively rewrite the new DimOp on the **original** tensor. auto enc = getSparseTensorEncoding(srcType); - Value src = adaptor.getOperands()[0]; unsigned innerDim = srcType.getRank() - 1; if (AffineMap p = enc.getDimOrdering()) innerDim = p.getDimPosition(innerDim); - Value sz = genDimSizeCall(rewriter, loc, enc, src, innerDim); + Value sz = rewriter.create(loc, op.getTensor(), innerDim); // Allocate temporary buffers for values, filled-switch, and indices. // We do not use stack buffers for this, since the expanded size may // be rather large (as it envelops a single expanded dense dimension). diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir index 23bec35..99dab30 100644 --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -16,10 +16,15 @@ indexBitWidth = 32 }> -#SparseMatrix = #sparse_tensor.encoding<{ +#CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }> +#CSC = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "compressed"], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + #SparseTensor = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed", "compressed"], dimOrdering = affine_map<(i,j,k) -> (k,i,j)> @@ -97,9 +102,9 @@ func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> // CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xindex> to memref // CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromFile]], %[[A]]) // CHECK: return %[[T]] : !llvm.ptr -func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor { - %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor - return %0 : tensor +func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor { + %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor + return %0 : tensor } // CHECK-LABEL: func @sparse_new3d( @@ -135,10 +140,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor // CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[Empty]], %[[NP]]) // CHECK: return %[[T]] : !llvm.ptr -func.func @sparse_init(%arg0: index, %arg1: index) -> tensor { - %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor - %1 = sparse_tensor.load %0 : tensor - return %1 : tensor +func.func @sparse_init(%arg0: index, %arg1: index) -> tensor { + %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor + %1 = sparse_tensor.load %0 : tensor + return %1 : tensor } // CHECK-LABEL: func @sparse_release( @@ -277,9 +282,9 @@ func.func @sparse_convert_1d_ss(%arg0: tensor) -> tensor // CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]]) // CHECK: call @delSparseTensorCOOF64(%[[C]]) // CHECK: return %[[T]] : !llvm.ptr -func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> { - %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix> - return %0 : tensor<2x4xf64, #SparseMatrix> +func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> { + %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #CSR> + return %0 : tensor<2x4xf64, #CSR> } // CHECK-LABEL: func @sparse_constant() -> !llvm.ptr { @@ -309,12 +314,12 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseM // CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]]) // CHECK: call @delSparseTensorCOOF32(%[[C]]) // CHECK: return %[[T]] : !llvm.ptr -func.func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{ +func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{ // Initialize a tensor. %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32> // Convert the tensor to a sparse tensor. - %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #SparseMatrix> - return %1 : tensor<8x7xf32, #SparseMatrix> + %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #CSR> + return %1 : tensor<8x7xf32, #CSR> } // CHECK-LABEL: func @sparse_convert_3d( @@ -493,20 +498,52 @@ func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>, return } -// CHECK-LABEL: func @sparse_expansion() -// CHECK-DAG: %[[C:.*]] = arith.constant 1 : index +// CHECK-LABEL: func @sparse_expansion1() +// CHECK: %[[N:.*]] = call @newSparseTensor +// CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64> +// CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1> +// CHECK: %[[C:.*]] = memref.alloc() : memref<8xindex> +// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<8xindex> to memref +// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<8xf64>) +// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>) +// CHECK: return %[[D]] : memref +func.func @sparse_expansion1() -> memref { + %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR> + %values, %filled, %added, %count = sparse_tensor.expand %0 + : tensor<4x8xf64, #CSR> to memref, memref, memref, index + return %added : memref +} + +// CHECK-LABEL: func @sparse_expansion2() +// CHECK: %[[N:.*]] = call @newSparseTensor +// CHECK: %[[A:.*]] = memref.alloc() : memref<4xf64> +// CHECK: %[[B:.*]] = memref.alloc() : memref<4xi1> +// CHECK: %[[C:.*]] = memref.alloc() : memref<4xindex> +// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<4xindex> to memref +// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<4xf64>) +// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<4xi1>) +// CHECK: return %[[D]] : memref +func.func @sparse_expansion2() -> memref { + %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC> + %values, %filled, %added, %count = sparse_tensor.expand %0 + : tensor<4x8xf64, #CSC> to memref, memref, memref, index + return %added : memref +} + +// CHECK-LABEL: func @sparse_expansion3( +// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[N:.*]] = call @newSparseTensor -// CHECK: %[[S:.*]] = call @sparseDimSize(%[[N]], %[[C]]) : (!llvm.ptr, index) -> index +// CHECK: %[[S:.*]] = call @sparseDimSize(%[[N]], %c1) : (!llvm.ptr, index) -> index // CHECK: %[[A:.*]] = memref.alloc(%[[S]]) : memref // CHECK: %[[B:.*]] = memref.alloc(%[[S]]) : memref // CHECK: %[[C:.*]] = memref.alloc(%[[S]]) : memref // CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref) // CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) // CHECK: return %[[C]] : memref -func.func @sparse_expansion() -> memref { - %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #SparseMatrix> +func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref { + %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor %values, %filled, %added, %count = sparse_tensor.expand %0 - : tensor<4x8xf64, #SparseMatrix> to memref, memref, memref, index + : tensor to memref, memref, memref, index return %added : memref } @@ -521,11 +558,11 @@ func.func @sparse_expansion() -> memref { // CHECK-DAG: memref.dealloc %[[D]] : memref // CHECK-DAG: memref.dealloc %[[E]] : memref // CHECK: return -func.func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>, +func.func @sparse_compression(%arg0: tensor<8x8xf64, #CSR>, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 - : tensor<8x8xf64, #SparseMatrix>, memref, memref, memref, memref, index + : tensor<8x8xf64, #CSR>, memref, memref, memref, memref, index return } @@ -538,8 +575,8 @@ func.func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>, // CHECK: call @outSparseTensorF64(%[[COO]], %[[B]], %[[Sort]]) : (!llvm.ptr, !llvm.ptr, i1) -> () // CHECK: call @delSparseTensorCOOF64(%[[COO]]) // CHECK: return -func.func @sparse_out1(%arg0: tensor, %arg1: !llvm.ptr) { - sparse_tensor.out %arg0, %arg1 : tensor, !llvm.ptr +func.func @sparse_out1(%arg0: tensor, %arg1: !llvm.ptr) { + sparse_tensor.out %arg0, %arg1 : tensor, !llvm.ptr return } @@ -562,9 +599,9 @@ func.func @sparse_out2(%arg0: tensor, %arg1: !llvm.ptr // CHECK: %[[D:.*]] = bufferization.alloc_tensor // CHECK: return %[[S]], %[[D]] : !llvm.ptr, tensor func.func @sparse_and_dense_init(%arg0: index, %arg1: index) - -> (tensor, tensor) { - %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor - %1 = sparse_tensor.load %0 : tensor + -> (tensor, tensor) { + %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor + %1 = sparse_tensor.load %0 : tensor %2 = bufferization.alloc_tensor(%arg0, %arg1) : tensor - return %1, %2 : tensor, tensor + return %1, %2 : tensor, tensor } diff --git a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir index d266d4b..fd4e297 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_expand.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_expand.mlir @@ -98,12 +98,12 @@ func.func @kernel(%arga: tensor) -> tensor { // CHECK-CONVERT-LABEL: func @matmul1( // CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-CONVERT-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-CONVERT-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK-CONVERT: %[[N:.*]] = call @newSparseTensor -// CHECK-CONVERT: %[[S:.*]] = call @sparseDimSize(%[[N]], %[[C1]]) -// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref -// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[S]]) : memref -// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[S]]) : memref +// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[C4]]) : memref +// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[C4]]) : memref +// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[C4]]) : memref // CHECK-CONVERT: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref) // CHECK-CONVERT: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) // CHECK-CONVERT: scf.for %{{.*}} = %[[C0]] to %[[C8]] step %[[C1]] { @@ -147,11 +147,11 @@ func.func @matmul1(%A: tensor<8x2xf64, #CSR>, // CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-CONVERT-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-CONVERT-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK-CONVERT: %[[N:.*]] = call @newSparseTensor -// CHECK-CONVERT: %[[S:.*]] = call @sparseDimSize(%[[N]], %[[C1]]) -// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref -// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[S]]) : memref -// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[S]]) : memref +// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[C8]]) : memref +// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[C8]]) : memref +// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[C8]]) : memref // CHECK-CONVERT: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref) // CHECK-CONVERT: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) // CHECK-CONVERT: scf.for %{{.*}} = %[[C0]] to %[[C4]] step %[[C1]] { diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir index f2812cd..7b87411 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir @@ -1,122 +1,127 @@ // RUN: mlir-opt %s --linalg-generalize-named-ops --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> -// CHECK-LABEL: func.func @fill_zero_after_alloc -// CHECK-SAME: %[[TMP_arg0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[TMP_arg1:.*]]: !llvm.ptr -// CHECK: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32 -// CHECK: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32 -// CHECK: %[[TMP_c0:.*]] = arith.constant 0 : index -// CHECK: %[[TMP_c1:.*]] = arith.constant 1 : index -// CHECK: %[[TMP_false:.*]] = arith.constant false -// CHECK: %[[TMP_true:.*]] = arith.constant true -// CHECK: %[[TMP_c100:.*]] = arith.constant 100 : index -// CHECK: %[[TMP_c1_i8:.*]] = arith.constant 1 : i8 -// CHECK: %[[TMP_0:.*]] = memref.alloca() : memref<2xi8> -// CHECK: %[[TMP_1:.*]] = memref.cast %[[TMP_0]] : memref<2xi8> to memref -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8> -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8> -// CHECK: %[[TMP_2:.*]] = memref.alloca() : memref<2xindex> -// CHECK: %[[TMP_3:.*]] = memref.cast %[[TMP_2]] : memref<2xindex> to memref -// CHECK: memref.store %[[TMP_c100]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xindex> -// CHECK: memref.store %[[TMP_c100]], %[[TMP_2]][%[[TMP_c1]]] : memref<2xindex> -// CHECK: %[[TMP_4:.*]] = memref.alloca() : memref<2xindex> -// CHECK: %[[TMP_5:.*]] = memref.cast %[[TMP_4]] : memref<2xindex> to memref -// CHECK: memref.store %[[TMP_c0]], %[[TMP_4]][%[[TMP_c0]]] : memref<2xindex> -// CHECK: memref.store %[[TMP_c1]], %[[TMP_4]][%[[TMP_c1]]] : memref<2xindex> -// CHECK: %[[TMP_6:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK: %[[TMP_7:.*]] = call @newSparseTensor(%[[TMP_1]], %[[TMP_3]], %[[TMP_5]], %[[TMP_c0_i32]], %[[TMP_c0_i32]], %[[TMP_c1_i32]], %[[TMP_c0_i32]], %[[TMP_6]]) -// CHECK: %[[TMP_8:.*]] = call @sparseDimSize(%[[TMP_7]], %[[TMP_c1]]) -// CHECK: %[[TMP_9:.*]] = memref.alloc(%[[TMP_8]]) : memref -// CHECK: %[[TMP_10:.*]] = memref.alloc(%[[TMP_8]]) : memref -// CHECK: %[[TMP_11:.*]] = memref.alloc(%[[TMP_8]]) : memref -// CHECK: linalg.fill ins(%[[TMP_cst]] : f64) outs(%[[TMP_9]] : memref) -// CHECK: linalg.fill ins(%[[TMP_false]] : i1) outs(%[[TMP_10]] : memref) -// CHECK: %[[TMP_12:.*]] = call @sparsePointers0(%[[TMP_arg0]], %[[TMP_c0]]) -// CHECK: %[[TMP_13:.*]] = call @sparseIndices0(%[[TMP_arg0]], %[[TMP_c0]]) -// CHECK: %[[TMP_14:.*]] = call @sparsePointers0(%[[TMP_arg0]], %[[TMP_c1]]) -// CHECK: %[[TMP_15:.*]] = call @sparseIndices0(%[[TMP_arg0]], %[[TMP_c1]]) -// CHECK: %[[TMP_16:.*]] = call @sparseValuesF64(%[[TMP_arg0]]) -// CHECK: %[[TMP_17:.*]] = call @sparsePointers0(%[[TMP_arg1]], %[[TMP_c0]]) -// CHECK: %[[TMP_18:.*]] = call @sparseIndices0(%[[TMP_arg1]], %[[TMP_c0]]) -// CHECK: %[[TMP_19:.*]] = call @sparsePointers0(%[[TMP_arg1]], %[[TMP_c1]]) -// CHECK: %[[TMP_20:.*]] = call @sparseIndices0(%[[TMP_arg1]], %[[TMP_c1]]) -// CHECK: %[[TMP_21:.*]] = call @sparseValuesF64(%[[TMP_arg1]]) -// CHECK: %[[TMP_22:.*]] = memref.alloca() : memref<2xindex> -// CHECK: %[[TMP_23:.*]] = memref.cast %[[TMP_22]] : memref<2xindex> to memref -// CHECK: %[[TMP_24:.*]] = memref.load %[[TMP_12]][%[[TMP_c0]]] : memref -// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_12]][%[[TMP_c1]]] : memref -// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_24]] to %[[TMP_25]] step %[[TMP_c1]] { -// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_13]][%[[TMP_arg2]]] : memref -// CHECK: memref.store %[[TMP_26]], %[[TMP_22]][%[[TMP_c0]]] : memref<2xindex> -// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_14]][%[[TMP_arg2]]] : memref -// CHECK: %[[TMP_28:.*]] = arith.addi %[[TMP_arg2]], %[[TMP_c1]] : index -// CHECK: %[[TMP_29:.*]] = memref.load %[[TMP_14]][%[[TMP_28]]] : memref -// CHECK: %[[TMP_30:.*]] = memref.load %[[TMP_17]][%[[TMP_c0]]] : memref -// CHECK: %[[TMP_31:.*]] = memref.load %[[TMP_17]][%[[TMP_c1]]] : memref -// CHECK: %[[TMP_32:.*]]:3 = scf.while (%[[TMP_arg3:.*]] = %[[TMP_27]], %[[TMP_arg4:.*]] = %[[TMP_30]], %[[TMP_arg5:.*]] = %[[TMP_c0]]) : (index, index, index) -> (index, index, index) { -// CHECK: %[[TMP_33:.*]] = arith.cmpi ult, %[[TMP_arg3]], %[[TMP_29]] : index -// CHECK: %[[TMP_34:.*]] = arith.cmpi ult, %[[TMP_arg4]], %[[TMP_31]] : index -// CHECK: %[[TMP_35:.*]] = arith.andi %[[TMP_33]], %[[TMP_34]] : i1 -// CHECK: scf.condition(%[[TMP_35]]) %[[TMP_arg3]], %[[TMP_arg4]], %[[TMP_arg5]] : index, index, index -// CHECK: } do { -// CHECK: ^bb0(%[[TMP_arg3:.*]]: index, %[[TMP_arg4:.*]]: index, %[[TMP_arg5:.*]]: index): -// CHECK: %[[TMP_33:.*]] = memref.load %[[TMP_15]][%[[TMP_arg3]]] : memref -// CHECK: %[[TMP_34:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref -// CHECK: %[[TMP_35:.*]] = arith.cmpi ult, %[[TMP_34]], %[[TMP_33]] : index -// CHECK: %[[TMP_36:.*]] = arith.select %[[TMP_35]], %[[TMP_34]], %[[TMP_33]] : index -// CHECK: %[[TMP_37:.*]] = arith.cmpi eq, %[[TMP_33]], %[[TMP_36]] : index -// CHECK: %[[TMP_38:.*]] = arith.cmpi eq, %[[TMP_34]], %[[TMP_36]] : index -// CHECK: %[[TMP_39:.*]] = arith.andi %[[TMP_37]], %[[TMP_38]] : i1 -// CHECK: %[[TMP_40:.*]] = scf.if %[[TMP_39]] -> (index) { -// CHECK: %[[TMP_45:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref -// CHECK: %[[TMP_46:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref -// CHECK: %[[TMP_47:.*]] = arith.addi %[[TMP_arg4]], %[[TMP_c1]] : index -// CHECK: %[[TMP_48:.*]] = memref.load %[[TMP_19]][%[[TMP_47]]] : memref -// CHECK: %[[TMP_49:.*]] = scf.for %[[TMP_arg6:.*]] = %[[TMP_46]] to %[[TMP_48]] step %[[TMP_c1]] iter_args(%[[TMP_arg7:.*]] = %[[TMP_arg5]]) -> (index) { -// CHECK: %[[TMP_50:.*]] = memref.load %[[TMP_20]][%[[TMP_arg6]]] : memref -// CHECK: %[[TMP_51:.*]] = memref.load %[[TMP_9]][%[[TMP_50]]] : memref -// CHECK: %[[TMP_52:.*]] = memref.load %[[TMP_21]][%[[TMP_arg6]]] : memref -// CHECK: %[[TMP_53:.*]] = arith.mulf %[[TMP_45]], %[[TMP_52]] : f64 -// CHECK: %[[TMP_54:.*]] = arith.addf %[[TMP_51]], %[[TMP_53]] : f64 -// CHECK: %[[TMP_55:.*]] = memref.load %[[TMP_10]][%[[TMP_50]]] : memref -// CHECK: %[[TMP_56:.*]] = arith.cmpi eq, %[[TMP_55]], %[[TMP_false]] : i1 -// CHECK: %[[TMP_57:.*]] = scf.if %[[TMP_56]] -> (index) { -// CHECK: memref.store %[[TMP_true]], %[[TMP_10]][%[[TMP_50]]] : memref -// CHECK: memref.store %[[TMP_50]], %[[TMP_11]][%[[TMP_arg7]]] : memref -// CHECK: %[[TMP_58:.*]] = arith.addi %[[TMP_arg7]], %[[TMP_c1]] : index -// CHECK: scf.yield %[[TMP_58]] : index -// CHECK: } else { -// CHECK: scf.yield %[[TMP_arg7]] : index -// CHECK: } -// CHECK: memref.store %[[TMP_54]], %[[TMP_9]][%[[TMP_50]]] : memref -// CHECK: scf.yield %[[TMP_57]] : index -// CHECK: } -// CHECK: scf.yield %[[TMP_49]] : index -// CHECK: } else { -// CHECK: scf.yield %[[TMP_arg5]] : index -// CHECK: } -// CHECK: %[[TMP_41:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index -// CHECK: %[[TMP_42:.*]] = arith.select %[[TMP_37]], %[[TMP_41]], %[[TMP_arg3]] : index -// CHECK: %[[TMP_43:.*]] = arith.addi %[[TMP_arg4]], %[[TMP_c1]] : index -// CHECK: %[[TMP_44:.*]] = arith.select %[[TMP_38]], %[[TMP_43]], %[[TMP_arg4]] : index -// CHECK: scf.yield %[[TMP_42]], %[[TMP_44]], %[[TMP_40]] : index, index, index -// CHECK: } -// CHECK: func.call @expInsertF64(%[[TMP_7]], %[[TMP_23]], %[[TMP_9]], %[[TMP_10]], %[[TMP_11]], %[[TMP_32]]#2) -// CHECK: } -// CHECK: memref.dealloc %[[TMP_9]] : memref -// CHECK: memref.dealloc %[[TMP_10]] : memref -// CHECK: memref.dealloc %[[TMP_11]] : memref -// CHECK: call @endInsert(%[[TMP_7]]) : (!llvm.ptr) -> () -// CHECK: return %[[TMP_7]] : !llvm.ptr -func.func @fill_zero_after_alloc(%arg0: tensor<100x100xf64, #DCSR>, - %arg1: tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR> { - %0 = bufferization.alloc_tensor() : tensor<100x100xf64, #DCSR> + +// CHECK-LABEL: func.func @fill_zero_after_alloc( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) -> !llvm.ptr { +// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = arith.constant false +// CHECK: %[[VAL_8:.*]] = arith.constant true +// CHECK: %[[VAL_9:.*]] = arith.constant 100 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 300 : index +// CHECK: %[[VAL_11:.*]] = arith.constant 1 : i8 +// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8> +// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref +// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8> +// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8> +// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex> +// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref +// CHECK: memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex> +// CHECK: memref.store %[[VAL_10]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex> +// CHECK: %[[VAL_16:.*]] = memref.alloca() : memref<2xindex> +// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref +// CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex> +// CHECK: memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex> +// CHECK: %[[VAL_18:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_13]], %[[VAL_15]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<300xf64> +// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref +// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<300xi1> +// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<300xi1> to memref +// CHECK: %[[VAL_24:.*]] = memref.alloc() : memref<300xindex> +// CHECK: %[[VAL_25:.*]] = memref.cast %[[VAL_24]] : memref<300xindex> to memref +// CHECK: linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_20]] : memref<300xf64>) +// CHECK: linalg.fill ins(%[[VAL_7]] : i1) outs(%[[VAL_22]] : memref<300xi1>) +// CHECK: %[[VAL_26:.*]] = call @sparsePointers0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_27:.*]] = call @sparseIndices0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_28:.*]] = call @sparsePointers0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_29:.*]] = call @sparseIndices0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_30:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK: %[[VAL_31:.*]] = call @sparsePointers0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_32:.*]] = call @sparseIndices0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_33:.*]] = call @sparsePointers0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_34:.*]] = call @sparseIndices0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr) -> memref +// CHECK: %[[VAL_36:.*]] = memref.alloca() : memref<2xindex> +// CHECK: %[[VAL_37:.*]] = memref.cast %[[VAL_36]] : memref<2xindex> to memref +// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref +// CHECK: scf.for %[[VAL_40:.*]] = %[[VAL_38]] to %[[VAL_39]] step %[[VAL_6]] { +// CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_40]]] : memref +// CHECK: memref.store %[[VAL_41]], %[[VAL_36]]{{\[}}%[[VAL_5]]] : memref<2xindex> +// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_40]]] : memref +// CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_40]], %[[VAL_6]] : index +// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_43]]] : memref +// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_47:.*]]:3 = scf.while (%[[VAL_48:.*]] = %[[VAL_42]], %[[VAL_49:.*]] = %[[VAL_45]], %[[VAL_50:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) { +// CHECK: %[[VAL_51:.*]] = arith.cmpi ult, %[[VAL_48]], %[[VAL_44]] : index +// CHECK: %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_49]], %[[VAL_46]] : index +// CHECK: %[[VAL_53:.*]] = arith.andi %[[VAL_51]], %[[VAL_52]] : i1 +// CHECK: scf.condition(%[[VAL_53]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]] : index, index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_54:.*]]: index, %[[VAL_55:.*]]: index, %[[VAL_56:.*]]: index): +// CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_54]]] : memref +// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_55]]] : memref +// CHECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_57]] : index +// CHECK: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_57]] : index +// CHECK: %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index +// CHECK: %[[VAL_62:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index +// CHECK: %[[VAL_63:.*]] = arith.andi %[[VAL_61]], %[[VAL_62]] : i1 +// CHECK: %[[VAL_64:.*]] = scf.if %[[VAL_63]] -> (index) { +// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_54]]] : memref +// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_55]]] : memref +// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_55]], %[[VAL_6]] : index +// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_67]]] : memref +// CHECK: %[[VAL_69:.*]] = scf.for %[[VAL_70:.*]] = %[[VAL_66]] to %[[VAL_68]] step %[[VAL_6]] iter_args(%[[VAL_71:.*]] = %[[VAL_56]]) -> (index) { +// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_70]]] : memref +// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_72]]] : memref<300xf64> +// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_70]]] : memref +// CHECK: %[[VAL_75:.*]] = arith.mulf %[[VAL_65]], %[[VAL_74]] : f64 +// CHECK: %[[VAL_76:.*]] = arith.addf %[[VAL_73]], %[[VAL_75]] : f64 +// CHECK: %[[VAL_77:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_72]]] : memref<300xi1> +// CHECK: %[[VAL_78:.*]] = arith.cmpi eq, %[[VAL_77]], %[[VAL_7]] : i1 +// CHECK: %[[VAL_79:.*]] = scf.if %[[VAL_78]] -> (index) { +// CHECK: memref.store %[[VAL_8]], %[[VAL_22]]{{\[}}%[[VAL_72]]] : memref<300xi1> +// CHECK: memref.store %[[VAL_72]], %[[VAL_24]]{{\[}}%[[VAL_71]]] : memref<300xindex> +// CHECK: %[[VAL_80:.*]] = arith.addi %[[VAL_71]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_80]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_71]] : index +// CHECK: } +// CHECK: memref.store %[[VAL_76]], %[[VAL_20]]{{\[}}%[[VAL_72]]] : memref<300xf64> +// CHECK: scf.yield %[[VAL_81:.*]] : index +// CHECK: } +// CHECK: scf.yield %[[VAL_82:.*]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_56]] : index +// CHECK: } +// CHECK: %[[VAL_83:.*]] = arith.addi %[[VAL_54]], %[[VAL_6]] : index +// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_61]], %[[VAL_83]], %[[VAL_54]] : index +// CHECK: %[[VAL_85:.*]] = arith.addi %[[VAL_55]], %[[VAL_6]] : index +// CHECK: %[[VAL_86:.*]] = arith.select %[[VAL_62]], %[[VAL_85]], %[[VAL_55]] : index +// CHECK: scf.yield %[[VAL_84]], %[[VAL_86]], %[[VAL_87:.*]] : index, index, index +// CHECK: } +// CHECK: func.call @expInsertF64(%[[VAL_19]], %[[VAL_37]], %[[VAL_21]], %[[VAL_23]], %[[VAL_25]], %[[VAL_88:.*]]#2) : (!llvm.ptr, memref, memref, memref, memref, index) -> () +// CHECK: } +// CHECK: memref.dealloc %[[VAL_20]] : memref<300xf64> +// CHECK: memref.dealloc %[[VAL_22]] : memref<300xi1> +// CHECK: memref.dealloc %[[VAL_24]] : memref<300xindex> +// CHECK: call @endInsert(%[[VAL_19]]) : (!llvm.ptr) -> () +// CHECK: return %[[VAL_19]] : !llvm.ptr +// CHECK: } +func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>, + %arg1: tensor<200x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> { + %0 = bufferization.alloc_tensor() : tensor<100x300xf64, #DCSR> %cst = arith.constant 0.000000e+00 : f64 %1 = linalg.fill ins(%cst : f64) - outs(%0 : tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<100x100xf64, #DCSR>, tensor<100x100xf64, #DCSR>) - outs(%1 : tensor<100x100xf64, #DCSR>) -> tensor<100x100xf64, #DCSR> - return %2 : tensor<100x100xf64, #DCSR> + outs(%0 : tensor<100x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<100x200xf64, #DCSR>, tensor<200x300xf64, #DCSR>) + outs(%1 : tensor<100x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> + return %2 : tensor<100x300xf64, #DCSR> } -- 2.7.4