From 652b39b46f85ad826a20d3e0cec5d0db91b43daf Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Wed, 23 Feb 2022 16:22:42 -0800 Subject: [PATCH] [mlir][sparse][linalg] add linalg rewriting specific to sparse tensors Now that sparse tensor types are first-class citizens and the sparse compiler is taking shape, it is time to make sure other compiler optimizations compose well with sparse tensors. Mostly, this should be completely transparent (i.e., dense and sparse take the same path). However, in some cases, optimizations only make sense in the context of sparse tensors. This is a first example of such an optimization, where fusing a sampled elt-wise multiplication only makes sense when the resulting kernel has a potential lower asymptotic complexity due to the sparsity. As an extreme example, running SDDMM with 1024x1024 matrices and a sparse sampling matrix with only two elements runs in 463.55ms in the unfused case but just 0.032ms in the fused case, with a speedup of 14485x that is only possible in the exciting world of sparse computations! Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D120429 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 3 + mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt | 1 + .../Linalg/Transforms/ElementwiseOpFusion.cpp | 3 +- .../Linalg/Transforms/SparseTensorRewriting.cpp | 213 +++++++++++++++++++++ .../SparseTensor/CPU/sparse_sampled_mm_fusion.mlir | 96 ++++++++-- .../Dialect/SparseTensor/taco/test_SDDMM.py | 5 +- 6 files changed, 305 insertions(+), 16 deletions(-) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 80ec20ac..a8cd374 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -59,6 +59,9 @@ void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, /// parallel loops. void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); +/// Populate patterns that are only useful in the context of sparse tensors. +void populateSparseTensorRewriting(RewritePatternSet &patterns); + /// Function type which is used to control when to stop fusion. It is expected /// that OpOperand is not modified in the callback. The OpOperand is not marked /// as const to allow callers to use non-const methods. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index f758546..ec8c8c4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms NamedOpConversions.cpp PadOpInterchange.cpp Promotion.cpp + SparseTensorRewriting.cpp Tiling.cpp Transforms.cpp Vectorization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 7e0e857..3493f4e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -49,7 +49,7 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( AffineMap invProducerResultIndexMap = inversePermutation(producerResultIndexMap); assert(invProducerResultIndexMap && - "expected producer result indexig map to be invertible"); + "expected producer result indexing map to be invertible"); LinalgOp producer = cast(producerOpOperand->getOwner()); // argMap is a map from producer loop -> producer arg tensor index. @@ -2264,6 +2264,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns( FoldConstantTranspose>(context, options.controlElementwiseOpsFusionFn); patterns.add(context); + populateSparseTensorRewriting(patterns); populateFoldReshapeOpsByExpansionPatterns(patterns, options.controlFoldingReshapesFn); AffineApplyOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp new file mode 100644 index 0000000..3958ab3 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp @@ -0,0 +1,213 @@ +//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements linalg dialect rewriting specific to sparse tensors. +// +// Sparsity should be mostly transparent to the linalg dialect optimizations +// (i.e., the dense and sparse take the same path). However, in some cases, +// optimizations only make sense in the context of sparse tensors. This file +// implements such sparsity specific rewriting rules. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::sparse_tensor; + +//===---------------------------------------------------------------------===// +// Helper methods for the actual rewriting rules. +//===---------------------------------------------------------------------===// + +// Helper to detect a sparse tensor type operand. +static bool isSparseTensor(OpOperand *op) { + if (auto enc = getSparseTensorEncoding(op->get().getType())) { + ArrayRef dimTypes = + enc.getDimLevelType(); + for (unsigned i = 0, e = dimTypes.size(); i < e; i++) + if (dimTypes[i] == SparseTensorEncodingAttr::DimLevelType::Compressed) + return true; // at least one compressed + } + return false; +} + +// Helper method to find zero or empty initialization. +static bool isEmptyInit(OpOperand *op) { + Value val = op->get(); + if (matchPattern(val, m_Zero())) + return true; + if (matchPattern(val, m_AnyZeroFloat())) + return true; + if (val.getDefiningOp()) + return true; + if (val.getDefiningOp()) + return true; + return false; +} + +// Helper to detect sampling operation. +static bool isSampling(GenericOp op) { + auto yieldOp = cast(op.region().front().getTerminator()); + if (auto def = yieldOp.getOperand(0).getDefiningOp()) { + if (isa(def) || isa(def)) { + // Both scalar input arguments used exactly once. + Value s1 = op.getBlock()->getArgument(0); + Value s2 = op.getBlock()->getArgument(1); + return (def->getOperand(0) == s1 && def->getOperand(1) == s2) || + (def->getOperand(1) == s1 && def->getOperand(0) == s2); + } + } + return false; +} + +// Helper to detect chain of multiplications that do not involve x. +static bool isMulChain(Value val, Value x) { + if (auto arg = val.dyn_cast()) + return arg != x; + if (auto def = val.getDefiningOp()) { + if (isa(def) || isa(def)) + return isMulChain(def->getOperand(0), x) && + isMulChain(def->getOperand(1), x); + } + return false; +} + +// Helper to detect x = x + . +static bool isSumOfMul(GenericOp op) { + auto yieldOp = cast(op.region().front().getTerminator()); + if (auto def = yieldOp.getOperand(0).getDefiningOp()) { + if (isa(def) || isa(def)) { + Value x = op.getBlock()->getArguments().back(); + return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) || + (def->getOperand(1) == x && isMulChain(def->getOperand(0), x)); + } + } + return false; +} + +//===---------------------------------------------------------------------===// +// The actual sparse tensor rewriting rules. +//===---------------------------------------------------------------------===// + +namespace { +/// Rewriting rule that converts two kernels: +/// +/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) +/// X(i,j) = S(i,j) * T(i,j) +/// +/// into a single kernel, using distributive law: +/// +/// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... ) +/// +/// This kind of fusion (merging two ops into one but using arithmetic +/// equalities that may not hold for floating-point computations) would +/// be undesirable in the dense case, since we distribute the multiplication +/// into the reduction loop. However, for sparse sampling tensor S, such +/// a fusion may actually reduce the asymptotic complexity of the kernel, +/// since intermediate results may be nullified. +struct FuseSparseMultiplyOverAdd : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + // Check consumer. + if (!op.hasTensorSemantics() || op.getNumInputs() != 2 || + op.getNumResults() != 1) + return failure(); + if (op.getNumParallelLoops() != op.getNumLoops()) + return failure(); + if (!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() || + !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() || + !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity()) + return failure(); + // Find consuming OP2(sparse, other) or OP2(other, sparse). The other + // operand can be sparse or dense, since the point of this rewriting rule + // is detecting a situation in which *more* sparsity is introduced into + // a computation, be it already sparse or still dense. + unsigned other = 0; + if (isSparseTensor(op.getInputOperand(0))) + other = 1; + else if (!isSparseTensor(op.getInputOperand(1))) + return failure(); + // Check producer. + auto prod = dyn_cast_or_null( + op.getInputOperand(other)->get().getDefiningOp()); + if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1) + return failure(); + if (!prod.getResult(0).hasOneUse()) + return failure(); + // Sampling consumer and sum of multiplication chain producer. + if (!isEmptyInit(op.getOutputOperand(0)) || + !isEmptyInit(prod.getOutputOperand(0))) + return failure(); + if (!isSampling(op) || !isSumOfMul(prod)) + return failure(); + // Modify operand structure of producer and consumer. + Location loc = prod.getLoc(); + SmallVector inputOps = prod.getInputOperands(); + SmallVector outputOps = op.getOutputOperands(); + SmallVector fusedIndexMaps = prod.getIndexingMaps(); + inputOps.push_back(op.getInputOperand(1 - other)->get()); + fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other + // Fuse producer and consumer into a new generic op. + auto fusedOp = rewriter.create( + loc, op.getResult(0).getType(), inputOps, outputOps, + rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(), + /*doc=*/nullptr, /*library_call=*/nullptr); + Block &prodBlock = prod.region().front(); + Block &consBlock = op.region().front(); + BlockAndValueMapping mapper; + Block *fusedBlock = new Block(); + fusedOp.region().push_back(fusedBlock); + unsigned num = prodBlock.getNumArguments(); + for (unsigned i = 0; i < num - 1; i++) + addArg(mapper, fusedBlock, prodBlock.getArgument(i)); + addArg(mapper, fusedBlock, consBlock.getArgument(1 - other)); + addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1)); + // Clone bodies of the producer and consumer in new evaluation order. + auto acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); + auto sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); + rewriter.setInsertionPointToStart(fusedBlock); + Value last; + for (auto &op : prodBlock.without_terminator()) + if (&op != acc) { + last = op.getResult(0); + rewriter.clone(op, mapper); + } + mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0)); + mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); + last = rewriter.clone(*acc, mapper)->getResult(0); + rewriter.create(loc, last); + // Replace consumer with fused operation. Old producer + // and consumer ops will be removed by DCE. + rewriter.replaceOp(op, fusedOp->getResults()); + return success(); + } + +private: + // Helper to add argument and record the mapping. + static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) { + mapper.map(a, b->addArgument(a.getType(), a.getLoc())); + } +}; +} // namespace + +//===---------------------------------------------------------------------===// +// Methods that add patterns described in this file to a pattern list. +//===---------------------------------------------------------------------===// + +void mlir::linalg::populateSparseTensorRewriting(RewritePatternSet &patterns) { + auto *context = patterns.getContext(); + patterns.add(context); +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir index 0179503..0cbdd7a 100755 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir @@ -5,7 +5,7 @@ // // Do the same run, but now with SIMDization as well. This should not change the outcome. // -// RUN: mlir-opt %s -sparse-compiler="vectorization-strategy=2 vl=8" | \ +// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=2 vl=8" | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s @@ -46,7 +46,8 @@ // module { // - // A kernel that computes a direct sampled matrix matrix multiplication. + // A kernel that computes a direct sampled matrix matrix multiplication + // (with dense result). // func @sampled_dd(%args: tensor<8x8xf64, #SM>, %arga: tensor<8x8xf64>, @@ -66,11 +67,13 @@ module { } // - // A kernel that computes an unfused sampled matrix matrix multiplication. + // A kernel that computes an unfused sampled matrix matrix multiplication + // (with dense result). // func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, %arga: tensor<8x8xf64>, - %argb: tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8x8xf64>) { + %argb: tensor<8x8xf64>) -> tensor<8x8xf64> { + // Perform dense-dense matrix matrix multiplication. %1 = arith.constant dense<0.0> : tensor<8x8xf64> %2 = linalg.generic #trait_matmul ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>) @@ -80,17 +83,68 @@ module { %q = arith.addf %x, %p : f64 linalg.yield %q : f64 } -> tensor<8x8xf64> - - %3 = arith.constant dense<0.0> : tensor<8x8xf64> - %4 = linalg.generic #trait_scale + // Sample the result with elements-wise multiplication with sparse matrix. + %3 = linalg.generic #trait_scale ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>) - outs(%3 : tensor<8x8xf64>) { + outs(%1 : tensor<8x8xf64>) { ^bb0(%t: f64, %s: f64, %x: f64): %r = arith.mulf %t, %s : f64 linalg.yield %r : f64 } -> tensor<8x8xf64> + return %3 : tensor<8x8xf64> + } - return %4, %2 : tensor<8x8xf64>, tensor<8x8xf64> + // + // A kernel that computes a direct sampled matrix matrix multiplication + // (with sparse result). + // + func @sparse_sampled_dd(%args: tensor<8x8xf64, #SM>, + %arga: tensor<8x8xf64>, + %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> { + %c8 = arith.constant 8 : index + %1 = sparse_tensor.init [%c8, %c8] : tensor<8x8xf64, #SM> + %2 = linalg.generic #trait_sampled_dense_dense + ins(%args, %arga, %argb: tensor<8x8xf64, #SM>, + tensor<8x8xf64>, tensor<8x8xf64>) + outs(%1: tensor<8x8xf64, #SM>) { + ^bb(%s: f64, %a: f64, %b: f64, %x: f64): + %p = arith.mulf %a, %b : f64 + %q = arith.mulf %s, %p : f64 + %r = arith.addf %x, %q : f64 + linalg.yield %r : f64 + } -> tensor<8x8xf64, #SM> + return %2 : tensor<8x8xf64, #SM> + } + + // + // A kernel that computes an unfused sampled matrix matrix multiplication + // (with sparse result). + // + func @sparse_sampled_dd_unfused( + %args: tensor<8x8xf64, #SM>, + %arga: tensor<8x8xf64>, + %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> { + // Perform dense-dense matrix matrix multiplication. + %1 = arith.constant dense<0.0> : tensor<8x8xf64> + %2 = linalg.generic #trait_matmul + ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>) + outs(%1 : tensor<8x8xf64>) { + ^bb0(%a: f64, %b: f64, %x: f64): + %p = arith.mulf %a, %b : f64 + %q = arith.addf %x, %p : f64 + linalg.yield %q : f64 + } -> tensor<8x8xf64> + // Sample the result with elements-wise multiplication with sparse matrix. + %c8 = arith.constant 8 : index + %3 = sparse_tensor.init [%c8, %c8] : tensor<8x8xf64, #SM> + %4 = linalg.generic #trait_scale + ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>) + outs(%3 : tensor<8x8xf64, #SM>) { + ^bb0(%t: f64, %s: f64, %x: f64): + %r = arith.mulf %t, %s : f64 + linalg.yield %r : f64 + } -> tensor<8x8xf64, #SM> + return %4 : tensor<8x8xf64, #SM> } // @@ -112,9 +166,15 @@ module { %0 = call @sampled_dd(%s, %a, %b) : (tensor<8x8xf64, #SM>, tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64> - %1, %2 = call @sampled_dd_unfused(%s, %a, %b) + %1 = call @sampled_dd_unfused(%s, %a, %b) : (tensor<8x8xf64, #SM>, - tensor<8x8xf64>, tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8x8xf64>) + tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64> + %2 = call @sparse_sampled_dd(%s, %a, %b) + : (tensor<8x8xf64, #SM>, + tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64, #SM> + %3 = call @sparse_sampled_dd_unfused(%s, %a, %b) + : (tensor<8x8xf64, #SM>, + tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64, #SM> // Verify the outputs. // @@ -128,21 +188,31 @@ module { // CHECK-SAME: ( 0, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ), // CHECK-SAME: ( 0, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 192 ) ) // + // CHECK-NEXT: ( 96, 192, 0, 0 ) + // + // CHECK-NEXT: ( 96, 192, 0, 0 ) + // %m0 = bufferization.to_memref %0 : memref<8x8xf64> %m1 = bufferization.to_memref %1 : memref<8x8xf64> - %m2 = bufferization.to_memref %2 : memref<8x8xf64> + %m2 = sparse_tensor.values %2 : tensor<8x8xf64, #SM> to memref + %m3 = sparse_tensor.values %3 : tensor<8x8xf64, #SM> to memref %v0 = vector.transfer_read %m0[%c0, %c0], %d0 : memref<8x8xf64>, vector<8x8xf64> %v1 = vector.transfer_read %m1[%c0, %c0], %d0 : memref<8x8xf64>, vector<8x8xf64> + %v2 = vector.transfer_read %m2[%c0], %d0 : memref, vector<4xf64> + %v3 = vector.transfer_read %m3[%c0], %d0 : memref, vector<4xf64> vector.print %v0 : vector<8x8xf64> vector.print %v1 : vector<8x8xf64> + vector.print %v2 : vector<4xf64> + vector.print %v3 : vector<4xf64> // Release the resources. sparse_tensor.release %s : tensor<8x8xf64, #SM> memref.dealloc %m0 : memref<8x8xf64> memref.dealloc %m1 : memref<8x8xf64> - memref.dealloc %m2 : memref<8x8xf64> + sparse_tensor.release %2 : tensor<8x8xf64, #SM> + sparse_tensor.release %3 : tensor<8x8xf64, #SM> return } diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py index 876e6bd..9f017ad1 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py @@ -33,8 +33,9 @@ X[i, j] = S[i, j] * A[i, k] * B[k, j] # Alternative way to define SDDMM kernel. Since this performs the reduction as # sum(k, A[i, k] * B[k, j]) * S[i, j] -# the MLIR lowering results in two separate tensor index expressions that -# need to be fused properly to guarantee proper asymptotic complexity. +# the MLIR lowering results in two separate tensor index expressions that are +# fused prior to running the sparse compiler in order to guarantee proper +# asymptotic complexity. Y[i, j] = A[i, k] * B[k, j] * S[i, j] expected = """; extended FROSTT format -- 2.7.4