ArithmeticOp<mnemonic, traits>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
+// Base class for memref allocating ops: alloca and alloc.
+//
+// %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)>
+//
+class AllocLikeOp<string mnemonic, list<OpTrait> traits = []> :
+ Std_Op<mnemonic, traits> {
+
+ let arguments = (ins Variadic<Index>:$value,
+ Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$alignment);
+ let results = (outs AnyMemRef);
+
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, MemRefType memrefType", [{
+ result.types.push_back(memrefType);
+ }]>,
+ OpBuilder<
+ "Builder *builder, OperationState &result, MemRefType memrefType, " #
+ "ValueRange operands, IntegerAttr alignment = IntegerAttr()", [{
+ result.addOperands(operands);
+ result.types.push_back(memrefType);
+ if (alignment)
+ result.addAttribute(getAlignmentAttrName(), alignment);
+ }]>];
+
+ let extraClassDeclaration = [{
+ static StringRef getAlignmentAttrName() { return "alignment"; }
+
+ MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+
+ /// Returns the number of symbolic operands (the ones in square brackets),
+ /// which bind to the symbols of the memref's layout map.
+ unsigned getNumSymbolicOperands() {
+ return getNumOperands() - getType().getNumDynamicDims();
+ }
+
+ /// Returns the symbolic operands (the ones in square brackets), which bind
+ /// to the symbols of the memref's layout map.
+ operand_range getSymbolicOperands() {
+ return {operand_begin() + getType().getNumDynamicDims(), operand_end()};
+ }
+
+ /// Returns the dynamic sizes for this alloc operation if specified.
+ operand_range getDynamicSizes() { return getOperands(); }
+ }];
+
+ let parser = [{ return ::parseAllocLikeOp(parser, result); }];
+
+ let hasCanonicalizer = 1;
+}
+
//===----------------------------------------------------------------------===//
// AbsFOp
//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
-def AllocOp : Std_Op<"alloc"> {
+def AllocOp : AllocLikeOp<"alloc"> {
let summary = "memory allocation operation";
let description = [{
The `alloc` operation allocates a region of memory, as specified by its
Example:
```mlir
- %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
+ %0 = alloc() : memref<8x64xf32, 1>
```
The optional list of dimension operands are bound to the dynamic dimensions
bound to the second dimension of the memref (which is dynamic).
```mlir
- %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1>
+ %0 = alloc(%d) : memref<8x?xf32, 1>
```
The optional list of symbol operands are bound to the symbols of the
the symbol 's0' in the affine map specified in the allocs memref type.
```mlir
- %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+ %0 = alloc()[%s] : memref<8x64xf32,
+ affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
```
This operation returns a single ssa value of memref type, which can be used
```mlir
%0 = alloc()[%s] {alignment = 8} :
- memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+ memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
```
}];
+}
- let arguments = (ins Variadic<Index>:$value,
- Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$alignment);
- let results = (outs AnyMemRef);
+//===----------------------------------------------------------------------===//
+// AllocaOp
+//===----------------------------------------------------------------------===//
- let builders = [OpBuilder<
- "Builder *builder, OperationState &result, MemRefType memrefType", [{
- result.types.push_back(memrefType);
- }]>,
- OpBuilder<
- "Builder *builder, OperationState &result, MemRefType memrefType, " #
- "ArrayRef<Value> operands, IntegerAttr alignment = IntegerAttr()", [{
- result.addOperands(operands);
- result.types.push_back(memrefType);
- if (alignment)
- result.addAttribute(getAlignmentAttrName(), alignment);
- }]>];
+def AllocaOp : AllocLikeOp<"alloca"> {
+ let summary = "stack memory allocation operation";
+ let description = [{
+ The `alloca` operation allocates memory on the stack, to be automatically
+ released when the stack frame is discarded. The amount of memory allocated
+ is specified by its memref and additional operands. For example:
- let extraClassDeclaration = [{
- static StringRef getAlignmentAttrName() { return "alignment"; }
+ ```mlir
+ %0 = alloca() : memref<8x64xf32>
+ ```
- MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+ The optional list of dimension operands are bound to the dynamic dimensions
+ specified in its memref type. In the example below, the SSA value '%d' is
+ bound to the second dimension of the memref (which is dynamic).
- /// Returns the number of symbolic operands (the ones in square brackets),
- /// which bind to the symbols of the memref's layout map.
- unsigned getNumSymbolicOperands() {
- return getNumOperands() - getType().getNumDynamicDims();
- }
+ ```mlir
+ %0 = alloca(%d) : memref<8x?xf32>
+ ```
- /// Returns the symbolic operands (the ones in square brackets), which bind
- /// to the symbols of the memref's layout map.
- operand_range getSymbolicOperands() {
- return {operand_begin() + getType().getNumDynamicDims(), operand_end()};
- }
+ The optional list of symbol operands are bound to the symbols of the
+ memref's affine map. In the example below, the SSA value '%s' is bound to
+ the symbol 's0' in the affine map specified in the allocs memref type.
- /// Returns the dynamic sizes for this alloc operation if specified.
- operand_range getDynamicSizes() { return getOperands(); }
- }];
+ ```mlir
+ %0 = alloca()[%s] : memref<8x64xf32,
+ affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>>
+ ```
- let hasCanonicalizer = 1;
+ This operation returns a single SSA value of memref type, which can be used
+ by subsequent load and store operations. An optional alignment attribute, if
+ specified, guarantees alignment at least to that boundary. If not specified,
+ an alignment on any convenient boundary compatible with the type will be
+ chosen.
+ }];
}
//===----------------------------------------------------------------------===//
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/ADT/TypeSwitch.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
[](AffineMap map) { return map.isIdentity(); });
}
-// An `alloc` is converted into a definition of a memref descriptor value and
-// a call to `malloc` to allocate the underlying data buffer. The memref
-// descriptor is of the LLVM structure type where:
-// 1. the first element is a pointer to the allocated (typed) data buffer,
-// 2. the second element is a pointer to the (typed) payload, aligned to the
-// specified alignment,
-// 3. the remaining elements serve to store all the sizes and strides of the
-// memref using LLVM-converted `index` type.
-//
-// Alignment is obtained by allocating `alignment - 1` more bytes than requested
-// and shifting the aligned pointer relative to the allocated memory. If
-// alignment is unspecified, the two pointers are equal.
-struct AllocOpLowering : public ConvertOpToLLVMPattern<AllocOp> {
- using ConvertOpToLLVMPattern<AllocOp>::ConvertOpToLLVMPattern;
+/// Lowering for AllocOp and AllocaOp.
+template <typename AllocLikeOp>
+struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<AllocLikeOp> {
+ using ConvertOpToLLVMPattern<AllocLikeOp>::ConvertOpToLLVMPattern;
+ using Base = AllocLikeOpLowering<AllocLikeOp>;
+ using ConvertOpToLLVMPattern<AllocLikeOp>::createIndexConstant;
+ using ConvertOpToLLVMPattern<AllocLikeOp>::getIndexType;
+ using ConvertOpToLLVMPattern<AllocLikeOp>::typeConverter;
+ using ConvertOpToLLVMPattern<AllocLikeOp>::getVoidPtrType;
- explicit AllocOpLowering(LLVMTypeConverter &converter, bool useAlloca = false)
- : ConvertOpToLLVMPattern<AllocOp>(converter), useAlloca(useAlloca) {}
+ explicit AllocLikeOpLowering(LLVMTypeConverter &converter)
+ : ConvertOpToLLVMPattern<AllocLikeOp>(converter) {}
LogicalResult match(Operation *op) const override {
- MemRefType type = cast<AllocOp>(op).getType();
- if (isSupportedMemRefType(type))
+ MemRefType memRefType = cast<AllocLikeOp>(op).getType();
+ if (isSupportedMemRefType(memRefType))
return success();
int64_t offset;
SmallVector<int64_t, 4> strides;
- auto successStrides = getStridesAndOffset(type, strides, offset);
+ auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (failed(successStrides))
return failure();
return success();
}
- void rewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
- auto allocOp = cast<AllocOp>(op);
- MemRefType type = allocOp.getType();
-
- // Get actual sizes of the memref as values: static sizes are constant
- // values and dynamic sizes are passed to 'alloc' as operands. In case of
- // zero-dimensional memref, assume a scalar (size 1).
- SmallVector<Value, 4> sizes;
- sizes.reserve(type.getRank());
- unsigned i = 0;
- for (int64_t s : type.getShape())
- sizes.push_back(s == -1 ? operands[i++]
- : createIndexConstant(rewriter, loc, s));
- if (sizes.empty())
- sizes.push_back(createIndexConstant(rewriter, loc, 1));
-
- // Compute the total number of memref elements.
- Value cumulativeSize = sizes.front();
- for (unsigned i = 1, e = sizes.size(); i < e; ++i)
- cumulativeSize = rewriter.create<LLVM::MulOp>(
- loc, getIndexType(), ArrayRef<Value>{cumulativeSize, sizes[i]});
-
- // Compute the size of an individual element. This emits the MLIR equivalent
- // of the following sizeof(...) implementation in LLVM IR:
- // %0 = getelementptr %elementType* null, %indexType 1
- // %1 = ptrtoint %elementType* %0 to %indexType
- // which is a common pattern of getting the size of a type in bytes.
- auto elementType = type.getElementType();
- auto convertedPtrType = typeConverter.convertType(elementType)
- .cast<LLVM::LLVMType>()
- .getPointerTo();
- auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
- auto one = createIndexConstant(rewriter, loc, 1);
- auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType,
- ArrayRef<Value>{nullPtr, one});
- auto elementSize =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
- cumulativeSize = rewriter.create<LLVM::MulOp>(
- loc, getIndexType(), ArrayRef<Value>{cumulativeSize, elementSize});
-
- // Allocate the underlying buffer and store a pointer to it in the MemRef
- // descriptor.
- Value allocated = nullptr;
- int alignment = 0;
- Value alignmentValue = nullptr;
- if (auto alignAttr = allocOp.alignment())
- alignment = alignAttr.getValue().getSExtValue();
-
- if (useAlloca) {
- allocated = rewriter.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
- cumulativeSize, alignment);
- } else {
- // Insert the `malloc` declaration if it is not already present.
- auto module = op->getParentOfType<ModuleOp>();
- auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
- if (!mallocFunc) {
- OpBuilder moduleBuilder(
- op->getParentOfType<ModuleOp>().getBodyRegion());
- mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
- rewriter.getUnknownLoc(), "malloc",
- LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
- /*isVarArg=*/false));
- }
- if (alignment != 0) {
- alignmentValue = createIndexConstant(rewriter, loc, alignment);
- cumulativeSize = rewriter.create<LLVM::SubOp>(
- loc,
- rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignmentValue),
- one);
- }
- allocated = rewriter
- .create<LLVM::CallOp>(
- loc, getVoidPtrType(),
- rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize)
- .getResult(0);
- }
-
- auto structElementType = typeConverter.convertType(elementType);
- auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
- type.getMemorySpace());
- Value bitcastAllocated = rewriter.create<LLVM::BitcastOp>(
- loc, elementPtrType, ArrayRef<Value>(allocated));
-
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto successStrides = getStridesAndOffset(type, strides, offset);
- assert(succeeded(successStrides) && "unexpected non-strided memref");
- (void)successStrides;
- assert(offset != MemRefType::getDynamicStrideOrOffset() &&
- "unexpected dynamic offset");
-
- // 0-D memref corner case: they have size 1 ...
- assert(((type.getRank() == 0 && strides.empty() && sizes.size() == 1) ||
- (strides.size() == sizes.size())) &&
- "unexpected number of strides");
-
- // Create the MemRef descriptor.
- auto structType = typeConverter.convertType(type);
+ /// Creates and populates the memref descriptor struct given all its fields.
+ /// This method also performs any post allocation alignment needed for heap
+ /// allocations when `accessAlignment` is non null. This is used with
+ /// allocators that do not support alignment.
+ MemRefDescriptor createMemRefDescriptor(
+ Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType,
+ Value allocatedTypePtr, Value allocatedBytePtr, Value accessAlignment,
+ uint64_t offset, ArrayRef<int64_t> strides, ArrayRef<Value> sizes) const {
+ auto elementPtrType = getElementPtrType(memRefType);
+ auto structType = typeConverter.convertType(memRefType);
auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType);
+
// Field 1: Allocated pointer, used for malloc/free.
- memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated);
+ memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedTypePtr);
// Field 2: Actual aligned pointer to payload.
- Value bitcastAligned = bitcastAllocated;
- if (!useAlloca && alignment != 0) {
- assert(alignmentValue);
+ Value alignedBytePtr = allocatedTypePtr;
+ if (accessAlignment) {
// offset = (align - (ptr % align))% align
Value intVal = rewriter.create<LLVM::PtrToIntOp>(
- loc, this->getIndexType(), allocated);
+ loc, this->getIndexType(), allocatedBytePtr);
Value ptrModAlign =
- rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue);
+ rewriter.create<LLVM::URemOp>(loc, intVal, accessAlignment);
Value subbed =
- rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign);
- Value offset = rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue);
- Value aligned = rewriter.create<LLVM::GEPOp>(loc, allocated.getType(),
- allocated, offset);
- bitcastAligned = rewriter.create<LLVM::BitcastOp>(
+ rewriter.create<LLVM::SubOp>(loc, accessAlignment, ptrModAlign);
+ Value offset =
+ rewriter.create<LLVM::URemOp>(loc, subbed, accessAlignment);
+ Value aligned = rewriter.create<LLVM::GEPOp>(
+ loc, allocatedBytePtr.getType(), allocatedBytePtr, offset);
+ alignedBytePtr = rewriter.create<LLVM::BitcastOp>(
loc, elementPtrType, ArrayRef<Value>(aligned));
}
- memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned);
+ memRefDescriptor.setAlignedPtr(rewriter, loc, alignedBytePtr);
// Field 3: Offset in aligned pointer.
memRefDescriptor.setOffset(rewriter, loc,
createIndexConstant(rewriter, loc, offset));
- if (type.getRank() == 0)
+ if (memRefType.getRank() == 0)
// No size/stride descriptor in memref, return the descriptor value.
- return rewriter.replaceOp(op, {memRefDescriptor});
+ return memRefDescriptor;
- // Fields 4 and 5: Sizes and strides of the strided MemRef.
+ // Fields 4 and 5: sizes and strides of the strided MemRef.
// Store all sizes in the descriptor. Only dynamic sizes are passed in as
// operands to AllocOp.
Value runningStride = nullptr;
memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value());
memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]);
}
+ return memRefDescriptor;
+ }
+
+ /// Determines sizes to be used in the memref descriptor.
+ void getSizes(Location loc, MemRefType memRefType, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<Value> &sizes, Value &cumulativeSize,
+ Value &one) const {
+ sizes.reserve(memRefType.getRank());
+ unsigned i = 0;
+ for (int64_t s : memRefType.getShape())
+ sizes.push_back(s == -1 ? operands[i++]
+ : createIndexConstant(rewriter, loc, s));
+ if (sizes.empty())
+ sizes.push_back(createIndexConstant(rewriter, loc, 1));
+
+ // Compute the total number of memref elements.
+ cumulativeSize = sizes.front();
+ for (unsigned i = 1, e = sizes.size(); i < e; ++i)
+ cumulativeSize = rewriter.create<LLVM::MulOp>(
+ loc, getIndexType(), ArrayRef<Value>{cumulativeSize, sizes[i]});
+
+ // Compute the size of an individual element. This emits the MLIR equivalent
+ // of the following sizeof(...) implementation in LLVM IR:
+ // %0 = getelementptr %elementType* null, %indexType 1
+ // %1 = ptrtoint %elementType* %0 to %indexType
+ // which is a common pattern of getting the size of a type in bytes.
+ auto elementType = memRefType.getElementType();
+ auto convertedPtrType = typeConverter.convertType(elementType)
+ .template cast<LLVM::LLVMType>()
+ .getPointerTo();
+ auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
+ one = createIndexConstant(rewriter, loc, 1);
+ auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType,
+ ArrayRef<Value>{nullPtr, one});
+ auto elementSize =
+ rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
+ cumulativeSize = rewriter.create<LLVM::MulOp>(
+ loc, getIndexType(), ArrayRef<Value>{cumulativeSize, elementSize});
+ }
+
+ /// Returns the type of a pointer to an element of the memref.
+ Type getElementPtrType(MemRefType memRefType) const {
+ auto elementType = memRefType.getElementType();
+ auto structElementType = typeConverter.convertType(elementType);
+ return structElementType.template cast<LLVM::LLVMType>().getPointerTo(
+ memRefType.getMemorySpace());
+ }
+
+ /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
+ /// is set to null for stack allocations. `accessAlignment` is set if
+ /// alignment is neeeded post allocation (for eg. in conjunction with malloc).
+ /// TODO(bondhugula): next revision will support std lib func aligned_alloc.
+ Value allocateBuffer(Location loc, Value cumulativeSize, Operation *op,
+ MemRefType memRefType, Value one, Value &accessAlignment,
+ Value &allocatedBytePtr,
+ ConversionPatternRewriter &rewriter) const {
+ auto elementPtrType = getElementPtrType(memRefType);
+
+ // Whether to use std lib function aligned_alloc that supports alignment.
+ Optional<APInt> allocationAlignment = cast<AllocLikeOp>(op).alignment();
+
+ // With alloca, one gets a pointer to the element type right away.
+ bool onStack = isa<AllocaOp>(op);
+ if (onStack) {
+ allocatedBytePtr = nullptr;
+ accessAlignment = nullptr;
+ return rewriter.create<LLVM::AllocaOp>(
+ loc, elementPtrType, cumulativeSize,
+ allocationAlignment ? allocationAlignment.getValue().getSExtValue()
+ : 0);
+ }
+
+ // Use malloc. Insert the malloc declaration if it is not already present.
+ auto allocFuncName = "malloc";
+ AllocOp allocOp = cast<AllocOp>(op);
+ auto module = allocOp.getParentOfType<ModuleOp>();
+ auto allocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(allocFuncName);
+ if (!allocFunc) {
+ OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
+ SmallVector<LLVM::LLVMType, 2> callArgTypes = {getIndexType()};
+ allocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
+ rewriter.getUnknownLoc(), allocFuncName,
+ LLVM::LLVMType::getFunctionTy(getVoidPtrType(), callArgTypes,
+ /*isVarArg=*/false));
+ }
+
+ // Allocate the underlying buffer and store a pointer to it in the MemRef
+ // descriptor.
+ SmallVector<Value, 2> callArgs;
+ // Adjust the allocation size to consider alignment.
+ if (allocOp.alignment()) {
+ accessAlignment = createIndexConstant(
+ rewriter, loc, allocOp.alignment().getValue().getSExtValue());
+ cumulativeSize = rewriter.create<LLVM::SubOp>(
+ loc,
+ rewriter.create<LLVM::AddOp>(loc, cumulativeSize, accessAlignment),
+ one);
+ }
+ callArgs.push_back(cumulativeSize);
+ auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc);
+ allocatedBytePtr = rewriter
+ .create<LLVM::CallOp>(loc, getVoidPtrType(),
+ allocFuncSymbol, callArgs)
+ .getResult(0);
+ // For heap allocations, the allocated pointer is a cast of the byte pointer
+ // to the type pointer.
+ return rewriter.create<LLVM::BitcastOp>(loc, elementPtrType,
+ allocatedBytePtr);
+ }
+
+ // An `alloc` is converted into a definition of a memref descriptor value and
+ // a call to `malloc` to allocate the underlying data buffer. The memref
+ // descriptor is of the LLVM structure type where:
+ // 1. the first element is a pointer to the allocated (typed) data buffer,
+ // 2. the second element is a pointer to the (typed) payload, aligned to the
+ // specified alignment,
+ // 3. the remaining elements serve to store all the sizes and strides of the
+ // memref using LLVM-converted `index` type.
+ //
+ // Alignment is performed by allocating `alignment - 1` more bytes than
+ // requested and shifting the aligned pointer relative to the allocated
+ // memory. If alignment is unspecified, the two pointers are equal.
+
+ // An `alloca` is converted into a definition of a memref descriptor value and
+ // an llvm.alloca to allocate the underlying data buffer.
+ void rewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType memRefType = cast<AllocLikeOp>(op).getType();
+ auto loc = op->getLoc();
+
+ // Get actual sizes of the memref as values: static sizes are constant
+ // values and dynamic sizes are passed to 'alloc' as operands. In case of
+ // zero-dimensional memref, assume a scalar (size 1).
+ SmallVector<Value, 4> sizes;
+ Value cumulativeSize, one;
+ getSizes(loc, memRefType, operands, rewriter, sizes, cumulativeSize, one);
+
+ // Allocate the underlying buffer.
+ // Value holding the alignment that has to be performed post allocation
+ // (in conjunction with allocators that do not support alignment, eg.
+ // malloc); nullptr if no such adjustment needs to be performed.
+ Value accessAlignment;
+ // Byte pointer to the allocated buffer.
+ Value allocatedBytePtr;
+ Value allocatedTypePtr =
+ allocateBuffer(loc, cumulativeSize, op, memRefType, one,
+ accessAlignment, allocatedBytePtr, rewriter);
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(memRefType, strides, offset);
+ (void)successStrides;
+ assert(succeeded(successStrides) && "unexpected non-strided memref");
+ assert(offset != MemRefType::getDynamicStrideOrOffset() &&
+ "unexpected dynamic offset");
+
+ // 0-D memref corner case: they have size 1.
+ assert(
+ ((memRefType.getRank() == 0 && strides.empty() && sizes.size() == 1) ||
+ (strides.size() == sizes.size())) &&
+ "unexpected number of strides");
+
+ // Create the MemRef descriptor.
+ auto memRefDescriptor = createMemRefDescriptor(
+ loc, rewriter, memRefType, allocatedTypePtr, allocatedBytePtr,
+ accessAlignment, offset, strides, sizes);
// Return the final value of the descriptor.
rewriter.replaceOp(op, {memRefDescriptor});
}
+};
- bool useAlloca;
+struct AllocOpLowering : public AllocLikeOpLowering<AllocOp> {
+ using Base::Base;
+};
+struct AllocaOpLowering : public AllocLikeOpLowering<AllocaOp> {
+ using Base::Base;
};
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
using ConvertOpToLLVMPattern<DeallocOp>::ConvertOpToLLVMPattern;
- explicit DeallocOpLowering(LLVMTypeConverter &converter,
- bool useAlloca = false)
- : ConvertOpToLLVMPattern<DeallocOp>(converter), useAlloca(useAlloca) {}
+ explicit DeallocOpLowering(LLVMTypeConverter &converter)
+ : ConvertOpToLLVMPattern<DeallocOp>(converter) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (useAlloca)
- return rewriter.eraseOp(op), success();
-
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
return success();
}
-
- bool useAlloca;
};
// A `rsqrt` is converted into `1 / sqrt`.
AbsFOpLowering,
AddFOpLowering,
AddIOpLowering,
+ AllocaOpLowering,
AndOpLowering,
AtomicCmpXchgOpLowering,
AtomicRMWOpLowering,
}
void mlir::populateStdToLLVMMemoryConversionPatters(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- bool useAlloca) {
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// clang-format off
patterns.insert<
AssumeAlignmentOpLowering,
ViewOpLowering>(converter);
patterns.insert<
AllocOpLowering,
- DeallocOpLowering>(converter, useAlloca);
+ DeallocOpLowering>(converter);
// clang-format on
}
void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- bool useAlloca, bool emitCWrappers) {
+ bool emitCWrappers) {
populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns,
emitCWrappers);
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
- populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca);
+ populateStdToLLVMMemoryConversionPatters(converter, patterns);
}
static void populateStdToLLVMBarePtrFuncOpConversionPattern(
}
void mlir::populateStdToLLVMBarePtrConversionPatterns(
- LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- bool useAlloca) {
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns);
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
- populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca);
+ populateStdToLLVMMemoryConversionPatters(converter, patterns);
}
// Create an LLVM IR structure type if there is more than one result.
#include "mlir/Conversion/Passes.h.inc"
/// Creates an LLVM lowering pass.
- LLVMLoweringPass(bool useAlloca, bool useBarePtrCallConv, bool emitCWrappers,
+ LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers,
unsigned indexBitwidth) {
- this->useAlloca = useAlloca;
this->useBarePtrCallConv = useBarePtrCallConv;
this->emitCWrappers = emitCWrappers;
this->indexBitwidth = indexBitwidth;
OwningRewritePatternList patterns;
if (useBarePtrCallConv)
- populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns,
- useAlloca);
+ populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns);
else
- populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca,
+ populateStdToLLVMConversionPatterns(typeConverter, patterns,
emitCWrappers);
LLVMConversionTarget target(getContext());
}
std::unique_ptr<OpPassBase<ModuleOp>>
-mlir::createLowerToLLVMPass(bool useAlloca, bool useBarePtrCallConv,
- bool emitCWrappers, unsigned indexBitwidth) {
- return std::make_unique<LLVMLoweringPass>(useAlloca, useBarePtrCallConv,
- emitCWrappers, indexBitwidth);
+mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
+ return std::make_unique<LLVMLoweringPass>(
+ options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth);
}