From ee5c2256ef31fefc92ad59f78b0649b145dc0eb0 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 14 Nov 2019 00:48:41 -0800 Subject: [PATCH] Concentrate memref descriptor manipulation logic in one place Memref descriptor is becoming increasingly complex. Memrefs are manipulated by multiple standard instructions, each of which has a non-trivial lowering to the LLVM dialect. This leads to verbose code that manipulates the descriptors exposing the internals of insert/extractelement opreations. Implement a wrapper class that contains a memref descriptor and provides semantically named methods that build the primitive IR operations instead. PiperOrigin-RevId: 280371225 --- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 339 +++++++++++---------- 1 file changed, 174 insertions(+), 165 deletions(-) diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 791a237..0641a6b 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -235,6 +235,125 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} namespace { +/// Helper class to produce LLVM dialect operations extracting or inserting +/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. +/// The Value may be null, in which case none of the operations are valid. +class MemRefDescriptor { +public: + /// Construct a helper for the given descriptor value. + explicit MemRefDescriptor(Value *descriptor) : value(descriptor) { + if (value) { + structType = value->getType().cast(); + indexType = value->getType().cast().getStructElementType( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor); + } + } + + /// Builds IR creating an `undef` value of the descriptor type. + static MemRefDescriptor undef(OpBuilder &builder, Location loc, + Type descriptorType) { + Value *descriptor = builder.create( + loc, descriptorType.cast()); + return MemRefDescriptor(descriptor); + } + + /// Builds IR extracting the allocated pointer from the descriptor. + Value *allocatedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); + } + + /// Builds IR inserting the allocated pointer into the descriptor. + void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) { + setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor, + ptr); + } + + /// Builds IR extracting the aligned pointer from the descriptor. + Value *alignedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); + } + + /// Builds IR inserting the aligned pointer into the descriptor. + void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) { + setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor, + ptr); + } + + /// Builds IR extracting the offset from the descriptor. + Value *offset(OpBuilder &builder, Location loc) { + return builder.create( + loc, indexType, value, + builder.getI64ArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + } + + /// Builds IR inserting the offset into the descriptor. + void setOffset(OpBuilder &builder, Location loc, Value *offset) { + value = builder.create( + loc, structType, value, offset, + builder.getI64ArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + } + + /// Builds IR extracting the pos-th size from the descriptor. + Value *size(OpBuilder &builder, Location loc, unsigned pos) { + return builder.create( + loc, indexType, value, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); + } + + /// Builds IR inserting the pos-th size into the descriptor + void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) { + value = builder.create( + loc, structType, value, size, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); + } + + /// Builds IR extracting the pos-th size from the descriptor. + Value *stride(OpBuilder &builder, Location loc, unsigned pos) { + return builder.create( + loc, indexType, value, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); + } + + /// Builds IR inserting the pos-th stride into the descriptor + void setStride(OpBuilder &builder, Location loc, unsigned pos, + Value *stride) { + value = builder.create( + loc, structType, value, stride, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); + } + + /*implicit*/ operator Value *() { return value; } + +private: + Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) { + Type type = structType.getStructElementType(pos); + return builder.create(loc, type, value, + builder.getI64ArrayAttr(pos)); + } + + void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) { + value = builder.create(loc, structType, value, ptr, + builder.getI64ArrayAttr(pos)); + } + + // Cached descriptor type. + LLVM::LLVMType structType; + + // Cached index type. + LLVM::LLVMType indexType; + + // Actual descriptor. + Value *value; +}; + // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in // case it is necessary for rewriters. @@ -278,29 +397,6 @@ public: return builder.create(loc, getIndexType(), attr); } - // Extract allocated data pointer value from a value representing a memref. - static Value * - extractAllocatedMemRefElementPtr(ConversionPatternRewriter &builder, - Location loc, Value *memref, - Type elementTypePtr) { - return builder.create( - loc, elementTypePtr, memref, - builder.getI64ArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); - } - - // Extract properly aligned data pointer value from a value representing a - // memref. - static Value * - extractAlignedMemRefElementPtr(ConversionPatternRewriter &builder, - Location loc, Value *memref, - Type elementTypePtr) { - return builder.create( - loc, elementTypePtr, memref, - builder.getI64ArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); - } - protected: LLVM::LLVMDialect &dialect; }; @@ -786,14 +882,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Create the MemRef descriptor. auto structType = lowering.convertType(type); - Value *memRefDescriptor = - rewriter.create(loc, structType, ArrayRef{}); - + auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); // Field 1: Allocated pointer, used for malloc/free. - memRefDescriptor = rewriter.create( - loc, structType, memRefDescriptor, bitcastAllocated, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated); + // Field 2: Actual aligned pointer to payload. Value *bitcastAligned = bitcastAllocated; if (align) { @@ -808,20 +900,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern { bitcastAligned = rewriter.create( loc, elementPtrType, ArrayRef(aligned)); } - memRefDescriptor = rewriter.create( - loc, structType, memRefDescriptor, bitcastAligned, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned); + // Field 3: Offset in aligned pointer. - memRefDescriptor = rewriter.create( - loc, structType, memRefDescriptor, - createIndexConstant(rewriter, loc, offset), - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + memRefDescriptor.setOffset(rewriter, loc, + createIndexConstant(rewriter, loc, offset)); if (type.getRank() == 0) // No size/stride descriptor in memref, return the descriptor value. - return rewriter.replaceOp(op, memRefDescriptor); + return rewriter.replaceOp(op, {memRefDescriptor}); // Fields 4 and 5: Sizes and strides of the strided MemRef. // Store all sizes in the descriptor. Only dynamic sizes are passed in as @@ -846,18 +933,12 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(sizes)) { int64_t index = indexedSize.index(); - memRefDescriptor = rewriter.create( - loc, structType, memRefDescriptor, indexedSize.value(), - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); - memRefDescriptor = rewriter.create( - loc, structType, memRefDescriptor, strideValues[index], - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, index})); + memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); + memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); } // Return the final value of the descriptor. - rewriter.replaceOp(op, memRefDescriptor); + rewriter.replaceOp(op, {memRefDescriptor}); } }; @@ -947,13 +1028,10 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { /*isVarArg=*/false)); } - auto type = transformed.memref()->getType().cast(); - Type elementPtrType = type.getStructElementType( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); - Value *bufferPtr = extractAllocatedMemRefElementPtr( - rewriter, op->getLoc(), transformed.memref(), elementPtrType); + MemRefDescriptor memref(transformed.memref()); Value *casted = rewriter.create( - op->getLoc(), getVoidPtrType(), bufferPtr); + op->getLoc(), getVoidPtrType(), + memref.allocatedPtr(rewriter, op->getLoc())); rewriter.replaceOpWithNewOp( op, ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), casted); return matchSuccess(); @@ -1003,10 +1081,8 @@ struct DimOpLowering : public LLVMLegalizationPattern { int64_t index = dimOp.getIndex(); // Extract dynamic size from the memref descriptor. if (ShapedType::isDynamic(shape[index])) - rewriter.replaceOpWithNewOp( - op, getIndexType(), transformed.memrefOrTensor(), - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, index})); + rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor()) + .size(rewriter, op->getLoc(), index)}); else // Use constant for static size. rewriter.replaceOp( @@ -1058,34 +1134,21 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. Value *getStridedElementPtr(Location loc, Type elementTypePtr, - Value *memRefDescriptor, - ArrayRef indices, + Value *descriptor, ArrayRef indices, ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const { - auto indexTy = this->getIndexType(); - Value *base = this->extractAlignedMemRefElementPtr( - rewriter, loc, memRefDescriptor, elementTypePtr); - Value *offsetValue = - offset == MemRefType::getDynamicStrideOrOffset() - ? rewriter.create( - loc, indexTy, memRefDescriptor, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)) - : this->createIndexConstant(rewriter, loc, offset); + MemRefDescriptor memRefDescriptor(descriptor); + + Value *base = memRefDescriptor.alignedPtr(rewriter, loc); + Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.offset(rewriter, loc) + : this->createIndexConstant(rewriter, loc, offset); + for (int i = 0, e = indices.size(); i < e; ++i) { - Value *stride; - if (strides[i] != MemRefType::getDynamicStrideOrOffset()) { - // Use static stride. - auto attr = - rewriter.getIntegerAttr(rewriter.getIndexType(), strides[i]); - stride = rewriter.create(loc, indexTy, attr); - } else { - // Use dynamic stride. - stride = rewriter.create( - loc, indexTy, memRefDescriptor, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); - } + Value *stride = + strides[i] == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.stride(rewriter, loc, i) + : this->createIndexConstant(rewriter, loc, strides[i]); Value *additionalOffset = rewriter.create(loc, indices[i], stride); offsetValue = @@ -1452,74 +1515,45 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { return matchFailure(); // Create the descriptor. - Value *desc = rewriter.create(loc, targetDescTy); + MemRefDescriptor sourceMemRef(adaptor.source()); + auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. - Value *sourceDescriptor = adaptor.source(); - Value *extracted = rewriter.create( - loc, sourceElementTy.getPointerTo(), sourceDescriptor, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value *bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); - desc = rewriter.create( - loc, desc, bitcastPtr, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); - extracted = rewriter.create( - loc, sourceElementTy.getPointerTo(), sourceDescriptor, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); + + extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); - desc = rewriter.create( - loc, desc, bitcastPtr, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Extract strides needed to compute offset. SmallVector strideValues; strideValues.reserve(viewMemRefType.getRank()); - for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { - strideValues.push_back(rewriter.create( - loc, getIndexType(), sourceDescriptor, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}))); - } + for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) + strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Offset. - Value *baseOffset = rewriter.create( - loc, getIndexType(), sourceDescriptor, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + Value *baseOffset = sourceMemRef.offset(rewriter, loc); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { Value *min = adaptor.offsets()[i]; baseOffset = rewriter.create( loc, baseOffset, rewriter.create(loc, min, strideValues[i])); } - desc = rewriter.create( - loc, desc, baseOffset, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + targetMemRef.setOffset(rewriter, loc, baseOffset); // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { - // Update size. - desc = rewriter.create( - loc, desc, adaptor.sizes()[i], - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, i})); - // Update stride. - desc = rewriter.create( - loc, desc, - rewriter.create(loc, adaptor.strides()[i], - strideValues[i]), - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); + targetMemRef.setSize(rewriter, loc, i, adaptor.sizes()[i]); + targetMemRef.setStride(rewriter, loc, i, + rewriter.create( + loc, adaptor.strides()[i], strideValues[i])); } - rewriter.replaceOp(op, desc); + rewriter.replaceOp(op, {targetMemRef}); return matchSuccess(); } }; @@ -1571,10 +1605,6 @@ struct ViewOpLowering : public LLVMLegalizationPattern { auto loc = op->getLoc(); auto viewOp = cast(op); ViewOpOperandAdaptor adaptor(operands); - auto sourceMemRefType = viewOp.source()->getType().cast(); - auto sourceElementTy = - lowering.convertType(sourceMemRefType.getElementType()) - .dyn_cast(); auto viewMemRefType = viewOp.getType(); auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) @@ -1593,32 +1623,20 @@ struct ViewOpLowering : public LLVMLegalizationPattern { matchFailure(); // Create the descriptor. - Value *desc = rewriter.create(loc, targetDescTy); + MemRefDescriptor sourceMemRef(adaptor.source()); + auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. - Value *sourceDescriptor = adaptor.source(); - Value *extracted = rewriter.create( - loc, sourceElementTy.getPointerTo(), sourceDescriptor, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value *bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); - desc = rewriter.create( - loc, desc, bitcastPtr, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor)); + targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. - extracted = rewriter.create( - loc, sourceElementTy.getPointerTo(), sourceDescriptor, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); - desc = rewriter.create( - loc, desc, bitcastPtr, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor)); + targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Field 3: Copy the offset in aligned pointer. unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); @@ -1630,14 +1648,11 @@ struct ViewOpLowering : public LLVMLegalizationPattern { ? createIndexConstant(rewriter, loc, offset) // TODO(ntv): better adaptor. : sizeAndOffsetOperands.back(); - desc = rewriter.create( - loc, desc, baseOffset, - rewriter.getI64ArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + targetMemRef.setOffset(rewriter, loc, baseOffset); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) - return rewriter.replaceOp(op, desc), matchSuccess(); + return rewriter.replaceOp(op, {targetMemRef}), matchSuccess(); // Fields 4 and 5: Update sizes and strides. if (strides.back() != 1) @@ -1648,20 +1663,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { // Update size. Value *size = getSize(rewriter, loc, viewMemRefType.getShape(), sizeAndOffsetOperands, i); - desc = rewriter.create( - loc, desc, size, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, i})); + targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); - desc = rewriter.create( - loc, desc, stride, - rewriter.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); + targetMemRef.setStride(rewriter, loc, i, stride); nextSize = size; } - rewriter.replaceOp(op, desc); + rewriter.replaceOp(op, {targetMemRef}); return matchSuccess(); } }; -- 2.7.4