// Extract raw data pointer value from a value representing a memref.
static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
- Location loc,
- Value *convertedMemRefValue,
- Type elementTypePtr,
- bool hasStaticShape) {
- Value *buffer;
- if (hasStaticShape)
- return convertedMemRefValue;
- else
- return builder.create<LLVM::ExtractValueOp>(loc, elementTypePtr,
- convertedMemRefValue,
- builder.getIndexArrayAttr(0));
- return buffer;
+ Location loc, Value *memref,
+ Type elementTypePtr) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, elementTypePtr, memref,
+ builder.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
}
protected:
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value *, 4> sizes;
- auto numOperands = allocOp.getNumOperands();
- sizes.reserve(numOperands);
+ sizes.reserve(type.getRank());
unsigned i = 0;
for (int64_t s : type.getShape())
sizes.push_back(s == -1 ? operands[i++]
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
- auto hasStaticShape = type.isPointerTy();
- Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
- Value *bufferPtr =
- extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
- elementPtrType, hasStaticShape);
+ Type elementPtrType = type.getStructElementType(kPtrPosInMemRefDescriptor);
+ Value *bufferPtr = extractMemRefElementPtr(
+ rewriter, op->getLoc(), transformed.memref(), elementPtrType);
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
ArrayRef<int64_t> strides, int64_t offset,
ConversionPatternRewriter &rewriter) const {
auto indexTy = this->getIndexType();
- Value *base = rewriter.create<LLVM::ExtractValueOp>(
- loc, elementTypePtr, memRefDescriptor,
- rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor));
+ Value *base = this->extractMemRefElementPtr(rewriter, loc, memRefDescriptor,
+ elementTypePtr);
Value *offsetValue =
offset == MemRefType::getDynamicStrideOrOffset()
? rewriter.create<LLVM::ExtractValueOp>(