[mlir][linalg][sparse] add linalg optimization passes "upstream"
authorAart Bik <ajcbik@google.com>
Wed, 16 Feb 2022 20:56:43 +0000 (12:56 -0800)
committerAart Bik <ajcbik@google.com>
Thu, 17 Feb 2022 16:55:50 +0000 (08:55 -0800)
It is time to compose Linalg related optimizations with SparseTensor
related optimizations. This is a careful first start by adding some
general Linalg optimizations "upstream" of the sparse compiler in the
full sparse compiler pipeline. Some minor changes were needed to make
those optimizations aware of sparsity.

Note that after this, we will add a sparse specific fusion rule,
just to demonstrate the power of the new composition.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D119971

15 files changed:
mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/dense_output.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_quantized_matmul.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py
mlir/test/Integration/Dialect/SparseTensor/python/test_output.py
mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 2ff2cc8..e310f58 100644 (file)
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRLinalg
   MLIRIR
   MLIRParser
   MLIRSideEffectInterfaces
+  MLIRSparseTensor
   MLIRSCF
   MLIRMath
   MLIRMemRef
index 87278dc..c148ab9 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -819,9 +820,18 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
       Type resultType = genericOp->getResult(yieldVal.index()).getType();
       // The input can have a different type than the result, e.g. a dynamic
       // input dimension can be turned into a static output dimension.
-      if (returnedArg.getType() != resultType)
-        returnedArg = rewriter.create<tensor::CastOp>(genericOp.getLoc(),
-                                                      resultType, returnedArg);
+      Type returnType = returnedArg.getType();
+      if (returnType != resultType) {
+        // Distinguish between sparse conversion or dense tensor casting.
+        // TODO: unify the two ops?
+        if (sparse_tensor::getSparseTensorEncoding(returnType) ||
+            sparse_tensor::getSparseTensorEncoding(resultType))
+          returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
+              genericOp.getLoc(), resultType, returnedArg);
+        else
+          returnedArg = rewriter.create<tensor::CastOp>(
+              genericOp.getLoc(), resultType, returnedArg);
+      }
       returnedArgs.push_back(returnedArg);
     }
 
index 6897cb9..57bef39 100644 (file)
@@ -50,6 +50,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRSCFTransforms
   MLIRSCFUtils
   MLIRPass
+  MLIRSparseTensor
   MLIRStandard
   MLIRStandardOpsTransforms
   MLIRStandardToLLVM
index 570e844..7e0e857 100644 (file)
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
@@ -2184,6 +2185,10 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
         if (!operandType)
           continue;
 
+        // If outs is sparse, leave it to the sparse compiler.
+        if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
+          continue;
+
         // If outs is already an `init_tensor` operation, nothing to do.
         auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
         if (definingOp)
@@ -2213,7 +2218,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
 } // namespace
 
 //===---------------------------------------------------------------------===//
-// Methods that add patterns descrined in this file to a pattern list.
+// Methods that add patterns described in this file to a pattern list.
 //===---------------------------------------------------------------------===//
 
 void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
index 25487e4..ff6577a 100644 (file)
@@ -29,6 +29,8 @@ using namespace mlir::sparse_tensor;
 void mlir::sparse_tensor::buildSparseCompiler(
     OpPassManager &pm, const SparseCompilerOptions &options) {
   // TODO(wrengr): ensure the original `pm` is for ModuleOp
+  pm.addNestedPass<FuncOp>(createLinalgGeneralizationPass());
+  pm.addPass(createLinalgElementwiseOpFusionPass());
   pm.addPass(createSparsificationPass(options.sparsificationOptions()));
   pm.addPass(createSparseTensorConversionPass());
   pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
index 2d8898e..a263972 100644 (file)
@@ -1,6 +1,5 @@
 // RUN: mlir-opt %s --sparse-compiler | \
 // RUN: TENSOR0="%mlir_integration_test_dir/data/test.mtx" \
-// RUN: TENSOR1="%mlir_integration_test_dir/data/zero.mtx" \
 // RUN: mlir-cpu-runner \
 // RUN:  -e entry -entry-point-result=void  \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // library.
 module {
   //
-  // A kernel that assigns elements from A to an initially zero X.
+  // A kernel that assigns elements from A to X.
   //
-  func @dense_output(%arga: tensor<?x?xf64, #SparseMatrix>,
-                     %argx: tensor<?x?xf64, #DenseMatrix>
-                    {linalg.inplaceable = true})
-       -> tensor<?x?xf64, #DenseMatrix> {
+  func @dense_output(%arga: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64, #DenseMatrix> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #SparseMatrix>
+    %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #SparseMatrix>
+    %init = sparse_tensor.init [%d0, %d1] : tensor<?x?xf64, #DenseMatrix>
     %0 = linalg.generic #trait_assign
        ins(%arga: tensor<?x?xf64, #SparseMatrix>)
-      outs(%argx: tensor<?x?xf64, #DenseMatrix>) {
+      outs(%init: tensor<?x?xf64, #DenseMatrix>) {
       ^bb(%a: f64, %x: f64):
         linalg.yield %a : f64
     } -> tensor<?x?xf64, #DenseMatrix>
@@ -70,15 +71,9 @@ module {
     %a = sparse_tensor.new %fileName
       : !Filename to tensor<?x?xf64, #SparseMatrix>
 
-    // Initialize all-dense annotated "sparse" matrix to all zeros.
-    %fileZero = call @getTensorFilename(%c1) : (index) -> (!Filename)
-    %x = sparse_tensor.new %fileZero
-      : !Filename to tensor<?x?xf64, #DenseMatrix>
-
     // Call the kernel.
-    %0 = call @dense_output(%a, %x)
-      : (tensor<?x?xf64, #SparseMatrix>,
-         tensor<?x?xf64, #DenseMatrix>) -> tensor<?x?xf64, #DenseMatrix>
+    %0 = call @dense_output(%a)
+      : (tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64, #DenseMatrix>
 
     //
     // Print the linearized 5x5 result for verification.
@@ -92,7 +87,7 @@ module {
 
     // Release the resources.
     sparse_tensor.release %a : tensor<?x?xf64, #SparseMatrix>
-    sparse_tensor.release %x : tensor<?x?xf64, #DenseMatrix>
+    sparse_tensor.release %0 : tensor<?x?xf64, #DenseMatrix>
 
     return
   }
index a758a89..02d5cc0 100644 (file)
@@ -1,18 +1,12 @@
-// RUN: mlir-opt %s \
-// RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN:   --sparse-compiler | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 //
 // Do the same run, but now with SIMDization as well. This should not change the outcome.
 //
-// RUN: mlir-opt %s \
-// RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN:   --sparse-compiler="vectorization-strategy=2 vl=2" | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
+// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=2 vl=2" | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
index 6c207d7..f2a35ef 100644 (file)
@@ -1,11 +1,7 @@
-// RUN: mlir-opt %s \
-// RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN:   --sparse-compiler | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
-//
 
 #CSR = #sparse_tensor.encoding<{
   dimLevelType = [ "dense", "compressed" ],
index c4cb95a..1db865b 100644 (file)
@@ -1,18 +1,12 @@
-// RUN: mlir-opt %s \
-// RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN:   --sparse-compiler | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 //
 // Do the same run, but now with SIMDization as well. This should not change the outcome.
 //
-// RUN: mlir-opt %s \
-// RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN:   --sparse-compiler="vectorization-strategy=2 vl=2" | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
+// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=2 vl=2" | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
index 8f409fe..b0fde08 100644 (file)
@@ -1,16 +1,12 @@
 // RUN: mlir-opt %s --sparse-compiler | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 //
 // Do the same run, but now with SIMDization as well. This should not change the outcome.
 //
-// RUN: mlir-opt %s \
-// RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN:   --sparse-compiler="vectorization-strategy=2 vl=8" | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
+// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=2 vl=8" | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
index 8709275..0179503 100755 (executable)
@@ -1,18 +1,12 @@
-// RUN: mlir-opt %s \
-// RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN:   --sparse-compiler | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 //
 // Do the same run, but now with SIMDization as well. This should not change the outcome.
 //
-// RUN: mlir-opt %s \
-// RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
-// RUN:   --sparse-compiler="vectorization-strategy=2 vl=8" | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
+// RUN: mlir-opt %s -sparse-compiler="vectorization-strategy=2 vl=8" | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
index a79c4b4..1b66628 100644 (file)
@@ -113,7 +113,6 @@ class SparseCompiler:
 
   def __init__(self, options: str):
     pipeline = (
-        f'builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops),'
         f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}')
     self.pipeline = pipeline
 
index f03756b..c29f618 100644 (file)
@@ -73,7 +73,6 @@ class SparseCompiler:
 
   def __init__(self):
     pipeline = (
-        f'builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops),'
         f'sparse-compiler{{reassociate-fp-reductions=1 enable-index-optimizations=1}}')
     self.pipeline = pipeline
 
index f18655e..ccf1ffd 100644 (file)
@@ -171,7 +171,6 @@ class SparseCompiler:
   def __init__(self, sparsification_options: str, support_lib: str):
     self._support_lib = support_lib
     self._pipeline = (
-        f'builtin.func(linalg-generalize-named-ops,linalg-fuse-elementwise-ops),'
         f'sparse-compiler{{{sparsification_options} reassociate-fp-reductions=1 enable-index-optimizations=1}}')
     # Must be in the scope of a `with ir.Context():`
     self._passmanager = PassManager.parse(self._pipeline)
index 511c221..32d78f2 100644 (file)
@@ -6997,6 +6997,7 @@ cc_library(
         ":Parser",
         ":SCFDialect",
         ":SideEffectInterfaces",
+        ":SparseTensor",
         ":StandardOps",
         ":Support",
         ":TensorDialect",
@@ -7083,6 +7084,7 @@ cc_library(
         ":SCFDialect",
         ":SCFTransforms",
         ":SCFUtils",
+        ":SparseTensor",
         ":StandardOps",
         ":StandardOpsTransforms",
         ":Support",