From 86b22d312053f38c7ea94af49dd0e93c660ffec8 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Fri, 26 Aug 2022 13:49:07 -0700 Subject: [PATCH] [mlir][sparse] start a sparse codegen conversion pass This new pass provides an alternative to the current conversion pass that converts sparse tensor types and sparse primitives to opaque pointers and calls into a runtime support library. This pass will map sparse tensor types to actual data structures and primitives to actual code. In the long run, this new pass will remove our dependence on the support library, avoid the need to link in fully templated and expanded code, and provide much better opportunities for optimization on the generated code. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D132766 --- .../mlir/Dialect/SparseTensor/Transforms/Passes.h | 28 ++++++-- .../mlir/Dialect/SparseTensor/Transforms/Passes.td | 40 ++++++++--- .../Dialect/SparseTensor/Transforms/CMakeLists.txt | 1 + .../Transforms/SparseTensorCodegen.cpp | 82 ++++++++++++++++++++++ .../Transforms/SparseTensorConversion.cpp | 29 ++++++-- .../SparseTensor/Transforms/SparseTensorPasses.cpp | 65 +++++++++++++---- mlir/test/Dialect/SparseTensor/codegen.mlir | 14 ++++ mlir/test/Dialect/SparseTensor/conversion.mlir | 7 ++ 8 files changed, 233 insertions(+), 33 deletions(-) create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp create mode 100644 mlir/test/Dialect/SparseTensor/codegen.mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index 71afddc..2d4bdb3 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -11,8 +11,6 @@ // In general, this file takes the approach of keeping "mechanism" (the // actual steps of applying a transformation) completely separate from // "policy" (heuristics for when and where to apply transformations). -// The only exception is in `SparseToSparseConversionStrategy`; for which, -// see further discussion there. // //===----------------------------------------------------------------------===// @@ -21,15 +19,13 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace bufferization { struct OneShotBufferizationOptions; } // namespace bufferization -// Forward. -class TypeConverter; - //===----------------------------------------------------------------------===// // The Sparsification pass. //===----------------------------------------------------------------------===// @@ -95,6 +91,12 @@ createSparsificationPass(const SparsificationOptions &options); // The SparseTensorConversion pass. //===----------------------------------------------------------------------===// +/// Sparse tensor type converter into an opaque pointer. +class SparseTensorTypeToPtrConverter : public TypeConverter { +public: + SparseTensorTypeToPtrConverter(); +}; + /// Defines a strategy for implementing sparse-to-sparse conversion. /// `kAuto` leaves it up to the compiler to automatically determine /// the method used. `kViaCOO` converts the source tensor to COO and @@ -139,6 +141,22 @@ std::unique_ptr createSparseTensorConversionPass(const SparseTensorConversionOptions &options); //===----------------------------------------------------------------------===// +// The SparseTensorCodegen pass. +//===----------------------------------------------------------------------===// + +/// Sparse tensor type converter into an actual buffer. +class SparseTensorTypeToBufferConverter : public TypeConverter { +public: + SparseTensorTypeToBufferConverter(); +}; + +/// Sets up sparse tensor conversion rules. +void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +std::unique_ptr createSparseTensorCodegenPass(); + +//===----------------------------------------------------------------------===// // Other rewriting rules and passes. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 6e36259..4ca224b 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -77,16 +77,16 @@ def Sparsification : Pass<"sparsification", "ModuleOp"> { } def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> { - let summary = "Apply conversion rules to sparse tensor primitives and types"; + let summary = "Convert sparse tensors and primitives to library calls"; let description = [{ - A pass that converts sparse tensor primitives to calls into a runtime - support library. All sparse tensor types are converted into opaque - pointers to the underlying sparse storage schemes. + A pass that converts sparse tensor primitives into calls into a runtime + support library. Sparse tensor types are converted into opaque pointers + to the underlying sparse storage schemes. - Note that this is a current implementation choice to keep the conversion - relatively simple. In principle, these primitives could also be - converted to actual elaborate IR code that implements the primitives - on the selected sparse tensor storage schemes. + The use of opaque pointers together with runtime support library keeps + the conversion relatively simple, but at the expense of IR opacity, + which obscures opportunities for subsequent optimization of the IR. + An alternative is provided by the SparseTensorCodegen pass. Example of the conversion: @@ -122,4 +122,28 @@ def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> { ]; } +def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> { + let summary = "Convert sparse tensors and primitives to actual code"; + let description = [{ + A pass that converts sparse tensor types and primitives to actual + compiler visible buffers and compiler IR that implements these + primitives on the selected sparse tensor storage schemes. + + This pass provides an alternative to the SparseTensorConversion pass, + eliminating the dependence on a runtime support library, and providing + much more opportunities for subsequent compiler optimization of the + generated code. + + Example of the conversion: + + ```mlir + TBD + ``` + }]; + let constructor = "mlir::createSparseTensorCodegenPass()"; + let dependentDialects = [ + "sparse_tensor::SparseTensorDialect", + ]; +} + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt index 9d99d2f..640ee67 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms CodegenUtils.cpp DenseBufferizationPass.cpp Sparsification.cpp + SparseTensorCodegen.cpp SparseTensorConversion.cpp SparseTensorPasses.cpp SparseTensorRewriting.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp new file mode 100644 index 0000000..8666926 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -0,0 +1,82 @@ +//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A pass that converts sparse tensor types and primitives to actual compiler +// visible buffers and actual compiler IR that implements these primitives on +// the selected sparse tensor storage schemes. This pass provides an alternative +// to the SparseTensorConversion pass, eliminating the dependence on a runtime +// support library, and providing much more opportunities for subsequent +// compiler optimization of the generated code. +// +//===----------------------------------------------------------------------===// + +#include "CodegenUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper methods. +//===----------------------------------------------------------------------===// + +/// Maps each sparse tensor type to the appropriate buffer. +static Optional convertSparseTensorTypes(Type type) { + if (getSparseTensorEncoding(type) != nullptr) { + // TODO: this is just a dummy rule to get the ball rolling.... + RankedTensorType rTp = type.cast(); + return MemRefType::get({ShapedType::kDynamicSize}, rTp.getElementType()); + } + return llvm::None; +} + +//===----------------------------------------------------------------------===// +// Conversion rules. +//===----------------------------------------------------------------------===// + +/// Sparse conversion rule for returns. +class SparseReturnConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Sparse tensor type conversion into an actual buffer. +//===----------------------------------------------------------------------===// + +mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertSparseTensorTypes); +} + +//===----------------------------------------------------------------------===// +// Public method for populating conversion rules. +//===----------------------------------------------------------------------===// + +/// Populates the given patterns list with conversion rules required for +/// the sparsification of linear algebra operations. +void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add(typeConverter, patterns.getContext()); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 7161015..5bd10e8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -6,11 +6,13 @@ // //===----------------------------------------------------------------------===// // -// Convert sparse tensor primitives to calls into a runtime support library. -// Note that this is a current implementation choice to keep the conversion -// simple. In principle, these primitives could also be converted to actual -// elaborate IR code that implements the primitives on the selected sparse -// tensor storage schemes. +// A pass that converts sparse tensor primitives into calls into a runtime +// support library. Sparse tensor types are converted into opaque pointers +// to the underlying sparse storage schemes. The use of opaque pointers +// together with runtime support library keeps the conversion relatively +// simple, but at the expense of IR opacity, which obscures opportunities +// for subsequent optimization of the IR. An alternative is provided by +// the SparseTensorCodegen pass. // //===----------------------------------------------------------------------===// @@ -48,6 +50,13 @@ static Type getOpaquePointerType(OpBuilder &builder) { return LLVM::LLVMPointerType::get(builder.getI8Type()); } +/// Maps each sparse tensor type to an opaque pointer. +static Optional convertSparseTensorTypes(Type type) { + if (getSparseTensorEncoding(type) != nullptr) + return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); + return llvm::None; +} + /// Returns a function reference (first hit also inserts into module). Sets /// the "_emit_c_interface" on the function declaration when requested, /// so that LLVM lowering generates a wrapper function that takes care @@ -1345,6 +1354,7 @@ public: return success(); } }; + /// Sparse conversion rule for the output operator. class SparseTensorOutConverter : public OpConversionPattern { public: @@ -1388,6 +1398,15 @@ public: } // namespace //===----------------------------------------------------------------------===// +// Sparse tensor type conversion into opaque pointer. +//===----------------------------------------------------------------------===// + +mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertSparseTensorTypes); +} + +//===----------------------------------------------------------------------===// // Public method for populating conversion rules. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 2014781..643cff9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -67,20 +67,6 @@ struct SparsificationPass : public SparsificationBase { } }; -class SparseTensorTypeConverter : public TypeConverter { -public: - SparseTensorTypeConverter() { - addConversion([](Type type) { return type; }); - addConversion(convertSparseTensorTypes); - } - // Maps each sparse tensor type to an opaque pointer. - static Optional convertSparseTensorTypes(Type type) { - if (getSparseTensorEncoding(type) != nullptr) - return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8)); - return llvm::None; - } -}; - struct SparseTensorConversionPass : public SparseTensorConversionBase { @@ -93,7 +79,7 @@ struct SparseTensorConversionPass void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - SparseTensorTypeConverter converter; + SparseTensorTypeToPtrConverter converter; ConversionTarget target(*ctx); // Everything in the sparse dialect must go! target.addIllegalDialect(); @@ -158,8 +144,49 @@ struct SparseTensorConversionPass } }; +struct SparseTensorCodegenPass + : public SparseTensorCodegenBase { + + SparseTensorCodegenPass() = default; + SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + SparseTensorTypeToBufferConverter converter; + ConversionTarget target(*ctx); + // Everything in the sparse dialect must go! + target.addIllegalDialect(); + // All dynamic rules below accept new function, call, return, and various + // tensor and bufferization operations as legal output of the rewriting. + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + return converter.isSignatureLegal(op.getCalleeType()); + }); + target.addDynamicallyLegalOp([&](func::ReturnOp op) { + return converter.isLegal(op.getOperandTypes()); + }); + // Populate with rules and apply rewriting rules. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + populateSparseTensorCodegenPatterns(converter, patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + } // namespace +//===----------------------------------------------------------------------===// +// Strategy flag methods. +//===----------------------------------------------------------------------===// + SparseParallelizationStrategy mlir::sparseParallelizationStrategy(int32_t flag) { switch (flag) { @@ -199,6 +226,10 @@ mlir::sparseToSparseConversionStrategy(int32_t flag) { } } +//===----------------------------------------------------------------------===// +// Pass creation methods. +//===----------------------------------------------------------------------===// + std::unique_ptr mlir::createSparsificationPass() { return std::make_unique(); } @@ -216,3 +247,7 @@ std::unique_ptr mlir::createSparseTensorConversionPass( const SparseTensorConversionOptions &options) { return std::make_unique(options); } + +std::unique_ptr mlir::createSparseTensorCodegenPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir new file mode 100644 index 0000000..a3cecaf --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + +// TODO: just a dummy memref rewriting to get the ball rolling.... + +// CHECK-LABEL: func @sparse_nop( +// CHECK-SAME: %[[A:.*]]: memref) -> memref { +// CHECK: return %[[A]] : memref +func.func @sparse_nop(%arg0: tensor) -> tensor { + return %arg0 : tensor +} diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir index 336d815..4de4021 100644 --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -25,6 +25,13 @@ dimOrdering = affine_map<(i,j,k) -> (k,i,j)> }> +// CHECK-LABEL: func @sparse_nop( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: return %[[A]] : !llvm.ptr +func.func @sparse_nop(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + // CHECK-LABEL: func @sparse_dim1d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = arith.constant 0 : index -- 2.7.4