--- /dev/null
+//===- Generalization.cpp - linalg named ops to generic ops --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Linalg generalization pass. It converts named
+// Linalg ops to linalg.generic ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-generalization"
+
+using namespace mlir;
+
+// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
+// the given `namedOp` does not have a region builder.
+static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
+ OpBuilder &builder) {
+ auto regionBuilder = namedOp.getRegionBuilder();
+ if (!regionBuilder) {
+ LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
+ return nullptr;
+ }
+
+ SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
+ auto iterators = llvm::to_vector<4>(
+ namedOp.iterator_types().getAsValueRange<StringAttr>());
+ auto resultTypes = namedOp.getOutputTensorTypes();
+ SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
+
+ return builder.create<linalg::GenericOp>(
+ namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputBuffers(),
+ namedOp.getInitTensors(), indexingMaps, iterators,
+ [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
+ edsc::ScopedContext scope(bodyBuilder, loc);
+ regionBuilder(*bodyBuilder.getBlock());
+ });
+}
+
+namespace {
+
+/// Base class for all linalg generalization patterns. A subclass must provide
+/// the following method:
+/// linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
+/// for creating the generic op.
+// TODO: remove this pattern after migrating all manually-written named ops
+// into auto-generated ones.
+template <typename ConcretePattern, typename RootOp>
+struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
+ LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
+
+ LogicalResult matchAndRewrite(RootOp rootOp,
+ PatternRewriter &rewriter) const override {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
+ if (!linalgOp)
+ return failure();
+ if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+ return failure();
+
+ auto *pattern = static_cast<const ConcretePattern *>(this);
+ linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
+ if (!genericOp)
+ return failure();
+
+ rewriter.replaceOp(rootOp, genericOp.getResults());
+ marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
+ return success();
+ }
+
+private:
+ linalg::LinalgMarker marker;
+};
+
+struct GeneralizeConvOp
+ : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
+ using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
+
+ linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
+};
+
+/// Catch-all pattern for converting all named ops with a region builder into
+/// linalg.generic.
+struct LinalgNamedOpGeneralizationPattern : RewritePattern {
+ LinalgNamedOpGeneralizationPattern(MLIRContext *context,
+ linalg::LinalgMarker marker,
+ PatternBenefit benefit = 1)
+ : RewritePattern(benefit, MatchAnyOpTypeTag()),
+ marker(std::move(marker)) {}
+
+ LogicalResult matchAndRewrite(Operation *rootOp,
+ PatternRewriter &rewriter) const override {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
+ if (!linalgOp)
+ return failure();
+ if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+ return failure();
+
+ // No nothing to do for linalg.generic and linalg.indexed_generic.
+ if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
+ return failure();
+
+ linalg::GenericOp genericOp =
+ createGenericOpFromNamedOp(linalgOp, rewriter);
+ if (!genericOp)
+ return failure();
+
+ rewriter.replaceOp(rootOp, genericOp.getResults());
+ marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
+ return success();
+ }
+
+private:
+ linalg::LinalgMarker marker;
+};
+
+struct LinalgGeneralizationPass
+ : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
+ void runOnFunction() override;
+};
+
+} // namespace
+
+void LinalgGeneralizationPass::runOnFunction() {
+ FuncOp func = getFunction();
+ OwningRewritePatternList patterns;
+ linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
+ linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
+ applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
+}
+
+linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
+ OpBuilder &builder) const {
+ SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
+ auto iterators =
+ llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
+ return builder.create<linalg::GenericOp>(
+ convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
+ convOp.getInputBuffers(), convOp.getOutputBuffers(),
+ /*initTensors=*/ValueRange(), indexingMaps, iterators,
+ [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
+ Value mul =
+ bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
+ Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
+ bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
+ });
+}
+
+void mlir::linalg::populateLinalgConvGeneralizationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns,
+ linalg::LinalgMarker marker) {
+ patterns.insert<GeneralizeConvOp>(context, marker);
+}
+
+void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns,
+ linalg::LinalgMarker marker) {
+ patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
+}
+
+std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
+ return std::make_unique<LinalgGeneralizationPass>();
+}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
+
+func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) {
+ linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32>
+ return
+}
+
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 4 + d3 * 2, d2 * 5 + d4 * 3, d5)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
+
+// CHECK: func @generalize_conv
+// CHECK-SAME: %[[INPUT:.+]]: memref<1x225x225x3xf32>
+// CHECK-SAME: %[[FILTER:.+]]: memref<3x3x3x32xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: memref<1x112x112x32xf32>
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[FILTER_MAP]], #[[INPUT_MAP]], #[[OUTPUT_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "window", "window", "reduction", "parallel"]
+// CHECK-SAME: ins(%[[FILTER]], %[[INPUT]]
+// CHECK-SAME: outs(%[[OUTPUT]]
+
+// CHECK: ^{{.*}}(%[[FILTER_ARG:.+]]: f32, %[[INPUT_ARG:.+]]: f32, %[[OUTPUT_ARG:.+]]: f32)
+// CHECK: %[[MUL:.+]] = mulf %[[FILTER_ARG]], %[[INPUT_ARG]]
+// CHECK: %[[ADD:.+]] = addf %[[MUL]], %[[OUTPUT_ARG]]
+// CHECK: linalg.yield %[[ADD]]
+
+// -----
+
+func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) {
+ linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>) outs(%C: memref<16x32xf32>)
+ return
+}
+
+
+// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: func @generalize_matmul_buffer
+// CHECK-SAME: %[[A:.+]]: memref<16x8xf32>
+// CHECK-SAME: %[[B:.+]]: memref<8x32xf32>
+// CHECK-SAME: %[[C:.+]]: memref<16x32xf32>
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+
+// -----
+
+func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) init(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK: func @generalize_matmul_tensor
+
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<16x8xf32>, tensor<8x32xf32>)
+// CHECK-SAME: init(%{{.+}} : tensor<16x32xf32>)
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>