[mlir][linalg] Start a named ops to generic ops pass
authorLei Zhang <antiagainst@google.com>
Thu, 19 Nov 2020 13:56:06 +0000 (08:56 -0500)
committerLei Zhang <antiagainst@google.com>
Thu, 19 Nov 2020 14:21:06 +0000 (09:21 -0500)
This commit starts a new pass and patterns for converting Linalg
named ops to generic ops. This enables us to leverage the flexbility
from generic ops during transformations. Right now only linalg.conv
is supported; others will be added when useful.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D91357

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp [new file with mode: 0644]
mlir/test/Dialect/Linalg/generalize-named-ops.mlir [new file with mode: 0644]
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

index 2c200fe..6ac1e56 100644 (file)
@@ -159,6 +159,10 @@ def CopyOp : LinalgStructured_Op<"copy", [
 
     Value getSource() { return input();}
     Value getTarget() { return output(); }
+
+    static std::function<void(Block &)> getRegionBuilder() {
+      return nullptr;
+    }
   }];
   let verifier = [{ return ::verify(*this); }];
 
@@ -188,6 +192,10 @@ def FillOp : LinalgStructured_Op<"fill", [
       return Builder(getContext()).getAffineMapArrayAttr({
           extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
     }
+
+    static std::function<void(Block &)> getRegionBuilder() {
+      return nullptr;
+    }
   }];
 
   let verifier = [{ return ::verify(*this); }];
@@ -261,6 +269,10 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
       if (!padding().hasValue()) return 0;
       return padding().getValue().getValue<int64_t>({i, 1});
     }
+
+    static std::function<void(Block &)> getRegionBuilder() {
+      return nullptr;
+    }
   }];
 }
 
@@ -516,6 +528,10 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
       return ss.hasValue() ?
         llvm::Optional<unsigned>(ss.getValue()) : llvm::None;
     }
+
+    static std::function<void(Block &)> getRegionBuilder() {
+      return nullptr;
+    }
   }];
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parseGenericOp(parser, result); }];
index ec71674..0373bf3 100644 (file)
@@ -803,6 +803,17 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
             &res->getRegion(ridx), map);
         return res;
       }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/[{
+        Returns the region builder for constructing the body for linalg.generic.
+        Returns a null function if this named op does not define a region
+        builder.
+      }],
+      /*retTy=*/"std::function<void(Block &)>",
+      /*methodName=*/"getRegionBuilder",
+      (ins),
+      [{ return ConcreteOp::getRegionBuilder(); }]
     >
   ];
 
index 620abd9..d041df8 100644 (file)
@@ -55,6 +55,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
 void populateElementwiseToLinalgConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx);
 
+/// Create a pass to conver named Linalg operations to Linalg generic
+/// operations.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
+
 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
 /// producer (consumer) generic operation by expanding the dimensionality of the
 /// loop in the generic op.
index 07bf93a..aabfd44 100644 (file)
@@ -112,4 +112,10 @@ def LinalgTilingToParallelLoops
   let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
 }
 
+def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> {
+  let summary = "Convert named ops into generic ops";
+  let constructor = "mlir::createLinalgGeneralizationPass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES
index 523a34e..8d531a1 100644 (file)
@@ -624,6 +624,20 @@ private:
   LinalgLoweringType loweringType;
 };
 
+/// Linalg generalization patterns
+
+/// Populates `patterns` with patterns to convert spec-generated named ops to
+/// linalg.generic ops.
+void populateLinalgNamedOpsGeneralizationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    LinalgMarker marker = LinalgMarker());
+
+/// Populates `patterns` with patterns to convert linalg.conv ops to
+/// linalg.generic ops.
+void populateLinalgConvGeneralizationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    LinalgMarker marker = LinalgMarker());
+
 //===----------------------------------------------------------------------===//
 // Op-specific patterns.
 //===----------------------------------------------------------------------===//
index 11a4889..6de4ce6 100644 (file)
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   ElementwiseToLinalg.cpp
   Fusion.cpp
   FusionOnTensors.cpp
+  Generalization.cpp
   Hoisting.cpp
   Interchange.cpp
   Loops.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
new file mode 100644 (file)
index 0000000..3496a77
--- /dev/null
@@ -0,0 +1,180 @@
+//===- Generalization.cpp - linalg named ops to generic ops  --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Linalg generalization pass. It converts named
+// Linalg ops to linalg.generic ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-generalization"
+
+using namespace mlir;
+
+// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
+// the given `namedOp` does not have a region builder.
+static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
+                                                    OpBuilder &builder) {
+  auto regionBuilder = namedOp.getRegionBuilder();
+  if (!regionBuilder) {
+    LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
+    return nullptr;
+  }
+
+  SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
+  auto iterators = llvm::to_vector<4>(
+      namedOp.iterator_types().getAsValueRange<StringAttr>());
+  auto resultTypes = namedOp.getOutputTensorTypes();
+  SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
+
+  return builder.create<linalg::GenericOp>(
+      namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputBuffers(),
+      namedOp.getInitTensors(), indexingMaps, iterators,
+      [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
+        edsc::ScopedContext scope(bodyBuilder, loc);
+        regionBuilder(*bodyBuilder.getBlock());
+      });
+}
+
+namespace {
+
+/// Base class for all linalg generalization patterns. A subclass must provide
+/// the following method:
+///   linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
+/// for creating the generic op.
+// TODO: remove this pattern after migrating all manually-written named ops
+// into auto-generated ones.
+template <typename ConcretePattern, typename RootOp>
+struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
+  LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker,
+                              PatternBenefit benefit = 1)
+      : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
+
+  LogicalResult matchAndRewrite(RootOp rootOp,
+                                PatternRewriter &rewriter) const override {
+    auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
+    if (!linalgOp)
+      return failure();
+    if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+      return failure();
+
+    auto *pattern = static_cast<const ConcretePattern *>(this);
+    linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
+    if (!genericOp)
+      return failure();
+
+    rewriter.replaceOp(rootOp, genericOp.getResults());
+    marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
+    return success();
+  }
+
+private:
+  linalg::LinalgMarker marker;
+};
+
+struct GeneralizeConvOp
+    : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
+  using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
+
+  linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
+};
+
+/// Catch-all pattern for converting all named ops with a region builder into
+/// linalg.generic.
+struct LinalgNamedOpGeneralizationPattern : RewritePattern {
+  LinalgNamedOpGeneralizationPattern(MLIRContext *context,
+                                     linalg::LinalgMarker marker,
+                                     PatternBenefit benefit = 1)
+      : RewritePattern(benefit, MatchAnyOpTypeTag()),
+        marker(std::move(marker)) {}
+
+  LogicalResult matchAndRewrite(Operation *rootOp,
+                                PatternRewriter &rewriter) const override {
+    auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
+    if (!linalgOp)
+      return failure();
+    if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+      return failure();
+
+    // No nothing to do for linalg.generic and linalg.indexed_generic.
+    if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
+      return failure();
+
+    linalg::GenericOp genericOp =
+        createGenericOpFromNamedOp(linalgOp, rewriter);
+    if (!genericOp)
+      return failure();
+
+    rewriter.replaceOp(rootOp, genericOp.getResults());
+    marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
+    return success();
+  }
+
+private:
+  linalg::LinalgMarker marker;
+};
+
+struct LinalgGeneralizationPass
+    : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
+  void runOnFunction() override;
+};
+
+} // namespace
+
+void LinalgGeneralizationPass::runOnFunction() {
+  FuncOp func = getFunction();
+  OwningRewritePatternList patterns;
+  linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
+  linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
+  applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
+}
+
+linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
+                                                    OpBuilder &builder) const {
+  SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
+  auto iterators =
+      llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
+  return builder.create<linalg::GenericOp>(
+      convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
+      convOp.getInputBuffers(), convOp.getOutputBuffers(),
+      /*initTensors=*/ValueRange(), indexingMaps, iterators,
+      [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
+        Value mul =
+            bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
+        Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
+        bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
+      });
+}
+
+void mlir::linalg::populateLinalgConvGeneralizationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    linalg::LinalgMarker marker) {
+  patterns.insert<GeneralizeConvOp>(context, marker);
+}
+
+void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    linalg::LinalgMarker marker) {
+  patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
+}
+
+std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
+  return std::make_unique<LinalgGeneralizationPass>();
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
new file mode 100644 (file)
index 0000000..9660243
--- /dev/null
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
+
+func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) {
+  linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32>
+  return
+}
+
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+// CHECK:  #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 4 + d3 * 2, d2 * 5 + d4 * 3, d5)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
+
+// CHECK: func @generalize_conv
+// CHECK-SAME:  %[[INPUT:.+]]: memref<1x225x225x3xf32>
+// CHECK-SAME: %[[FILTER:.+]]: memref<3x3x3x32xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: memref<1x112x112x32xf32>
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[FILTER_MAP]], #[[INPUT_MAP]], #[[OUTPUT_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "window", "window", "reduction", "parallel"]
+// CHECK-SAME:  ins(%[[FILTER]], %[[INPUT]]
+// CHECK-SAME: outs(%[[OUTPUT]]
+
+// CHECK: ^{{.*}}(%[[FILTER_ARG:.+]]: f32, %[[INPUT_ARG:.+]]: f32, %[[OUTPUT_ARG:.+]]: f32)
+// CHECK:   %[[MUL:.+]] = mulf %[[FILTER_ARG]], %[[INPUT_ARG]]
+// CHECK:   %[[ADD:.+]] = addf %[[MUL]], %[[OUTPUT_ARG]]
+// CHECK:   linalg.yield %[[ADD]]
+
+// -----
+
+func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) {
+  linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>) outs(%C: memref<16x32xf32>)
+  return
+}
+
+
+// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: func @generalize_matmul_buffer
+// CHECK-SAME: %[[A:.+]]: memref<16x8xf32>
+// CHECK-SAME: %[[B:.+]]: memref<8x32xf32>
+// CHECK-SAME: %[[C:.+]]: memref<16x32xf32>
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK:   %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK:   linalg.yield %[[ADD]] : f32
+
+// -----
+
+func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) init(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK: func @generalize_matmul_tensor
+
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<16x8xf32>, tensor<8x32xf32>)
+// CHECK-SAME: init(%{{.+}} : tensor<16x32xf32>)
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
index e7e5ef8..45dc115 100644 (file)
@@ -1522,6 +1522,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         ArrayAttr iterator_types();
         ArrayAttr indexing_maps();
         static void regionBuilder(Block &block);
+        static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
 
         // Generic methods.
         static unsigned getNumRegionArgs() {{ return {4}; }