LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+ Type unrankedDescriptorType);
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
};
//===----------------------------------------------------------------------===//
+// Small runtime support library for memref.copy lowering during codegen.
+//===----------------------------------------------------------------------===//
+extern "C" MLIR_CRUNNERUTILS_EXPORT void
+memrefCopy(int64_t elemSize, UnrankedMemRefType<char> *src,
+ UnrankedMemRefType<char> *dst);
+
+//===----------------------------------------------------------------------===//
// Small runtime support library for vector.print lowering during codegen.
//===----------------------------------------------------------------------===//
extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
}
};
+struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
+ using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CopyOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ memref::CopyOp::Adaptor adaptor(operands);
+ auto srcType = op.source().getType().cast<BaseMemRefType>();
+ auto targetType = op.target().getType().cast<BaseMemRefType>();
+
+ // First make sure we have an unranked memref descriptor representation.
+ auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
+ auto rank = rewriter.create<LLVM::ConstantOp>(
+ loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
+ auto *typeConverter = getTypeConverter();
+ auto ptr =
+ typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
+ auto voidPtr =
+ rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr)
+ .getResult();
+ auto unrankedType =
+ UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
+ return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
+ unrankedType,
+ ValueRange{rank, voidPtr});
+ };
+
+ Value unrankedSource = srcType.hasRank()
+ ? makeUnranked(adaptor.source(), srcType)
+ : adaptor.source();
+ Value unrankedTarget = targetType.hasRank()
+ ? makeUnranked(adaptor.target(), targetType)
+ : adaptor.target();
+
+ // Now promote the unranked descriptors to the stack.
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+ rewriter.getIndexAttr(1));
+ auto promote = [&](Value desc) {
+ auto ptrType = LLVM::LLVMPointerType::get(desc.getType());
+ auto allocated =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, ValueRange{one});
+ rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
+ return allocated;
+ };
+
+ auto sourcePtr = promote(unrankedSource);
+ auto targetPtr = promote(unrankedTarget);
+
+ auto elemSize = rewriter.create<LLVM::ConstantOp>(
+ loc, getIndexType(),
+ rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8));
+ auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
+ op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
+ rewriter.create<LLVM::CallOp>(loc, copyFn,
+ ValueRange{elemSize, sourcePtr, targetPtr});
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
/// memref type. In unranked case, the fields are extracted from the underlying
/// ranked descriptor.
GetGlobalMemrefOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
+ MemRefCopyOpLowering,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
RankOpLowering,
static constexpr llvm::StringRef kMalloc = "malloc";
static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
static constexpr llvm::StringRef kFree = "free";
+static constexpr llvm::StringRef kMemRefCopy = "memref_copy";
/// Generic print function lookupOrCreate helper.
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
+LLVM::LLVMFuncOp
+mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+ Type unrankedDescriptorType) {
+ return LLVM::lookupOrCreateFn(
+ moduleOp, kMemRefCopy,
+ ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
LLVM::LLVMFuncOp fn,
ValueRange paramTypes,
#include <cinttypes>
#include <cstdio>
+#include <string.h>
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
extern "C" void printComma() { fputs(", ", stdout); }
extern "C" void printNewline() { fputc('\n', stdout); }
+extern "C" MLIR_CRUNNERUTILS_EXPORT void
+memrefCopy(int64_t elemSize, UnrankedMemRefType<char> *srcArg,
+ UnrankedMemRefType<char> *dstArg) {
+ DynamicMemRefType<char> src(*srcArg);
+ DynamicMemRefType<char> dst(*dstArg);
+
+ int64_t rank = src.rank;
+ int64_t *indices = static_cast<int64_t *>(alloca(sizeof(int64_t) * rank));
+ int64_t *srcStrides = static_cast<int64_t *>(alloca(sizeof(int64_t) * rank));
+ int64_t *dstStrides = static_cast<int64_t *>(alloca(sizeof(int64_t) * rank));
+
+ char *srcPtr = src.data + src.offset * elemSize;
+ char *dstPtr = dst.data + dst.offset * elemSize;
+
+ // Initialize index and scale strides.
+ for (int rankp = 0; rankp < rank; ++rankp) {
+ indices[rankp] = 0;
+ srcStrides[rankp] = src.strides[rankp] * elemSize;
+ dstStrides[rankp] = dst.strides[rankp] * elemSize;
+ }
+
+ int64_t readIndex = 0, writeIndex = 0;
+ for (;;) {
+ // Copy over the element, byte by byte.
+ memcpy(dstPtr + writeIndex, srcPtr + readIndex, elemSize);
+ // Advance index and read position.
+ for (int64_t axis = rank - 1; axis >= 0; --axis) {
+ // Advance at current axis.
+ auto newIndex = ++indices[axis];
+ readIndex += srcStrides[axis];
+ writeIndex += dstStrides[axis];
+ // If this is a valid index, we have our next index, so continue copying.
+ if (src.sizes[axis] != newIndex)
+ break;
+ // We reached the end of this axis. If this is axis 0, we are done.
+ if (axis == 0)
+ return;
+ // Else, reset to 0 and undo the advancement of the linear index that
+ // this axis had. The continue with the axis one outer.
+ indices[axis] = 0;
+ readIndex -= src.sizes[axis] * srcStrides[axis];
+ writeIndex -= dst.sizes[axis] * dstStrides[axis];
+ }
+ }
+}
+
/// Prints GFLOPS rating.
extern "C" void print_flops(double flops) {
fprintf(stderr, "%lf GFLOPS\n", flops / 1.0E9);