From 28e28e5d651b19a3f2a22c4fe4209a9d3d8c2689 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 23 Aug 2019 17:28:51 -0700 Subject: [PATCH] Lower linalg.transpose to LLVM dialect Add a conversion pattern that transforms a linalg.transpose op into: 1. A function entry `alloca` operation to allocate a ViewDescriptor. 2. A load of the ViewDescriptor from the pointer allocated in 1. 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size and stride. Size and stride are permutations of the original values. 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. The linalg.transpose op is replaced by the alloca'ed pointer. PiperOrigin-RevId: 265169112 --- mlir/include/mlir/IR/Builders.h | 14 ++- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 115 ++++++++++++++++++++- mlir/test/Linalg/llvm.mlir | 16 +++ 3 files changed, 141 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3697f5d..a58d511 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -232,6 +232,18 @@ public: Block::iterator point; }; + /// RAII guard to reset the insertion point of the builder when destroyed. + class InsertionGuard { + public: + InsertionGuard(OpBuilder &builder) + : builder(builder), ip(builder.saveInsertionPoint()) {} + ~InsertionGuard() { builder.restoreInsertionPoint(ip); } + + private: + OpBuilder &builder; + OpBuilder::InsertPoint ip; + }; + /// Reset the insertion point to no location. Creating an operation without a /// set insertion point is an error, but this can still be useful when the /// current insertion point a builder refers to is being removed. @@ -299,7 +311,7 @@ public: /// Create an operation of specific op type at the current insertion point. template - OpTy create(Location location, Args&&... args) { + OpTy create(Location location, Args &&... args) { OperationState state(location, OpTy::getOperationName()); OpTy::build(this, &state, std::forward(args)...); auto *op = createOperation(state); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index e4ce0ca..d914206 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -25,6 +25,8 @@ #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" @@ -173,6 +175,48 @@ static ArrayAttr positionAttr(Builder &builder, ArrayRef position) { return builder.getArrayAttr(attrs); } +namespace { +/// Factor out the common information for all view conversions: +/// 1. common types in (standard and LLVM dialects) +/// 2. `pos` method +/// 3. op of the FuncOp alloca'ed value and descriptor. +class BaseViewConversionHelper { +public: + BaseViewConversionHelper(Operation *op, ViewType viewType, + ConversionPatternRewriter &rewriter, + LLVMTypeConverter &lowering) + : indexType(rewriter.getIndexType()), viewType(viewType), + elementTy(getPtrToElementType(viewType, lowering)), + int64Ty( + lowering.convertType(rewriter.getIntegerType(64)).cast()), + viewDescriptorPtrTy( + convertLinalgType(viewType, lowering).cast()), + rewriter(rewriter) { + + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart( + &op->getParentOfType().getBlocks().front()); + + edsc::ScopedContext context(rewriter, op->getLoc()); + one = constant(int64Ty, IntegerAttr::get(indexType, 1)); + // Alloca with proper alignment. + allocatedDesc = llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8); + // Load the alloca'ed descriptor. + desc = llvm_load(allocatedDesc); + } + + ArrayAttr pos(ArrayRef values) const { + return positionAttr(rewriter, values); + }; + + IndexType indexType; + ViewType viewType; + LLVMType elementTy, int64Ty, viewDescriptorPtrTy; + ConversionPatternRewriter &rewriter; + Value *one, *allocatedDesc, *desc; +}; +} // namespace + // BufferAllocOp creates a new `!linalg.buffer` value. class BufferAllocOpConversion : public LLVMOpLowering { public: @@ -222,8 +266,7 @@ public: mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize))); Value *one = nullptr, *align = nullptr; if (allocOp.alignment().hasValue()) { - one = constant(int64Ty, - rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + one = constant(int64Ty, IntegerAttr::get(indexType, 1)); align = constant(int64Ty, rewriter.getIntegerAttr( rewriter.getIndexType(), @@ -530,6 +573,71 @@ class StoreOpConversion : public LoadStoreOpConversion { } }; +/// Conversion pattern that transforms a linalg.transpose op into: +/// 1. A function entry `alloca` operation to allocate a ViewDescriptor. +/// 2. A load of the ViewDescriptor from the pointer allocated in 1. +/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size +/// and stride. Size and stride are permutations of the original values. +/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. +/// The linalg.transpose op is replaced by the alloca'ed pointer. +class TransposeOpConversion : public LLVMOpLowering { +public: + explicit TransposeOpConversion(MLIRContext *context, + LLVMTypeConverter &lowering_) + : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Initialize the common boilerplate and alloca at the top of the FuncOp. + TransposeOpOperandAdaptor adaptor(operands); + auto tranposeOp = cast(op); + BaseViewConversionHelper helper(op, tranposeOp.getViewType(), rewriter, + lowering); + IndexType indexType = helper.indexType; + ViewType viewType = helper.viewType; + LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty, + viewDescriptorPtrTy = helper.viewDescriptorPtrTy; + Value *allocatedDesc = helper.allocatedDesc, *desc = helper.desc; + + edsc::ScopedContext context(rewriter, op->getLoc()); + // Load the descriptor of the view constructed by the helper. + Value *baseDesc = llvm_load(adaptor.view()); + + // Copy the base pointer from the old descriptor to the new one. + ArrayAttr ptrPos = helper.pos(kPtrPosInView); + desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); + + // Copy the offset pointer from the old descriptor to the new one. + ArrayAttr offPos = helper.pos(kOffsetPosInView); + desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos); + + if (tranposeOp.permutation().isIdentity()) { + // No permutation, just store back in alloca'ed region. + llvm_store(desc, allocatedDesc); + return rewriter.replaceOp(op, allocatedDesc), matchSuccess(); + } + + // Iterate over the dimensions and apply size/stride permutation. + for (auto en : llvm::enumerate(tranposeOp.permutation().getResults())) { + int sourcePos = en.index(); + int targetPos = en.value().cast().getPosition(); + Value *size = extractvalue(int64Ty, baseDesc, + helper.pos({kSizePosInView, sourcePos})); + desc = insertvalue(desc, size, helper.pos({kSizePosInView, targetPos})); + Value *stride = extractvalue(int64Ty, baseDesc, + helper.pos({kStridePosInView, sourcePos})); + desc = + insertvalue(desc, stride, helper.pos({kStridePosInView, targetPos})); + } + + // Store back in alloca'ed region. + llvm_store(desc, allocatedDesc); + rewriter.replaceOp(op, allocatedDesc); + return matchSuccess(); + } +}; + /// Conversion pattern that transforms a linalg.view op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. @@ -705,7 +813,8 @@ populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LoadOpConversion, RangeOpConversion, SliceOpConversion, - StoreOpConversion, ViewOpConversion>(ctx, converter); + StoreOpConversion, TransposeOpConversion, ViewOpConversion>( + ctx, converter); } namespace { diff --git a/mlir/test/Linalg/llvm.mlir b/mlir/test/Linalg/llvm.mlir index 5246103..8570d8d 100644 --- a/mlir/test/Linalg/llvm.mlir +++ b/mlir/test/Linalg/llvm.mlir @@ -198,3 +198,19 @@ func @copy(%arg0: !linalg.view, %arg1: !linalg.view) { } // CHECK-LABEL: func @copy // CHECK: llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> () + +func @transpose(%arg0: !linalg.view) { + %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : !linalg.view + return +} +// CHECK-LABEL: func @transpose +// CHECK: llvm.alloca {{.*}} x !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> +// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 0] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.store {{.*}} : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*"> -- 2.7.4