[mlir][sparse] Add SparseTensorStorageExpansion Pass to expand compounded sparse...
authorPeiming Liu <peiming@google.com>
Thu, 1 Sep 2022 17:06:31 +0000 (17:06 +0000)
committerPeiming Liu <peiming@google.com>
Thu, 1 Sep 2022 22:47:31 +0000 (22:47 +0000)
This patch adds SparseTensorStorageExpansion pass, it flattens the tuple used to store a sparse
tensor handle.

Right now, it only set up the skeleton for the pass, more lowering rules for sparse tensor storage
operation need to be added.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133125

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp [new file with mode: 0644]
mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir [new file with mode: 0644]

index 83eb1bb..523f1fb 100644 (file)
@@ -162,6 +162,22 @@ void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
 std::unique_ptr<Pass> createSparseTensorCodegenPass();
 
 //===----------------------------------------------------------------------===//
+// The SparseTensorStorageExpansion pass.
+//===----------------------------------------------------------------------===//
+
+/// Sparse tensor storage type converter from compound to expanded form.
+class SparseTensorStorageTupleExpander : public TypeConverter {
+public:
+  SparseTensorStorageTupleExpander();
+};
+
+/// Sets up sparse tensor storage expansion rules.
+void populateSparseTensorStorageExpansionPatterns(TypeConverter &typeConverter,
+                                                  RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createSparseTensorStorageExpansionPass();
+
+//===----------------------------------------------------------------------===//
 // Other rewriting rules and passes.
 //===----------------------------------------------------------------------===//
 
index c8e7123..d765a10 100644 (file)
@@ -146,4 +146,39 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
   ];
 }
 
+def SparseTensorStorageExpansion : Pass<"sparse-tensor-storage-expansion", "ModuleOp"> {
+  let summary = "Expand compounded sparse tensor storage into individual SSA values";
+  let description = [{
+    A pass that expands sparse tensor storage (aggregated by tuple) into
+    individual SSA values. It also lowers sparse tensor storage operations,
+    e.g., sparse_tensor.storage_get and sparse_tensor.storage_set.
+
+    Example of the conversion:
+
+    ```mlir
+    Before:
+      func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>,
+                                                 memref<?xf64>,
+                                                 f64>)
+                                        -> tuple<memref<?xf64>,
+                                                 memref<?xf64>,
+                                                 f64> {
+        return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+      }
+    After:
+      func.func @sparse_storage_set(%arg0: memref<?xf64>,
+                                    %arg1: memref<?xf64>,
+                                    %arg2: f64)
+                                    -> (memref<?xf64>, memref<?xf64>, f64) {
+        return %arg0, %arg1, %arg2 : memref<?xf64>, memref<?xf64>, f64
+      }
+    ```
+  }];
+  let constructor = "mlir::createSparseTensorStorageExpansionPass()";
+  let dependentDialects = [
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
index 640ee67..39b633a 100644 (file)
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseTensorConversion.cpp
   SparseTensorPasses.cpp
   SparseTensorRewriting.cpp
+  SparseTensorStorageExpansion.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
index b30d0d2..d5e2b96 100644 (file)
@@ -24,6 +24,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
+#define GEN_PASS_DEF_SPARSETENSORSTORAGEEXPANSION
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
 } // namespace mlir
 
@@ -185,6 +186,44 @@ struct SparseTensorCodegenPass
   }
 };
 
+struct SparseTensorStorageExpansionPass
+    : public impl::SparseTensorStorageExpansionBase<
+          SparseTensorStorageExpansionPass> {
+
+  SparseTensorStorageExpansionPass() = default;
+  SparseTensorStorageExpansionPass(
+      const SparseTensorStorageExpansionPass &pass) = default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    SparseTensorStorageTupleExpander converter;
+    ConversionTarget target(*ctx);
+    // Now, everything in the sparse dialect must go!
+    target.addIllegalDialect<SparseTensorDialect>();
+    // All dynamic rules below accept new function, call, return.
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return converter.isSignatureLegal(op.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
+      return converter.isSignatureLegal(op.getCalleeType());
+    });
+    target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
+      return converter.isLegal(op.getOperandTypes());
+    });
+    // Populate with rules and apply rewriting rules.
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    populateSparseTensorStorageExpansionPatterns(converter, patterns);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -255,3 +294,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
   return std::make_unique<SparseTensorCodegenPass>();
 }
+
+std::unique_ptr<Pass> mlir::createSparseTensorStorageExpansionPass() {
+  return std::make_unique<SparseTensorStorageExpansionPass>();
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp
new file mode 100644 (file)
index 0000000..c1305eb
--- /dev/null
@@ -0,0 +1,96 @@
+//===- SparseTensorStorageExpansion.cpp - Sparse tensor storage expansion ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The sparse tensor storage expansion pass expands the compound storage for
+// sparse tensors (using tuple) to flattened SSA values.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+/// Expands sparse tensor storage tuple.
+static Optional<LogicalResult>
+convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
+  if (auto tuple = t.dyn_cast<TupleType>()) {
+    // Note that it does not handle nest tuples, but it is fine
+    // for sparse compiler as they will not be generated.
+    result.append(tuple.getTypes().begin(), tuple.getTypes().end());
+    return success();
+  }
+  return llvm::None;
+}
+
+//===----------------------------------------------------------------------===//
+// Conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Sparse tensor storage conversion rule for returns.
+class SparseStorageReturnConverter
+    : public OpConversionPattern<func::ReturnOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value, 8> flattened;
+    for (auto operand : adaptor.getOperands()) {
+      if (auto cast =
+              dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
+          cast && cast->getResultTypes()[0].isa<TupleType>())
+        // An unrealized_conversion_cast will be inserted by type converter to
+        // inter-mix the gap between 1:N conversion between tuple and types.
+        // In this case, take the operands in the cast and replace the tuple
+        // output with the flattened type array.
+        flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+      else
+        flattened.push_back(operand);
+    }
+    // Create a return with the flattened value extracted from tuple.
+    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
+    return success();
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Sparse tensor storage expansion
+//===----------------------------------------------------------------------===//
+
+mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertSparseTensorStorageTuple);
+}
+
+//===----------------------------------------------------------------------===//
+// Public method for populating conversion rules.
+//===----------------------------------------------------------------------===//
+
+/// Populates the given patterns list with conversion rules required
+/// to expand compounded sparse tensor tuples.
+void mlir::populateSparseTensorStorageExpansionPatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns) {
+  patterns.add<SparseStorageReturnConverter>(typeConverter,
+                                             patterns.getContext());
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir
new file mode 100644 (file)
index 0000000..445b234
--- /dev/null
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -sparse-tensor-storage-expansion | FileCheck %s
+
+// CHECK-LABEL:  func @sparse_storage_expand(
+// CHECK-SAME:     %[[TMP_arg0:.*0]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg1:.*1]]: memref<?xf64>,
+// CHECK-SAME:     %[[TMP_arg2:.*]]: f64
+// CHECK           return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]
+func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
+                                     -> tuple<memref<?xf64>, memref<?xf64>, f64> {
+  return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}