using namespace mlir;
//===----------------------------------------------------------------------===//
+// Utility functions for operation conversion
+//===----------------------------------------------------------------------===//
+
+/// Performs the index computation to get to the element pointed to by
+/// `indices` using the layout map of `baseType`.
+
+// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
+// MemRefType with AffineMap that has static strides. Handle dynamic strides
+spirv::AccessChainOp getElementPtr(OpBuilder &builder,
+ SPIRVTypeConverter &typeConverter,
+ Location loc, MemRefType origBaseType,
+ Value *basePtr, ArrayRef<Value *> indices) {
+ // Get base and offset of the MemRefType and verify they are static.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(getStridesAndOffset(origBaseType, strides, offset)) ||
+ llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+ return nullptr;
+ }
+
+ auto indexType = typeConverter.getIndexType(builder.getContext());
+
+ Value *ptrLoc = nullptr;
+ assert(indices.size() == strides.size());
+ for (auto index : enumerate(indices)) {
+ Value *strideVal = builder.create<spirv::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+ Value *update =
+ builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+ ptrLoc =
+ (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
+ : update);
+ }
+ SmallVector<Value *, 2> linearizedIndices;
+ // Add a '0' at the start to index into the struct.
+ linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
+ loc, indexType, IntegerAttr::get(indexType, 0)));
+ linearizedIndices.push_back(ptrLoc);
+ return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
+}
+
+//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
/// operation. Since IndexType is not used within SPIR-V dialect, this needs
/// special handling to make sure the result type and the type of the value
/// attribute are consistent.
+// TODO(ravishankarm) : This should be moved into DRR.
class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
public:
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
/// the type of the return value of the replacement operation differs from
/// that of the replaced operation. This is not handled in tablegen-based
/// pattern specification.
+// TODO(ravishankarm) : This should be moved into DRR.
template <typename StdOp, typename SPIRVOp>
class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
public:
}
};
-// If 'basePtr' is the result of lowering a value of MemRefType, and 'indices'
-// are the indices used to index into the original value (for load/store),
-// perform the equivalent address calculation in SPIR-V.
-spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc,
- Value *basePtr, ArrayRef<Value *> indices,
- SPIRVTypeConverter &typeConverter) {
- // MemRefType is converted to a
- // spirv::StructType<spirv::ArrayType<spirv:ArrayType...>>>
- auto ptrType = basePtr->getType().cast<spirv::PointerType>();
- (void)ptrType;
- auto structType = ptrType.getPointeeType().cast<spirv::StructType>();
- (void)structType;
- assert(structType.getNumElements() == 1);
- auto indexType = typeConverter.getIndexType(builder.getContext());
-
- // Need to add a '0' at the beginning of the index list for accessing into the
- // struct that wraps the nested array types.
- Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
- SmallVector<Value *, 4> accessIndices;
- accessIndices.reserve(1 + indices.size());
- accessIndices.push_back(zero);
- accessIndices.append(indices.begin(), indices.end());
- return builder.create<spirv::AccessChainOp>(loc, basePtr, accessIndices);
-}
-
/// Convert load -> spv.LoadOp. The operands of the replaced operation are of
/// IndexType while that of the replacement operation are of type i32. This is
/// not supported in tablegen based pattern specification.
-// TODO(ravishankarm) : These could potentially be templated on the operation
-// being converted, since the same logic should work for linalg.load.
+// TODO(ravishankarm) : This should be moved into DRR.
class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
public:
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
matchAndRewrite(LoadOp loadOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
LoadOpOperandAdaptor loadOperands(operands);
- auto basePtr = loadOperands.memref();
- auto loadPtr = getElementPtr(rewriter, loadOp.getLoc(), basePtr,
- loadOperands.indices(), typeConverter);
+ auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
+ loadOp.memref()->getType().cast<MemRefType>(),
+ loadOperands.memref(), loadOperands.indices());
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
};
/// Convert return -> spv.Return.
+// TODO(ravishankarm) : This should be moved into DRR.
class ReturnToSPIRVConversion final : public SPIRVOpLowering<ReturnOp> {
public:
using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
};
/// Convert select -> spv.Select
+// TODO(ravishankarm) : This should be moved into DRR.
class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
public:
using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
/// Convert store -> spv.StoreOp. The operands of the replaced operation are
/// of IndexType while that of the replacement operation are of type i32. This
/// is not supported in tablegen based pattern specification.
-// TODO(ravishankarm) : These could potentially be templated on the operation
-// being converted, since the same logic should work for linalg.store.
+// TODO(ravishankarm) : This should be moved into DRR.
class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
public:
using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
matchAndRewrite(StoreOp storeOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
StoreOpOperandAdaptor storeOperands(operands);
- auto value = storeOperands.value();
- auto basePtr = storeOperands.memref();
- auto storePtr = getElementPtr(rewriter, storeOp.getLoc(), basePtr,
- storeOperands.indices(), typeConverter);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, value,
+ auto storePtr =
+ getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
+ storeOp.memref()->getType().cast<MemRefType>(),
+ storeOperands.memref(), storeOperands.indices());
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
+ storeOperands.value(),
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
return matchSuccess();
return integerType.getWidth() / 8;
} else if (auto floatType = t.dyn_cast<FloatType>()) {
return floatType.getWidth() / 8;
+ } else if (auto memRefType = t.dyn_cast<MemRefType>()) {
+ // TODO: Layout should also be controlled by the ABI attributes. For now
+ // using the layout from MemRef.
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (!memRefType.hasStaticShape() ||
+ failed(getStridesAndOffset(memRefType, strides, offset))) {
+ return llvm::None;
+ }
+ // To get the size of the memref object in memory, the total size is the
+ // max(stride * dimension-size) computed for all dimensions times the size
+ // of the element.
+ auto elementSize = getTypeNumBytes(memRefType.getElementType());
+ if (!elementSize) {
+ return llvm::None;
+ }
+ auto dims = memRefType.getShape();
+ if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
+ offset == MemRefType::getDynamicStrideOrOffset() ||
+ llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+ return llvm::None;
+ }
+ int64_t memrefSize = -1;
+ for (auto shape : enumerate(dims)) {
+ memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
+ }
+ return (offset + memrefSize) * elementSize.getValue();
}
// TODO: Add size computation for other types.
return llvm::None;
if (!elementSize) {
return Type();
}
-
- if (!memRefType.hasStaticShape()) {
- // TODO(ravishankarm) : Handle dynamic shapes.
- return Type();
- }
-
- // Get the strides and offset.
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
- offset == MemRefType::getDynamicStrideOrOffset() ||
- llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
- // TODO(ravishankarm) : Handle dynamic strides and offsets.
- return Type();
- }
-
- // Convert to a multi-dimensional spv.array if size is known.
- auto shape = memRefType.getShape();
- assert(shape.size() == strides.size());
- Type arrayType = elementType;
- // TODO(antiagainst): Introduce layout as part of the shader ABI to have
- // better separate of concerns.
- for (int i = shape.size(); i > 0; --i) {
- arrayType = spirv::ArrayType::get(
- arrayType, shape[i - 1], strides[i - 1] * elementSize.getValue());
+ // TODO(ravishankarm) : Handle dynamic shapes.
+ if (memRefType.hasStaticShape()) {
+ auto arraySize = getTypeNumBytes(memRefType);
+ if (!arraySize) {
+ return Type();
+ }
+ auto arrayType = spirv::ArrayType::get(
+ elementType, arraySize.getValue() / elementSize.getValue(),
+ elementSize.getValue());
+ auto structType = spirv::StructType::get(arrayType, 0);
+ // For now initialize the storage class to StorageBuffer. This will be
+ // updated later based on whats passed in w.r.t to the ABI attributes.
+ return spirv::PointerType::get(structType,
+ spirv::StorageClass::StorageBuffer);
}
-
- // For the offset, need to wrap the array in a struct.
- auto structType =
- spirv::StructType::get(arrayType, offset * elementSize.getValue());
- // For now initialize the storage class to StorageBuffer. This will be
- // updated later based on whats passed in w.r.t to the ABI attributes.
- return spirv::PointerType::get(structType,
- spirv::StorageClass::StorageBuffer);
}
return Type();
// CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-LABEL: func @load_store_kernel
- // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG3:%.*]]: i32 {spirv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG4:%.*]]: i32 {spirv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG5:%.*]]: i32 {spirv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
%12 = addi %arg3, %0 : index
// CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], [[LOCALINVOCATIONIDX]]
%13 = addi %arg4, %3 : index
+ // CHECK: [[STRIDE1_1:%.*]] = spv.constant 4 : i32
+ // CHECK: [[OFFSET1_1:%.*]] = spv.IMul [[STRIDE1_1]], [[INDEX1]] : i32
+ // CHECK: [[STRIDE1_2:%.*]] = spv.constant 1 : i32
+ // CHECK: [[UPDATE1_2:%.*]] = spv.IMul [[STRIDE1_2]], [[INDEX2]] : i32
+ // CHECK: [[OFFSET1_2:%.*]] = spv.IAdd [[OFFSET1_1]], [[UPDATE1_2]] : i32
// CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
- // CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[ZERO1]], [[INDEX1]], [[INDEX2]]{{\]}}
+ // CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[ZERO1]], [[OFFSET1_2]]{{\]}}
// CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]
%14 = load %arg0[%12, %13] : memref<12x4xf32>
- // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32
- // CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}[[ZERO2]], [[INDEX1]], [[INDEX2]]{{\]}}
+ // CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
// CHECK-NEXT: [[VAL2:%.*]] = spv.Load "StorageBuffer" [[PTR2]]
%15 = load %arg1[%12, %13] : memref<12x4xf32>
// CHECK: [[VAL3:%.*]] = spv.FAdd [[VAL1]], [[VAL2]]
%16 = addf %14, %15 : f32
- // CHECK: [[ZERO3:%.*]] = spv.constant 0 : i32
- // CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}[[ZERO3]], [[INDEX1]], [[INDEX2]]{{\]}}
+ // CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
// CHECK-NEXT: spv.Store "StorageBuffer" [[PTR3]], [[VAL3]]
store %16, %arg2[%12, %13] : memref<12x4xf32>
return