From 8a583bd53dcb723b7a2b5e950e9d78da31d0e6cc Mon Sep 17 00:00:00 2001 From: bixia1 Date: Wed, 7 Sep 2022 14:34:04 -0700 Subject: [PATCH] [mlir][sparse] Add codegen for expand op. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133454 --- .../mlir/Dialect/SparseTensor/Transforms/Passes.td | 1 + .../Transforms/SparseTensorCodegen.cpp | 62 ++++++++++++++++++++-- .../SparseTensor/Transforms/SparseTensorPasses.cpp | 4 +- mlir/test/Dialect/SparseTensor/codegen.mlir | 50 +++++++++++++++++ 4 files changed, 112 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index f7f4a39..8d2e069 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -169,6 +169,7 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> { let dependentDialects = [ "arith::ArithmeticDialect", "bufferization::BufferizationDialect", + "linalg::LinalgDialect", "memref::MemRefDialect", "scf::SCFDialect", "sparse_tensor::SparseTensorDialect", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 5eca3c8..9ad37bf 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" @@ -474,6 +475,58 @@ public: } }; +/// Sparse codegen rule for the expand op. +class SparseExpandConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + ShapedType srcType = op.getTensor().getType().cast(); + Type eltType = srcType.getElementType(); + Type boolType = rewriter.getIntegerType(1); + Type idxType = rewriter.getIndexType(); + // All initialization should be done on entry of the loop nest. + rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); + // Determine the size for access expansion (always the innermost stored + // dimension size, translated back to original dimension). Note that we + // recursively rewrite the new DimOp on the **original** tensor. + auto enc = getSparseTensorEncoding(srcType); + unsigned innerDim = srcType.getRank() - 1; + if (AffineMap p = enc.getDimOrdering()) + innerDim = p.getDimPosition(innerDim); + Value sz = rewriter.create(loc, op.getTensor(), innerDim); + // Generate a memref for `sz` elements of type `t`. + auto genAlloc = [&](Type t) { + auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t); + return rewriter.create(loc, memTp, ValueRange{sz}); + }; + // Allocate temporary buffers for values, filled-switch, and indices. + // We do not use stack buffers for this, since the expanded size may + // be rather large (as it envelops a single expanded dense dimension). + Value values = genAlloc(eltType); + Value filled = genAlloc(boolType); + Value indices = genAlloc(idxType); + Value zero = constantZero(rewriter, loc, idxType); + // Reset the values/filled-switch to all-zero/false. Note that this + // introduces an O(N) operation into the computation, but this reset + // operation is amortized over the innermost loops for the access + // pattern expansion. As noted in the operation doc, we would like + // to amortize this setup cost even between kernels. + rewriter.create( + loc, ValueRange{constantZero(rewriter, loc, eltType)}, + ValueRange{values}); + rewriter.create( + loc, ValueRange{constantZero(rewriter, loc, boolType)}, + ValueRange{filled}); + // Replace expansion op with these buffers and initial index. + assert(op.getNumResults() == 4); + rewriter.replaceOp(op, {values, filled, indices, zero}); + return success(); + } +}; + /// Sparse codegen rule for pointer accesses. class SparseToPointersConverter : public SparseGetterOpConverter { @@ -533,8 +586,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); + SparseCastConverter, SparseExpandConverter, + SparseTensorAllocConverter, SparseTensorDeallocConverter, + SparseToPointersConverter, SparseToIndicesConverter, + SparseToValuesConverter, SparseTensorLoadConverter>( + typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index fee4222..ebb6993 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -175,7 +175,9 @@ struct SparseTensorCodegenPass [&](bufferization::DeallocTensorOp op) { return converter.isLegal(op.getTensor().getType()); }); - // Legal dialects may occur in generated code. + // The following operations and dialects may be introduced by the + // codegen rules, and are therefore marked as legal. + target.addLegalOp(); target.addLegalDialect(); diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir index 667a5e1..a2bd754 100644 --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -286,3 +286,53 @@ func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> return %1 : tensor<10x20x30xf64, #Dense3D> } + +// CHECK-LABEL: func.func @sparse_expansion1() +// CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64> +// CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1> +// CHECK: %[[C:.*]] = memref.alloc() : memref<8xindex> +// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<8xindex> to memref +// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<8xf64>) +// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>) +// CHECK: return %[[D]] : memref +func.func @sparse_expansion1() -> memref { + %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR> + %values, %filled, %added, %count = sparse_tensor.expand %0 + : tensor<4x8xf64, #CSR> to memref, memref, memref, index + return %added : memref +} + +// CHECK-LABEL: func.func @sparse_expansion2() +// CHECK: %[[A:.*]] = memref.alloc() : memref<4xf64> +// CHECK: %[[B:.*]] = memref.alloc() : memref<4xi1> +// CHECK: %[[C:.*]] = memref.alloc() : memref<4xindex> +// CHECK: %[[D:.*]] = memref.cast %[[C]] : memref<4xindex> to memref +// CHECK-DAG: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<4xf64>) +// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<4xi1>) +// CHECK: return %[[D]] : memref +func.func @sparse_expansion2() -> memref { + %0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC> + %values, %filled, %added, %count = sparse_tensor.expand %0 + : tensor<4x8xf64, #CSC> to memref, memref, memref, index + return %added : memref +} + +// CHECK-LABEL: func.func @sparse_expansion3( +// CHECK-SAME: %[[D0:.*]]: index, +// CHECK-SAME: %{{.*}}: index) -> memref { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex> +// CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> +// CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> +// CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref +// CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref +// CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref +// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref) +// CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) +// CHECK: return %[[D]] : memref +func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref { + %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor + %values, %filled, %added, %count = sparse_tensor.expand %0 + : tensor to memref, memref, memref, index + return %added : memref +} -- 2.7.4