From: Nicolas Vasilache Date: Tue, 12 Nov 2019 15:22:51 +0000 (-0800) Subject: Add LLVM lowering of std.subview X-Git-Tag: llvmorg-11-init~1466^2~361 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=51de3f688ea99d55dd1ed69706d2e055f231e4fb;p=platform%2Fupstream%2Fllvm.git Add LLVM lowering of std.subview A followup CL will replace usage of linalg.subview by std.subview. PiperOrigin-RevId: 279961981 --- diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 84eba82..791a237 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1412,6 +1412,118 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern { } }; +/// Conversion pattern that transforms a subview op into: +/// 1. An `llvm.mlir.undef` operation to create a memref descriptor +/// 2. Updates to the descriptor to introduce the data ptr, offset, size +/// and stride. +/// The subview op is replaced by the descriptor. +struct SubViewOpLowering : public LLVMLegalizationPattern { + using LLVMLegalizationPattern::LLVMLegalizationPattern; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto viewOp = cast(op); + SubViewOpOperandAdaptor adaptor(operands); + auto sourceMemRefType = viewOp.source()->getType().cast(); + auto sourceElementTy = + lowering.convertType(sourceMemRefType.getElementType()) + .dyn_cast_or_null(); + + auto viewMemRefType = viewOp.getType(); + auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) + .dyn_cast(); + auto targetDescTy = + lowering.convertType(viewMemRefType).dyn_cast_or_null(); + if (!sourceElementTy || !targetDescTy) + return matchFailure(); + + // Early exit for 0-D and operands lesser than `rank` corner cases. + unsigned rank = sourceMemRefType.getRank(); + if (viewMemRefType.getRank() == 0 || rank != adaptor.offsets().size() || + rank != adaptor.sizes().size() || rank != adaptor.strides().size()) + return matchFailure(); + + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); + if (failed(successStrides)) + return matchFailure(); + + // Create the descriptor. + Value *desc = rewriter.create(loc, targetDescTy); + + // Copy the buffer pointer from the old descriptor to the new one. + Value *sourceDescriptor = adaptor.source(); + Value *extracted = rewriter.create( + loc, sourceElementTy.getPointerTo(), sourceDescriptor, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + Value *bitcastPtr = rewriter.create( + loc, targetElementTy.getPointerTo(), extracted); + desc = rewriter.create( + loc, desc, bitcastPtr, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + extracted = rewriter.create( + loc, sourceElementTy.getPointerTo(), sourceDescriptor, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + bitcastPtr = rewriter.create( + loc, targetElementTy.getPointerTo(), extracted); + desc = rewriter.create( + loc, desc, bitcastPtr, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + + // Extract strides needed to compute offset. + SmallVector strideValues; + strideValues.reserve(viewMemRefType.getRank()); + for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { + strideValues.push_back(rewriter.create( + loc, getIndexType(), sourceDescriptor, + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}))); + } + + // Offset. + Value *baseOffset = rewriter.create( + loc, getIndexType(), sourceDescriptor, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { + Value *min = adaptor.offsets()[i]; + baseOffset = rewriter.create( + loc, baseOffset, + rewriter.create(loc, min, strideValues[i])); + } + desc = rewriter.create( + loc, desc, baseOffset, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + + // Update sizes and strides. + for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { + // Update size. + desc = rewriter.create( + loc, desc, adaptor.sizes()[i], + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, i})); + // Update stride. + desc = rewriter.create( + loc, desc, + rewriter.create(loc, adaptor.strides()[i], + strideValues[i]), + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); + } + + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + /// Conversion pattern that transforms a op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size @@ -1647,6 +1759,7 @@ void mlir::populateStdToLLVMConversionPatterns( StoreOpLowering, SubFOpLowering, SubIOpLowering, + SubViewOpLowering, TruncateIOpLowering, ViewOpLowering, XOrOpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index e0669dc2..b34b345 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -708,3 +708,29 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) { return } + +// CHECK-LABEL: func @subview( +// CHECK: %[[MEMREFPTR:.*]]: !llvm<{{.*}}>, %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64 +func @subview(%0 : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>, %arg0 : index, %arg1 : index, %arg2 : index) { + // CHECK: %[[MEMREF:.*]] = llvm.load %[[MEMREFPTR]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*"> + // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64 + // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i64 + // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64 + // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i64 + // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64 + // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64 + // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : + memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> to memref (d0 * s0 + d1 + s1)> + return +}