// In general, this file takes the approach of keeping "mechanism" (the
// actual steps of applying a transformation) completely separate from
// "policy" (heuristics for when and where to apply transformations).
-// The only exception is in `SparseToSparseConversionStrategy`; for which,
-// see further discussion there.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace bufferization {
struct OneShotBufferizationOptions;
} // namespace bufferization
-// Forward.
-class TypeConverter;
-
//===----------------------------------------------------------------------===//
// The Sparsification pass.
//===----------------------------------------------------------------------===//
// The SparseTensorConversion pass.
//===----------------------------------------------------------------------===//
+/// Sparse tensor type converter into an opaque pointer.
+class SparseTensorTypeToPtrConverter : public TypeConverter {
+public:
+ SparseTensorTypeToPtrConverter();
+};
+
/// Defines a strategy for implementing sparse-to-sparse conversion.
/// `kAuto` leaves it up to the compiler to automatically determine
/// the method used. `kViaCOO` converts the source tensor to COO and
createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
//===----------------------------------------------------------------------===//
+// The SparseTensorCodegen pass.
+//===----------------------------------------------------------------------===//
+
+/// Sparse tensor type converter into an actual buffer.
+class SparseTensorTypeToBufferConverter : public TypeConverter {
+public:
+ SparseTensorTypeToBufferConverter();
+};
+
+/// Sets up sparse tensor conversion rules.
+void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createSparseTensorCodegenPass();
+
+//===----------------------------------------------------------------------===//
// Other rewriting rules and passes.
//===----------------------------------------------------------------------===//
}
def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {
- let summary = "Apply conversion rules to sparse tensor primitives and types";
+ let summary = "Convert sparse tensors and primitives to library calls";
let description = [{
- A pass that converts sparse tensor primitives to calls into a runtime
- support library. All sparse tensor types are converted into opaque
- pointers to the underlying sparse storage schemes.
+ A pass that converts sparse tensor primitives into calls into a runtime
+ support library. Sparse tensor types are converted into opaque pointers
+ to the underlying sparse storage schemes.
- Note that this is a current implementation choice to keep the conversion
- relatively simple. In principle, these primitives could also be
- converted to actual elaborate IR code that implements the primitives
- on the selected sparse tensor storage schemes.
+ The use of opaque pointers together with runtime support library keeps
+ the conversion relatively simple, but at the expense of IR opacity,
+ which obscures opportunities for subsequent optimization of the IR.
+ An alternative is provided by the SparseTensorCodegen pass.
Example of the conversion:
];
}
+def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
+ let summary = "Convert sparse tensors and primitives to actual code";
+ let description = [{
+ A pass that converts sparse tensor types and primitives to actual
+ compiler visible buffers and compiler IR that implements these
+ primitives on the selected sparse tensor storage schemes.
+
+ This pass provides an alternative to the SparseTensorConversion pass,
+ eliminating the dependence on a runtime support library, and providing
+ much more opportunities for subsequent compiler optimization of the
+ generated code.
+
+ Example of the conversion:
+
+ ```mlir
+ TBD
+ ```
+ }];
+ let constructor = "mlir::createSparseTensorCodegenPass()";
+ let dependentDialects = [
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
CodegenUtils.cpp
DenseBufferizationPass.cpp
Sparsification.cpp
+ SparseTensorCodegen.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
SparseTensorRewriting.cpp
--- /dev/null
+//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A pass that converts sparse tensor types and primitives to actual compiler
+// visible buffers and actual compiler IR that implements these primitives on
+// the selected sparse tensor storage schemes. This pass provides an alternative
+// to the SparseTensorConversion pass, eliminating the dependence on a runtime
+// support library, and providing much more opportunities for subsequent
+// compiler optimization of the generated code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Maps each sparse tensor type to the appropriate buffer.
+static Optional<Type> convertSparseTensorTypes(Type type) {
+ if (getSparseTensorEncoding(type) != nullptr) {
+ // TODO: this is just a dummy rule to get the ball rolling....
+ RankedTensorType rTp = type.cast<RankedTensorType>();
+ return MemRefType::get({ShapedType::kDynamicSize}, rTp.getElementType());
+ }
+ return llvm::None;
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Sparse conversion rule for returns.
+class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
+ return success();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Sparse tensor type conversion into an actual buffer.
+//===----------------------------------------------------------------------===//
+
+mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion(convertSparseTensorTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// Public method for populating conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Populates the given patterns list with conversion rules required for
+/// the sparsification of linear algebra operations.
+void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ patterns.add<SparseReturnConverter>(typeConverter, patterns.getContext());
+}
//
//===----------------------------------------------------------------------===//
//
-// Convert sparse tensor primitives to calls into a runtime support library.
-// Note that this is a current implementation choice to keep the conversion
-// simple. In principle, these primitives could also be converted to actual
-// elaborate IR code that implements the primitives on the selected sparse
-// tensor storage schemes.
+// A pass that converts sparse tensor primitives into calls into a runtime
+// support library. Sparse tensor types are converted into opaque pointers
+// to the underlying sparse storage schemes. The use of opaque pointers
+// together with runtime support library keeps the conversion relatively
+// simple, but at the expense of IR opacity, which obscures opportunities
+// for subsequent optimization of the IR. An alternative is provided by
+// the SparseTensorCodegen pass.
//
//===----------------------------------------------------------------------===//
return LLVM::LLVMPointerType::get(builder.getI8Type());
}
+/// Maps each sparse tensor type to an opaque pointer.
+static Optional<Type> convertSparseTensorTypes(Type type) {
+ if (getSparseTensorEncoding(type) != nullptr)
+ return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
+ return llvm::None;
+}
+
/// Returns a function reference (first hit also inserts into module). Sets
/// the "_emit_c_interface" on the function declaration when requested,
/// so that LLVM lowering generates a wrapper function that takes care
return success();
}
};
+
/// Sparse conversion rule for the output operator.
class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
public:
} // namespace
//===----------------------------------------------------------------------===//
+// Sparse tensor type conversion into opaque pointer.
+//===----------------------------------------------------------------------===//
+
+mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion(convertSparseTensorTypes);
+}
+
+//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//
}
};
-class SparseTensorTypeConverter : public TypeConverter {
-public:
- SparseTensorTypeConverter() {
- addConversion([](Type type) { return type; });
- addConversion(convertSparseTensorTypes);
- }
- // Maps each sparse tensor type to an opaque pointer.
- static Optional<Type> convertSparseTensorTypes(Type type) {
- if (getSparseTensorEncoding(type) != nullptr)
- return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
- return llvm::None;
- }
-};
-
struct SparseTensorConversionPass
: public SparseTensorConversionBase<SparseTensorConversionPass> {
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- SparseTensorTypeConverter converter;
+ SparseTensorTypeToPtrConverter converter;
ConversionTarget target(*ctx);
// Everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
}
};
+struct SparseTensorCodegenPass
+ : public SparseTensorCodegenBase<SparseTensorCodegenPass> {
+
+ SparseTensorCodegenPass() = default;
+ SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ SparseTensorTypeToBufferConverter converter;
+ ConversionTarget target(*ctx);
+ // Everything in the sparse dialect must go!
+ target.addIllegalDialect<SparseTensorDialect>();
+ // All dynamic rules below accept new function, call, return, and various
+ // tensor and bufferization operations as legal output of the rewriting.
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return converter.isSignatureLegal(op.getFunctionType());
+ });
+ target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+ return converter.isSignatureLegal(op.getCalleeType());
+ });
+ target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+ return converter.isLegal(op.getOperandTypes());
+ });
+ // Populate with rules and apply rewriting rules.
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
+ scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+ target);
+ populateSparseTensorCodegenPatterns(converter, patterns);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
} // namespace
+//===----------------------------------------------------------------------===//
+// Strategy flag methods.
+//===----------------------------------------------------------------------===//
+
SparseParallelizationStrategy
mlir::sparseParallelizationStrategy(int32_t flag) {
switch (flag) {
}
}
+//===----------------------------------------------------------------------===//
+// Pass creation methods.
+//===----------------------------------------------------------------------===//
+
std::unique_ptr<Pass> mlir::createSparsificationPass() {
return std::make_unique<SparsificationPass>();
}
const SparseTensorConversionOptions &options) {
return std::make_unique<SparseTensorConversionPass>(options);
}
+
+std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
+ return std::make_unique<SparseTensorCodegenPass>();
+}
--- /dev/null
+// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed"]
+}>
+
+// TODO: just a dummy memref rewriting to get the ball rolling....
+
+// CHECK-LABEL: func @sparse_nop(
+// CHECK-SAME: %[[A:.*]]: memref<?xf64>) -> memref<?xf64> {
+// CHECK: return %[[A]] : memref<?xf64>
+func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+ return %arg0 : tensor<?xf64, #SparseVector>
+}
dimOrdering = affine_map<(i,j,k) -> (k,i,j)>
}>
+// CHECK-LABEL: func @sparse_nop(
+// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+// CHECK: return %[[A]] : !llvm.ptr<i8>
+func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+ return %arg0 : tensor<?xf64, #SparseVector>
+}
+
// CHECK-LABEL: func @sparse_dim1d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = arith.constant 0 : index