Lower linalg.copy to LLVM dialect in the presence of transposes.
authorNicolas Vasilache <ntv@google.com>
Sat, 24 Aug 2019 00:44:55 +0000 (17:44 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 24 Aug 2019 00:45:19 +0000 (17:45 -0700)
Add an extra RewritePattern that does not convert types to rewrite a CopyOp that has non-identity permutations into a sequence of TransposeOp followed by a CopyOp without such permutations.

This RewitePattern is made to fail in the non-permutation case so that the conversion pattern can kick in to lower to LLVM.

This is an instance of A->A->B lowering where A->A is done by a RewritePattern in case_1 and A->B is done by a ConversionPatternRewriter when not(case_1).

PiperOrigin-RevId: 265171380

mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/test/Linalg/llvm.mlir

index 9aba047..cac24ce 100644 (file)
@@ -173,8 +173,8 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
     enforced at the moment.
   }];
   let arguments = (ins
-    View,
-    View,
+    View:$input,
+    View:$output,
     OptionalAttr<AffineMapAttr>:$inputPermutation,
     OptionalAttr<AffineMapAttr>:$outputPermutation);
   // TODO(ntv) this should go away once the usage of OptionalAttr triggers
index d914206..0bc355a 100644 (file)
@@ -782,19 +782,6 @@ public:
     if (!f)
       return matchFailure();
 
-    if (std::is_same<LinalgOp, CopyOp>::value) {
-      auto copyOp = cast<CopyOp>(op);
-
-      // Ensure permutations are identity.
-      // TODO(ntv): insert a transpose op that captures the permutations and
-      // remove this.
-      auto inputPerm = copyOp.inputPermutation();
-      if (inputPerm.hasValue() && !inputPerm->isIdentity())
-        return matchFailure();
-      auto outputPerm = copyOp.outputPermutation();
-      if (outputPerm.hasValue() && !outputPerm->isIdentity())
-        return matchFailure();
-    }
     auto fAttr = rewriter.getSymbolRefAttr(f);
     auto named = rewriter.getNamedAttr("callee", fAttr);
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
@@ -803,11 +790,74 @@ public:
   }
 };
 
+/// Conversion pattern specialization for CopyOp. This kicks in when both input
+/// and output permutations are left unspecified or are the identity.
+template <> class LinalgOpConversion<CopyOp> : public LLVMOpLowering {
+public:
+  explicit LinalgOpConversion(MLIRContext *context,
+                              LinalgTypeConverter &lowering_)
+      : LLVMOpLowering(CopyOp::getOperationName(), context, lowering_) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto copyOp = cast<CopyOp>(op);
+    auto inputPerm = copyOp.inputPermutation();
+    if (inputPerm.hasValue() && !inputPerm->isIdentity())
+      return matchFailure();
+    auto outputPerm = copyOp.outputPermutation();
+    if (outputPerm.hasValue() && !outputPerm->isIdentity())
+      return matchFailure();
+
+    auto f = getLLVMLibraryCallDeclaration<CopyOp>(op, lowering, rewriter);
+    if (!f)
+      return matchFailure();
+
+    auto fAttr = rewriter.getSymbolRefAttr(f);
+    auto named = rewriter.getNamedAttr("callee", fAttr);
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
+                                              ArrayRef<NamedAttribute>{named});
+    return matchSuccess();
+  }
+};
+
+/// A non-conversion rewrite pattern kicks in to convert CopyOp with
+/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
+/// This interplays together with TransposeOpConversion and
+/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
+class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
+public:
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(CopyOp op,
+                                     PatternRewriter &rewriter) const override {
+    Value *in = op.input(), *out = op.output();
+
+    // If either inputPerm or outputPerm are non-identities, insert transposes.
+    auto inputPerm = op.inputPermutation();
+    if (inputPerm.hasValue() && !inputPerm->isIdentity())
+      in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
+                                                AffineMapAttr::get(*inputPerm));
+    auto outputPerm = op.outputPermutation();
+    if (outputPerm.hasValue() && !outputPerm->isIdentity())
+      out = rewriter.create<linalg::TransposeOp>(
+          op.getLoc(), out, AffineMapAttr::get(*outputPerm));
+
+    // If nothing was transposed, fail and let the conversion kick in.
+    if (in == op.input() && out == op.output())
+      return matchFailure();
+
+    rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
+    return matchSuccess();
+  }
+};
+
 /// Populate the given list with patterns that convert from Linalg to LLVM.
 static void
 populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
                                        OwningRewritePatternList &patterns,
                                        MLIRContext *ctx) {
+  patterns.insert<CopyTransposeConversion>(ctx);
   patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
                   BufferSizeOpConversion, DimOpConversion,
                   LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
index 8570d8d..d82418d 100644 (file)
@@ -214,3 +214,33 @@ func @transpose(%arg0: !linalg.view<?x?x?xf32>) {
 //       CHECK:   llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
 //       CHECK:    llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
 //       CHECK:   llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
+
+func @copy_transpose(%arg0: !linalg.view<?x?x?xf32>, %arg1: !linalg.view<?x?x?xf32>) {
+  linalg.copy(%arg0, %arg1) {inputPermutation = (i, j, k) -> (i, k, j),
+                             outputPermutation = (i, j, k) -> (k, j, i)}
+    : !linalg.view<?x?x?xf32>, !linalg.view<?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @copy
+// Tranpose input
+//       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
+// Transpose output
+//       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:   llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
+// Call external copy
+//       CHECK:   llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> ()