enforced at the moment.
}];
let arguments = (ins
- View,
- View,
+ View:$input,
+ View:$output,
OptionalAttr<AffineMapAttr>:$inputPermutation,
OptionalAttr<AffineMapAttr>:$outputPermutation);
// TODO(ntv) this should go away once the usage of OptionalAttr triggers
if (!f)
return matchFailure();
- if (std::is_same<LinalgOp, CopyOp>::value) {
- auto copyOp = cast<CopyOp>(op);
-
- // Ensure permutations are identity.
- // TODO(ntv): insert a transpose op that captures the permutations and
- // remove this.
- auto inputPerm = copyOp.inputPermutation();
- if (inputPerm.hasValue() && !inputPerm->isIdentity())
- return matchFailure();
- auto outputPerm = copyOp.outputPermutation();
- if (outputPerm.hasValue() && !outputPerm->isIdentity())
- return matchFailure();
- }
auto fAttr = rewriter.getSymbolRefAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
}
};
+/// Conversion pattern specialization for CopyOp. This kicks in when both input
+/// and output permutations are left unspecified or are the identity.
+template <> class LinalgOpConversion<CopyOp> : public LLVMOpLowering {
+public:
+ explicit LinalgOpConversion(MLIRContext *context,
+ LinalgTypeConverter &lowering_)
+ : LLVMOpLowering(CopyOp::getOperationName(), context, lowering_) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto copyOp = cast<CopyOp>(op);
+ auto inputPerm = copyOp.inputPermutation();
+ if (inputPerm.hasValue() && !inputPerm->isIdentity())
+ return matchFailure();
+ auto outputPerm = copyOp.outputPermutation();
+ if (outputPerm.hasValue() && !outputPerm->isIdentity())
+ return matchFailure();
+
+ auto f = getLLVMLibraryCallDeclaration<CopyOp>(op, lowering, rewriter);
+ if (!f)
+ return matchFailure();
+
+ auto fAttr = rewriter.getSymbolRefAttr(f);
+ auto named = rewriter.getNamedAttr("callee", fAttr);
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
+ ArrayRef<NamedAttribute>{named});
+ return matchSuccess();
+ }
+};
+
+/// A non-conversion rewrite pattern kicks in to convert CopyOp with
+/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
+/// This interplays together with TransposeOpConversion and
+/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
+class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
+public:
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(CopyOp op,
+ PatternRewriter &rewriter) const override {
+ Value *in = op.input(), *out = op.output();
+
+ // If either inputPerm or outputPerm are non-identities, insert transposes.
+ auto inputPerm = op.inputPermutation();
+ if (inputPerm.hasValue() && !inputPerm->isIdentity())
+ in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
+ AffineMapAttr::get(*inputPerm));
+ auto outputPerm = op.outputPermutation();
+ if (outputPerm.hasValue() && !outputPerm->isIdentity())
+ out = rewriter.create<linalg::TransposeOp>(
+ op.getLoc(), out, AffineMapAttr::get(*outputPerm));
+
+ // If nothing was transposed, fail and let the conversion kick in.
+ if (in == op.input() && out == op.output())
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
+ return matchSuccess();
+ }
+};
+
/// Populate the given list with patterns that convert from Linalg to LLVM.
static void
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx) {
+ patterns.insert<CopyTransposeConversion>(ctx);
patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DimOpConversion,
LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
+
+func @copy_transpose(%arg0: !linalg.view<?x?x?xf32>, %arg1: !linalg.view<?x?x?xf32>) {
+ linalg.copy(%arg0, %arg1) {inputPermutation = (i, j, k) -> (i, k, j),
+ outputPermutation = (i, j, k) -> (k, j, i)}
+ : !linalg.view<?x?x?xf32>, !linalg.view<?x?x?xf32>
+ return
+}
+// CHECK-LABEL: func @copy
+// Tranpose input
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
+// Transpose output
+// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">
+// Call external copy
+// CHECK: llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> ()