LLVM::LLVMType unwrap(Type type);
};
+/// Helper class to produce LLVM dialect operations extracting or inserting
+/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
+/// The Value may be null, in which case none of the operations are valid.
+class MemRefDescriptor {
+public:
+ /// Construct a helper for the given descriptor value.
+ explicit MemRefDescriptor(Value *descriptor);
+ /// Builds IR creating an `undef` value of the descriptor type.
+ static MemRefDescriptor undef(OpBuilder &builder, Location loc,
+ Type descriptorType);
+ /// Builds IR extracting the allocated pointer from the descriptor.
+ Value *allocatedPtr(OpBuilder &builder, Location loc);
+ /// Builds IR inserting the allocated pointer into the descriptor.
+ void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr);
+
+ /// Builds IR extracting the aligned pointer from the descriptor.
+ Value *alignedPtr(OpBuilder &builder, Location loc);
+
+ /// Builds IR inserting the aligned pointer into the descriptor.
+ void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr);
+
+ /// Builds IR extracting the offset from the descriptor.
+ Value *offset(OpBuilder &builder, Location loc);
+
+ /// Builds IR inserting the offset into the descriptor.
+ void setOffset(OpBuilder &builder, Location loc, Value *offset);
+
+ /// Builds IR extracting the pos-th size from the descriptor.
+ Value *size(OpBuilder &builder, Location loc, unsigned pos);
+
+ /// Builds IR inserting the pos-th size into the descriptor
+ void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size);
+
+ /// Builds IR extracting the pos-th size from the descriptor.
+ Value *stride(OpBuilder &builder, Location loc, unsigned pos);
+
+ /// Builds IR inserting the pos-th stride into the descriptor
+ void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride);
+
+ /*implicit*/ operator Value *() { return value; }
+
+private:
+ Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
+ void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
+
+ // Cached descriptor type.
+ Type structType;
+
+ // Cached index type.
+ Type indexType;
+
+ // Actual descriptor.
+ Value *value;
+};
+
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
/// conversion patterns with an access to the containing LLVMLowering for the
/// purpose of type conversions.
PatternBenefit benefit)
: ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
-namespace {
-/// Helper class to produce LLVM dialect operations extracting or inserting
-/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
-/// The Value may be null, in which case none of the operations are valid.
-class MemRefDescriptor {
-public:
- /// Construct a helper for the given descriptor value.
- explicit MemRefDescriptor(Value *descriptor) : value(descriptor) {
- if (value) {
- structType = value->getType().cast<LLVM::LLVMType>();
- indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
- }
- }
-
- /// Builds IR creating an `undef` value of the descriptor type.
- static MemRefDescriptor undef(OpBuilder &builder, Location loc,
- Type descriptorType) {
- Value *descriptor = builder.create<LLVM::UndefOp>(
- loc, descriptorType.cast<LLVM::LLVMType>());
- return MemRefDescriptor(descriptor);
- }
-
- /// Builds IR extracting the allocated pointer from the descriptor.
- Value *allocatedPtr(OpBuilder &builder, Location loc) {
- return extractPtr(builder, loc,
- LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
- }
-
- /// Builds IR inserting the allocated pointer into the descriptor.
- void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) {
- setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
- ptr);
- }
-
- /// Builds IR extracting the aligned pointer from the descriptor.
- Value *alignedPtr(OpBuilder &builder, Location loc) {
- return extractPtr(builder, loc,
- LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
+/*============================================================================*/
+/* MemRefDescriptor implementation */
+/*============================================================================*/
+
+/// Construct a helper for the given descriptor value.
+MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) {
+ if (value) {
+ structType = value->getType().cast<LLVM::LLVMType>();
+ indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
}
+}
- /// Builds IR inserting the aligned pointer into the descriptor.
- void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) {
- setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
- ptr);
- }
+/// Builds IR creating an `undef` value of the descriptor type.
+MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
+ Type descriptorType) {
+ Value *descriptor =
+ builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
+ return MemRefDescriptor(descriptor);
+}
- /// Builds IR extracting the offset from the descriptor.
- Value *offset(OpBuilder &builder, Location loc) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, indexType, value,
- builder.getI64ArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
- }
+/// Builds IR extracting the allocated pointer from the descriptor.
+Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
+ return extractPtr(builder, loc,
+ LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
+}
- /// Builds IR inserting the offset into the descriptor.
- void setOffset(OpBuilder &builder, Location loc, Value *offset) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, structType, value, offset,
- builder.getI64ArrayAttr(
- LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
- }
+/// Builds IR inserting the allocated pointer into the descriptor.
+void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
+ Value *ptr) {
+ setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
+ ptr);
+}
- /// Builds IR extracting the pos-th size from the descriptor.
- Value *size(OpBuilder &builder, Location loc, unsigned pos) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, indexType, value,
- builder.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
- }
+/// Builds IR extracting the aligned pointer from the descriptor.
+Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
+ return extractPtr(builder, loc,
+ LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
+}
- /// Builds IR inserting the pos-th size into the descriptor
- void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, structType, value, size,
- builder.getI64ArrayAttr(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
- }
+/// Builds IR inserting the aligned pointer into the descriptor.
+void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
+ Value *ptr) {
+ setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
+ ptr);
+}
- /// Builds IR extracting the pos-th size from the descriptor.
- Value *stride(OpBuilder &builder, Location loc, unsigned pos) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, indexType, value,
- builder.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
- }
+/// Builds IR extracting the offset from the descriptor.
+Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, indexType, value,
+ builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+}
- /// Builds IR inserting the pos-th stride into the descriptor
- void setStride(OpBuilder &builder, Location loc, unsigned pos,
- Value *stride) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, structType, value, stride,
- builder.getI64ArrayAttr(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
- }
+/// Builds IR inserting the offset into the descriptor.
+void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
+ Value *offset) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, structType, value, offset,
+ builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+}
- /*implicit*/ operator Value *() { return value; }
+/// Builds IR extracting the pos-th size from the descriptor.
+Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, indexType, value,
+ builder.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
+}
-private:
- Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) {
- Type type = structType.getStructElementType(pos);
- return builder.create<LLVM::ExtractValueOp>(loc, type, value,
- builder.getI64ArrayAttr(pos));
- }
+/// Builds IR inserting the pos-th size into the descriptor
+void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
+ Value *size) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, structType, value, size,
+ builder.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
+}
- void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) {
- value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
- builder.getI64ArrayAttr(pos));
- }
+/// Builds IR extracting the pos-th size from the descriptor.
+Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
+ unsigned pos) {
+ return builder.create<LLVM::ExtractValueOp>(
+ loc, indexType, value,
+ builder.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
+}
- // Cached descriptor type.
- LLVM::LLVMType structType;
+/// Builds IR inserting the pos-th stride into the descriptor
+void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
+ Value *stride) {
+ value = builder.create<LLVM::InsertValueOp>(
+ loc, structType, value, stride,
+ builder.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
+}
- // Cached index type.
- LLVM::LLVMType indexType;
+Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
+ unsigned pos) {
+ Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
+ return builder.create<LLVM::ExtractValueOp>(loc, type, value,
+ builder.getI64ArrayAttr(pos));
+}
- // Actual descriptor.
- Value *value;
-};
+void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,
+ Value *ptr) {
+ value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
+ builder.getI64ArrayAttr(pos));
+}
+namespace {
// Base class for Standard to LLVM IR op conversions. Matches the Op type
// provided as template argument. Carries a reference to the LLVM dialect in
// case it is necessary for rewriters.
}
namespace {
-/// Factor out the common information for all view conversions:
-/// 1. common types in (standard and LLVM dialects)
-/// 2. `pos` method
-/// 3. view descriptor construction `desc`.
+/// EDSC-compatible wrapper for MemRefDescriptor.
class BaseViewConversionHelper {
public:
- BaseViewConversionHelper(Location loc, MemRefType memRefType,
- ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &lowering)
- : zeroDMemRef(memRefType.getRank() == 0),
- elementTy(getPtrToElementType(memRefType, lowering)),
- int64Ty(
- lowering.convertType(rewriter.getIntegerType(64)).cast<LLVMType>()),
- desc(nullptr), rewriter(rewriter) {
- assert(isStrided(memRefType) && "expected strided memref type");
- viewDescriptorTy = lowering.convertType(memRefType).cast<LLVMType>();
- desc = rewriter.create<LLVM::UndefOp>(loc, viewDescriptorTy);
- }
-
- ArrayAttr pos(ArrayRef<int64_t> values) const {
- return rewriter.getI64ArrayAttr(values);
- };
-
- bool zeroDMemRef;
- LLVMType elementTy, int64Ty, viewDescriptorTy;
- Value *desc;
- ConversionPatternRewriter &rewriter;
+ BaseViewConversionHelper(Type type)
+ : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
+
+ BaseViewConversionHelper(Value *v) : d(v) {}
+
+ /// Wrappers around MemRefDescriptor that use EDSC builder and location.
+ Value *allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
+ void setAllocatedPtr(Value *v) { d.setAllocatedPtr(rewriter(), loc(), v); }
+ Value *alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
+ void setAlignedPtr(Value *v) { d.setAlignedPtr(rewriter(), loc(), v); }
+ Value *offset() { return d.offset(rewriter(), loc()); }
+ void setOffset(Value *v) { d.setOffset(rewriter(), loc(), v); }
+ Value *size(unsigned i) { return d.size(rewriter(), loc(), i); }
+ void setSize(unsigned i, Value *v) { d.setSize(rewriter(), loc(), i, v); }
+ Value *stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
+ void setStride(unsigned i, Value *v) { d.setStride(rewriter(), loc(), i, v); }
+
+ operator Value *() { return d; }
+
+private:
+ OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
+ Location loc() { return ScopedContext::getLocation(); }
+
+ MemRefDescriptor d;
};
} // namespace
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
+ edsc::ScopedContext context(rewriter, op->getLoc());
SliceOpOperandAdaptor adaptor(operands);
- Value *baseDesc = adaptor.view();
+ BaseViewConversionHelper baseDesc(adaptor.view());
auto sliceOp = cast<SliceOp>(op);
auto memRefType = sliceOp.getBaseViewType();
+ auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
+ .cast<LLVM::LLVMType>();
- BaseViewConversionHelper helper(op->getLoc(), sliceOp.getViewType(),
- rewriter, lowering);
- LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
- Value *desc = helper.desc;
-
- edsc::ScopedContext context(rewriter, op->getLoc());
+ BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
// TODO(ntv): extract sizes and emit asserts.
SmallVector<Value *, 4> strides(memRefType.getRank());
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
- strides[i] = extractvalue(
- int64Ty, baseDesc,
- helper.pos({LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
+ strides[i] = baseDesc.stride(i);
+
+ auto pos = [&rewriter](ArrayRef<int64_t> values) {
+ return rewriter.getI64ArrayAttr(values);
+ };
// Compute base offset.
- Value *baseOffset = extractvalue(
- int64Ty, baseDesc,
- helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ Value *baseOffset = baseDesc.offset();
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
Value *indexing = adaptor.indexings()[i];
Value *min = indexing;
if (sliceOp.indexing(i)->getType().isa<RangeType>())
- min = extractvalue(int64Ty, indexing, helper.pos(0));
+ min = extractvalue(int64Ty, indexing, pos(0));
baseOffset = add(baseOffset, mul(min, strides[i]));
}
// Insert the base and aligned pointers.
- auto ptrPos =
- helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
- desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
- ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
- desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
+ desc.setAllocatedPtr(baseDesc.allocatedPtr());
+ desc.setAlignedPtr(baseDesc.alignedPtr());
// Insert base offset.
- desc = insertvalue(
- desc, baseOffset,
- helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+ desc.setOffset(baseOffset);
// Corner case, no sizes or strides: early return the descriptor.
- if (helper.zeroDMemRef)
- return rewriter.replaceOp(op, desc), matchSuccess();
+ if (sliceOp.getViewType().getRank() == 0)
+ return rewriter.replaceOp(op, {desc}), matchSuccess();
Value *zero =
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
if (indexing->getType().isa<RangeType>()) {
int rank = en.index();
Value *rangeDescriptor = adaptor.indexings()[rank];
- Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
- Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
- Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
- Value *baseSize = extractvalue(
- int64Ty, baseDesc,
- helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor, rank}));
+ Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+ Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+ Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+ Value *baseSize = baseDesc.size(rank);
+
// Bound upper by base view upper bound.
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
baseSize);
size =
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
Value *stride = mul(strides[rank], step);
- desc = insertvalue(
- desc, size,
- helper.pos(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, numNewDims}));
- desc = insertvalue(
- desc, stride,
- helper.pos(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, numNewDims}));
+ desc.setSize(numNewDims, size);
+ desc.setStride(numNewDims, stride);
++numNewDims;
}
}
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(op, {desc});
return matchSuccess();
}
};
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
+ edsc::ScopedContext context(rewriter, op->getLoc());
TransposeOpOperandAdaptor adaptor(operands);
- Value *baseDesc = adaptor.view();
+ BaseViewConversionHelper baseDesc(adaptor.view());
auto transposeOp = cast<TransposeOp>(op);
// No permutation, early exit.
if (transposeOp.permutation().isIdentity())
- return rewriter.replaceOp(op, baseDesc), matchSuccess();
+ return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
- BaseViewConversionHelper helper(op->getLoc(), transposeOp.getViewType(),
- rewriter, lowering);
- LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
- Value *desc = helper.desc;
+ BaseViewConversionHelper desc(
+ lowering.convertType(transposeOp.getViewType()));
- edsc::ScopedContext context(rewriter, op->getLoc());
// Copy the base and aligned pointers from the old descriptor to the new
// one.
- ArrayAttr ptrPos =
- helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
- desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
- ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
- desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
+ desc.setAllocatedPtr(baseDesc.allocatedPtr());
+ desc.setAlignedPtr(baseDesc.alignedPtr());
// Copy the offset pointer from the old descriptor to the new one.
- ArrayAttr offPos =
- helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
- desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos);
+ desc.setOffset(baseDesc.offset());
// Iterate over the dimensions and apply size/stride permutation.
for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
int sourcePos = en.index();
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
- Value *size = extractvalue(
- int64Ty, baseDesc,
- helper.pos(
- {LLVMTypeConverter::kSizePosInMemRefDescriptor, sourcePos}));
- desc =
- insertvalue(desc, size,
- helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor,
- targetPos}));
- Value *stride = extractvalue(
- int64Ty, baseDesc,
- helper.pos(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, sourcePos}));
- desc = insertvalue(
- desc, stride,
- helper.pos(
- {LLVMTypeConverter::kStridePosInMemRefDescriptor, targetPos}));
+ desc.setSize(targetPos, baseDesc.size(sourcePos));
+ desc.setStride(targetPos, baseDesc.stride(sourcePos));
}
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(op, {desc});
return matchSuccess();
}
};