From 46084c03f42f598f766028e63325dffd9d66b3d7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 23 Aug 2019 17:44:55 -0700 Subject: [PATCH] Lower linalg.copy to LLVM dialect in the presence of transposes. Add an extra RewritePattern that does not convert types to rewrite a CopyOp that has non-identity permutations into a sequence of TransposeOp followed by a CopyOp without such permutations. This RewitePattern is made to fail in the non-permutation case so that the conversion pattern can kick in to lower to LLVM. This is an instance of A->A->B lowering where A->A is done by a RewritePattern in case_1 and A->B is done by a ConversionPatternRewriter when not(case_1). PiperOrigin-RevId: 265171380 --- .../mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 4 +- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 76 ++++++++++++++++++---- mlir/test/Linalg/llvm.mlir | 30 +++++++++ 3 files changed, 95 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 9aba047..cac24ce 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -173,8 +173,8 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { enforced at the moment. }]; let arguments = (ins - View, - View, + View:$input, + View:$output, OptionalAttr:$inputPermutation, OptionalAttr:$outputPermutation); // TODO(ntv) this should go away once the usage of OptionalAttr triggers diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index d914206..0bc355a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -782,19 +782,6 @@ public: if (!f) return matchFailure(); - if (std::is_same::value) { - auto copyOp = cast(op); - - // Ensure permutations are identity. - // TODO(ntv): insert a transpose op that captures the permutations and - // remove this. - auto inputPerm = copyOp.inputPermutation(); - if (inputPerm.hasValue() && !inputPerm->isIdentity()) - return matchFailure(); - auto outputPerm = copyOp.outputPermutation(); - if (outputPerm.hasValue() && !outputPerm->isIdentity()) - return matchFailure(); - } auto fAttr = rewriter.getSymbolRefAttr(f); auto named = rewriter.getNamedAttr("callee", fAttr); rewriter.replaceOpWithNewOp(op, operands, @@ -803,11 +790,74 @@ public: } }; +/// Conversion pattern specialization for CopyOp. This kicks in when both input +/// and output permutations are left unspecified or are the identity. +template <> class LinalgOpConversion : public LLVMOpLowering { +public: + explicit LinalgOpConversion(MLIRContext *context, + LinalgTypeConverter &lowering_) + : LLVMOpLowering(CopyOp::getOperationName(), context, lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto copyOp = cast(op); + auto inputPerm = copyOp.inputPermutation(); + if (inputPerm.hasValue() && !inputPerm->isIdentity()) + return matchFailure(); + auto outputPerm = copyOp.outputPermutation(); + if (outputPerm.hasValue() && !outputPerm->isIdentity()) + return matchFailure(); + + auto f = getLLVMLibraryCallDeclaration(op, lowering, rewriter); + if (!f) + return matchFailure(); + + auto fAttr = rewriter.getSymbolRefAttr(f); + auto named = rewriter.getNamedAttr("callee", fAttr); + rewriter.replaceOpWithNewOp(op, operands, + ArrayRef{named}); + return matchSuccess(); + } +}; + +/// A non-conversion rewrite pattern kicks in to convert CopyOp with +/// permutations into a sequence of TransposeOp and permutation-free CopyOp. +/// This interplays together with TransposeOpConversion and +/// LinalgConversion to create a path to the LLVM dialect. +class CopyTransposeConversion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override { + Value *in = op.input(), *out = op.output(); + + // If either inputPerm or outputPerm are non-identities, insert transposes. + auto inputPerm = op.inputPermutation(); + if (inputPerm.hasValue() && !inputPerm->isIdentity()) + in = rewriter.create(op.getLoc(), in, + AffineMapAttr::get(*inputPerm)); + auto outputPerm = op.outputPermutation(); + if (outputPerm.hasValue() && !outputPerm->isIdentity()) + out = rewriter.create( + op.getLoc(), out, AffineMapAttr::get(*outputPerm)); + + // If nothing was transposed, fail and let the conversion kick in. + if (in == op.input() && out == op.output()) + return matchFailure(); + + rewriter.replaceOpWithNewOp(op, in, out); + return matchSuccess(); + } +}; + /// Populate the given list with patterns that convert from Linalg to LLVM. static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); patterns.insert, LinalgOpConversion, diff --git a/mlir/test/Linalg/llvm.mlir b/mlir/test/Linalg/llvm.mlir index 8570d8d..d82418d 100644 --- a/mlir/test/Linalg/llvm.mlir +++ b/mlir/test/Linalg/llvm.mlir @@ -214,3 +214,33 @@ func @transpose(%arg0: !linalg.view) { // CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> + +func @copy_transpose(%arg0: !linalg.view, %arg1: !linalg.view) { + linalg.copy(%arg0, %arg1) {inputPermutation = (i, j, k) -> (i, k, j), + outputPermutation = (i, j, k) -> (k, j, i)} + : !linalg.view, !linalg.view + return +} +// CHECK-LABEL: func @copy +// Tranpose input +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> +// Transpose output +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> +// Call external copy +// CHECK: llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> () -- 2.7.4