Type descriptorType);
/// Builds IR extracting the rank from the descriptor
- Value rank(OpBuilder &builder, Location loc);
+ Value rank(OpBuilder &builder, Location loc) const;
/// Builds IR setting the rank in the descriptor
void setRank(OpBuilder &builder, Location loc, Value value);
/// Builds IR extracting ranked memref descriptor ptr
- Value memRefDescPtr(OpBuilder &builder, Location loc);
+ Value memRefDescPtr(OpBuilder &builder, Location loc) const;
/// Builds IR setting ranked memref descriptor ptr
void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
static unsigned getNumUnpackedValues() { return 2; }
/// Builds IR computing the sizes in bytes (suitable for opaque allocation)
- /// and appends the corresponding values into `sizes`.
+ /// and appends the corresponding values into `sizes`. `addressSpaces`
+ /// which must have the same length as `values`, is needed to handle layouts
+ /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
static void computeSizes(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
ArrayRef<UnrankedMemRefDescriptor> values,
+ ArrayRef<unsigned> addressSpaces,
SmallVectorImpl<Value> &sizes);
/// TODO: The following accessors don't take alignment rules between elements
protected:
/// Builds IR to extract a value from the struct at position pos
- Value extractPtr(OpBuilder &builder, Location loc, unsigned pos);
+ Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const;
/// Builds IR to set a value in the struct at position pos
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
};
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
return UnrankedMemRefDescriptor(descriptor);
}
-Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
+Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
}
void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
}
Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
- Location loc) {
+ Location loc) const {
return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
}
void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
void UnrankedMemRefDescriptor::computeSizes(
OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
- ArrayRef<UnrankedMemRefDescriptor> values, SmallVectorImpl<Value> &sizes) {
+ ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
+ SmallVectorImpl<Value> &sizes) {
if (values.empty())
return;
-
+ assert(values.size() == addressSpaces.size() &&
+ "must provide address space for each descriptor");
// Cache the index type.
Type indexType = typeConverter.getIndexType();
// Initialize shared constants.
Value one = createIndexAttrConstant(builder, loc, indexType, 1);
Value two = createIndexAttrConstant(builder, loc, indexType, 2);
- Value pointerSize = createIndexAttrConstant(
- builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8));
Value indexSize =
createIndexAttrConstant(builder, loc, indexType,
ceilDiv(typeConverter.getIndexTypeBitwidth(), 8));
sizes.reserve(sizes.size() + values.size());
- for (UnrankedMemRefDescriptor desc : values) {
+ for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
// Emit IR computing the memory necessary to store the descriptor. This
// assumes the descriptor to be
// { type*, type*, index, index[rank], index[rank] }
// 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
// TODO: consider including the actual size (including eventual padding due
// to data layout) into the unranked descriptor.
+ Value pointerSize = createIndexAttrConstant(
+ builder, loc, indexType,
+ ceilDiv(typeConverter.getPointerBitwidth(addressSpace), 8));
Value doublePointerSize =
builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
"expected as may original types as operands");
// Find operands of unranked memref type and store them.
- SmallVector<UnrankedMemRefDescriptor, 4> unrankedMemrefs;
- for (unsigned i = 0, e = operands.size(); i < e; ++i)
- if (origTypes[i].isa<UnrankedMemRefType>())
+ SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
+ SmallVector<unsigned> unrankedAddressSpaces;
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ if (auto memRefType = origTypes[i].dyn_cast<UnrankedMemRefType>()) {
unrankedMemrefs.emplace_back(operands[i]);
+ FailureOr<unsigned> addressSpace =
+ getTypeConverter()->getMemRefAddressSpace(memRefType);
+ if (failed(addressSpace))
+ return failure();
+ unrankedAddressSpaces.emplace_back(*addressSpace);
+ }
+ }
if (unrankedMemrefs.empty())
return success();
// Compute allocation sizes.
- SmallVector<Value, 4> sizes;
+ SmallVector<Value> sizes;
UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
- unrankedMemrefs, sizes);
+ unrankedMemrefs, unrankedAddressSpaces,
+ sizes);
// Get frequently used types.
MLIRContext *context = builder.getContext();
}
Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
- unsigned pos) {
+ unsigned pos) const {
return builder.create<LLVM::ExtractValueOp>(loc, value, pos);
}
targetDesc.setRank(rewriter, loc, resultRank);
SmallVector<Value, 4> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
- targetDesc, sizes);
+ targetDesc, addressSpace, sizes);
Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
sizes.front());
if (failed(warpMatrixInfo))
return failure();
- Attribute memorySpace =
- op.getSource().getType().cast<MemRefType>().getMemorySpace();
bool isLdMatrixCompatible =
isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index)
// CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
// These sizes may depend on the data layout, not matching specific values.
- // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant
+ // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
// CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
// CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
// CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[ALLOCA]], %[[DESC_1]][1]
%0 = memref.cast %arg0: memref<4x3xf32> to memref<*xf32>
-
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index)
// CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
// These sizes may depend on the data layout, not matching specific values.
- // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant
+ // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
// CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
// Compute size in bytes to allocate result ranked descriptor
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : i64
// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}}
// CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8
// Compute size in bytes to allocate result ranked descriptor
// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
+// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : i64
// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}}
// CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8