[mlir][sparse] start a sparse codegen conversion pass
authorAart Bik <ajcbik@google.com>
Fri, 26 Aug 2022 20:49:07 +0000 (13:49 -0700)
committerAart Bik <ajcbik@google.com>
Mon, 29 Aug 2022 16:39:33 +0000 (09:39 -0700)
This new pass provides an alternative to the current conversion pass
that converts sparse tensor types and sparse primitives to opaque pointers
and calls into a runtime support library. This pass will map sparse tensor
types to actual data structures and primitives to actual code. In the long
run, this new pass will remove our dependence on the support library, avoid
the need to link in fully templated and expanded code, and provide much better
opportunities for optimization on the generated code.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D132766

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp [new file with mode: 0644]
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir [new file with mode: 0644]
mlir/test/Dialect/SparseTensor/conversion.mlir

index 71afddc..2d4bdb3 100644 (file)
@@ -11,8 +11,6 @@
 // In general, this file takes the approach of keeping "mechanism" (the
 // actual steps of applying a transformation) completely separate from
 // "policy" (heuristics for when and where to apply transformations).
-// The only exception is in `SparseToSparseConversionStrategy`; for which,
-// see further discussion there.
 //
 //===----------------------------------------------------------------------===//
 
 
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 namespace bufferization {
 struct OneShotBufferizationOptions;
 } // namespace bufferization
 
-// Forward.
-class TypeConverter;
-
 //===----------------------------------------------------------------------===//
 // The Sparsification pass.
 //===----------------------------------------------------------------------===//
@@ -95,6 +91,12 @@ createSparsificationPass(const SparsificationOptions &options);
 // The SparseTensorConversion pass.
 //===----------------------------------------------------------------------===//
 
+/// Sparse tensor type converter into an opaque pointer.
+class SparseTensorTypeToPtrConverter : public TypeConverter {
+public:
+  SparseTensorTypeToPtrConverter();
+};
+
 /// Defines a strategy for implementing sparse-to-sparse conversion.
 /// `kAuto` leaves it up to the compiler to automatically determine
 /// the method used.  `kViaCOO` converts the source tensor to COO and
@@ -139,6 +141,22 @@ std::unique_ptr<Pass>
 createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
 
 //===----------------------------------------------------------------------===//
+// The SparseTensorCodegen pass.
+//===----------------------------------------------------------------------===//
+
+/// Sparse tensor type converter into an actual buffer.
+class SparseTensorTypeToBufferConverter : public TypeConverter {
+public:
+  SparseTensorTypeToBufferConverter();
+};
+
+/// Sets up sparse tensor conversion rules.
+void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
+                                         RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createSparseTensorCodegenPass();
+
+//===----------------------------------------------------------------------===//
 // Other rewriting rules and passes.
 //===----------------------------------------------------------------------===//
 
index 6e36259..4ca224b 100644 (file)
@@ -77,16 +77,16 @@ def Sparsification : Pass<"sparsification", "ModuleOp"> {
 }
 
 def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {
-  let summary = "Apply conversion rules to sparse tensor primitives and types";
+  let summary = "Convert sparse tensors and primitives to library calls";
   let description = [{
-    A pass that converts sparse tensor primitives to calls into a runtime
-    support library. All sparse tensor types are converted into opaque
-    pointers to the underlying sparse storage schemes.
+    A pass that converts sparse tensor primitives into calls into a runtime
+    support library. Sparse tensor types are converted into opaque pointers
+    to the underlying sparse storage schemes.
 
-    Note that this is a current implementation choice to keep the conversion
-    relatively simple. In principle, these primitives could also be
-    converted to actual elaborate IR code that implements the primitives
-    on the selected sparse tensor storage schemes.
+    The use of opaque pointers together with runtime support library keeps
+    the conversion relatively simple, but at the expense of IR opacity,
+    which obscures opportunities for subsequent optimization of the IR.
+    An alternative is provided by the SparseTensorCodegen pass.
 
     Example of the conversion:
 
@@ -122,4 +122,28 @@ def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {
   ];
 }
 
+def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
+  let summary = "Convert sparse tensors and primitives to actual code";
+  let description = [{
+    A pass that converts sparse tensor types and primitives to actual
+    compiler visible buffers and compiler IR that implements these
+    primitives on the selected sparse tensor storage schemes.
+
+    This pass provides an alternative to the SparseTensorConversion pass,
+    eliminating the dependence on a runtime support library, and providing
+    much more opportunities for subsequent compiler optimization of the
+    generated code.
+
+    Example of the conversion:
+
+    ```mlir
+    TBD
+    ```
+  }];
+  let constructor = "mlir::createSparseTensorCodegenPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
index 9d99d2f..640ee67 100644 (file)
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   CodegenUtils.cpp
   DenseBufferizationPass.cpp
   Sparsification.cpp
+  SparseTensorCodegen.cpp
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
   SparseTensorRewriting.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
new file mode 100644 (file)
index 0000000..8666926
--- /dev/null
@@ -0,0 +1,82 @@
+//===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A pass that converts sparse tensor types and primitives to actual compiler
+// visible buffers and actual compiler IR that implements these primitives on
+// the selected sparse tensor storage schemes. This pass provides an alternative
+// to the SparseTensorConversion pass, eliminating the dependence on a runtime
+// support library, and providing much more opportunities for subsequent
+// compiler optimization of the generated code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Maps each sparse tensor type to the appropriate buffer.
+static Optional<Type> convertSparseTensorTypes(Type type) {
+  if (getSparseTensorEncoding(type) != nullptr) {
+    // TODO: this is just a dummy rule to get the ball rolling....
+    RankedTensorType rTp = type.cast<RankedTensorType>();
+    return MemRefType::get({ShapedType::kDynamicSize}, rTp.getElementType());
+  }
+  return llvm::None;
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Sparse conversion rule for returns.
+class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
+    return success();
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Sparse tensor type conversion into an actual buffer.
+//===----------------------------------------------------------------------===//
+
+mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertSparseTensorTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// Public method for populating conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Populates the given patterns list with conversion rules required for
+/// the sparsification of linear algebra operations.
+void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
+                                               RewritePatternSet &patterns) {
+  patterns.add<SparseReturnConverter>(typeConverter, patterns.getContext());
+}
index 7161015..5bd10e8 100644 (file)
@@ -6,11 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Convert sparse tensor primitives to calls into a runtime support library.
-// Note that this is a current implementation choice to keep the conversion
-// simple. In principle, these primitives could also be converted to actual
-// elaborate IR code that implements the primitives on the selected sparse
-// tensor storage schemes.
+// A pass that converts sparse tensor primitives into calls into a runtime
+// support library. Sparse tensor types are converted into opaque pointers
+// to the underlying sparse storage schemes. The use of opaque pointers
+// together with runtime support library keeps the conversion relatively
+// simple, but at the expense of IR opacity, which obscures opportunities
+// for subsequent optimization of the IR. An alternative is provided by
+// the SparseTensorCodegen pass.
 //
 //===----------------------------------------------------------------------===//
 
@@ -48,6 +50,13 @@ static Type getOpaquePointerType(OpBuilder &builder) {
   return LLVM::LLVMPointerType::get(builder.getI8Type());
 }
 
+/// Maps each sparse tensor type to an opaque pointer.
+static Optional<Type> convertSparseTensorTypes(Type type) {
+  if (getSparseTensorEncoding(type) != nullptr)
+    return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
+  return llvm::None;
+}
+
 /// Returns a function reference (first hit also inserts into module). Sets
 /// the "_emit_c_interface" on the function declaration when requested,
 /// so that LLVM lowering generates a wrapper function that takes care
@@ -1345,6 +1354,7 @@ public:
     return success();
   }
 };
+
 /// Sparse conversion rule for the output operator.
 class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
 public:
@@ -1388,6 +1398,15 @@ public:
 } // namespace
 
 //===----------------------------------------------------------------------===//
+// Sparse tensor type conversion into opaque pointer.
+//===----------------------------------------------------------------------===//
+
+mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertSparseTensorTypes);
+}
+
+//===----------------------------------------------------------------------===//
 // Public method for populating conversion rules.
 //===----------------------------------------------------------------------===//
 
index 2014781..643cff9 100644 (file)
@@ -67,20 +67,6 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
   }
 };
 
-class SparseTensorTypeConverter : public TypeConverter {
-public:
-  SparseTensorTypeConverter() {
-    addConversion([](Type type) { return type; });
-    addConversion(convertSparseTensorTypes);
-  }
-  // Maps each sparse tensor type to an opaque pointer.
-  static Optional<Type> convertSparseTensorTypes(Type type) {
-    if (getSparseTensorEncoding(type) != nullptr)
-      return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
-    return llvm::None;
-  }
-};
-
 struct SparseTensorConversionPass
     : public SparseTensorConversionBase<SparseTensorConversionPass> {
 
@@ -93,7 +79,7 @@ struct SparseTensorConversionPass
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    SparseTensorTypeConverter converter;
+    SparseTensorTypeToPtrConverter converter;
     ConversionTarget target(*ctx);
     // Everything in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
@@ -158,8 +144,49 @@ struct SparseTensorConversionPass
   }
 };
 
+struct SparseTensorCodegenPass
+    : public SparseTensorCodegenBase<SparseTensorCodegenPass> {
+
+  SparseTensorCodegenPass() = default;
+  SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    SparseTensorTypeToBufferConverter converter;
+    ConversionTarget target(*ctx);
+    // Everything in the sparse dialect must go!
+    target.addIllegalDialect<SparseTensorDialect>();
+    // All dynamic rules below accept new function, call, return, and various
+    // tensor and bufferization operations as legal output of the rewriting.
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return converter.isSignatureLegal(op.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+      return converter.isSignatureLegal(op.getCalleeType());
+    });
+    target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+      return converter.isLegal(op.getOperandTypes());
+    });
+    // Populate with rules and apply rewriting rules.
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    populateSparseTensorCodegenPatterns(converter, patterns);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Strategy flag methods.
+//===----------------------------------------------------------------------===//
+
 SparseParallelizationStrategy
 mlir::sparseParallelizationStrategy(int32_t flag) {
   switch (flag) {
@@ -199,6 +226,10 @@ mlir::sparseToSparseConversionStrategy(int32_t flag) {
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Pass creation methods.
+//===----------------------------------------------------------------------===//
+
 std::unique_ptr<Pass> mlir::createSparsificationPass() {
   return std::make_unique<SparsificationPass>();
 }
@@ -216,3 +247,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
     const SparseTensorConversionOptions &options) {
   return std::make_unique<SparseTensorConversionPass>(options);
 }
+
+std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
+  return std::make_unique<SparseTensorCodegenPass>();
+}
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
new file mode 100644 (file)
index 0000000..a3cecaf
--- /dev/null
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s --sparse-tensor-codegen  --canonicalize --cse | FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"]
+}>
+
+// TODO: just a dummy memref rewriting to get the ball rolling....
+
+// CHECK-LABEL: func @sparse_nop(
+//  CHECK-SAME: %[[A:.*]]: memref<?xf64>) -> memref<?xf64> {
+//       CHECK: return %[[A]] : memref<?xf64>
+func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+  return %arg0 : tensor<?xf64, #SparseVector>
+}
index 336d815..4de4021 100644 (file)
   dimOrdering = affine_map<(i,j,k) -> (k,i,j)>
 }>
 
+// CHECK-LABEL: func @sparse_nop(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//       CHECK: return %[[A]] : !llvm.ptr<i8>
+func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+  return %arg0 : tensor<?xf64, #SparseVector>
+}
+
 // CHECK-LABEL: func @sparse_dim1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = arith.constant 0 : index