#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
}
+/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
+/// offset to LLVM struct. Otherwise, the conversion is not supported.
+static Optional<Type>
+convertStructTypeWithOffset(spirv::StructType type,
+ LLVMTypeConverter &converter) {
+ if (type != VulkanLayoutUtils::decorateType(type))
+ return llvm::None;
+
+ auto elementsVector = llvm::to_vector<8>(
+ llvm::map_range(type.getElementTypes(), [&](Type elementType) {
+ return converter.convertType(elementType).cast<LLVM::LLVMType>();
+ }));
+ return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
+ /*isPacked=*/false);
+}
+
/// Converts SPIR-V struct with no offset to packed LLVM struct.
static Type convertStructTypePacked(spirv::StructType type,
LLVMTypeConverter &converter) {
// Type conversion
//===----------------------------------------------------------------------===//
-/// Converts SPIR-V array type to LLVM array. There is no modelling of array
-/// stride at the moment.
+/// Converts SPIR-V array type to LLVM array. Natural stride (according to
+/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
+/// when converting ops that manipulate array types.
static Optional<Type> convertArrayType(spirv::ArrayType type,
TypeConverter &converter) {
- if (type.getArrayStride() != 0)
+ unsigned stride = type.getArrayStride();
+ Type elementType = type.getElementType();
+ auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
+ if (stride != 0 &&
+ !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
return llvm::None;
- auto elementType =
- converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
+
+ auto llvmElementType =
+ converter.convertType(elementType).cast<LLVM::LLVMType>();
unsigned numElements = type.getNumElements();
- return LLVM::LLVMType::getArrayTy(elementType, numElements);
+ return LLVM::LLVMType::getArrayTy(llvmElementType, numElements);
}
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
}
/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
-/// member decorations or with offset.
+/// member decorations. Also, only natural offset is supported.
static Optional<Type> convertStructType(spirv::StructType type,
LLVMTypeConverter &converter) {
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
type.getMemberDecorations(memberDecorations);
- if (type.hasOffset() || !memberDecorations.empty())
+ if (!memberDecorations.empty())
return llvm::None;
+ if (type.hasOffset())
+ return convertStructTypeWithOffset(type, converter);
return convertStructTypePacked(type, converter);
}
namespace {
+class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = typeConverter.convertType(op.component_ptr().getType());
+ if (!dstType)
+ return failure();
+ // To use GEP we need to add a first 0 index to go through the pointer.
+ auto indices = llvm::to_vector<4>(op.indices());
+ Type indexType = op.indices().front().getType();
+ auto llvmIndexType = typeConverter.convertType(indexType);
+ if (!llvmIndexType)
+ return failure();
+ Value zero = rewriter.create<LLVM::ConstantOp>(
+ op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
+ indices.insert(indices.begin(), zero);
+ rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
+ indices);
+ return success();
+ }
+};
+
class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
public:
using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
if (!dstType)
return failure();
- // Limit conversion to the current invocation only for now.
+ // Limit conversion to the current invocation only or `StorageBuffer`
+ // required by SPIR-V runner.
+ // This is okay because multiple invocations are not supported yet.
auto storageClass = srcType.getStorageClass();
if (storageClass != spirv::StorageClass::Input &&
storageClass != spirv::StorageClass::Private &&
- storageClass != spirv::StorageClass::Output) {
+ storageClass != spirv::StorageClass::Output &&
+ storageClass != spirv::StorageClass::StorageBuffer) {
return failure();
}
}
};
+/// A template pattern that erases the given `SPIRVOp`.
+template <typename SPIRVOp>
+class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
+public:
+ using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
public:
using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
}
};
-class MergePattern : public SPIRVToLLVMConversion<spirv::MergeOp> {
-public:
- using SPIRVToLLVMConversion<spirv::MergeOp>::SPIRVToLLVMConversion;
-
- LogicalResult
- matchAndRewrite(spirv::MergeOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.eraseOp(op);
- return success();
- }
-};
-
/// Converts `spv.selection` with `spv.BranchConditional` in its header block.
/// All blocks within selection should be reachable for conversion to succeed.
class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
ConstantScalarAndVectorPattern,
// Control Flow ops
- BranchConversionPattern, BranchConditionalConversionPattern, LoopPattern,
- SelectionPattern, MergePattern,
+ BranchConversionPattern, BranchConditionalConversionPattern,
+ FunctionCallPattern, LoopPattern, SelectionPattern,
+ ErasePattern<spirv::MergeOp>,
+
+ // Entry points and execution mode
+ // Module generated from SPIR-V could have other "internal" functions, so
+ // having entry point and execution mode metadat can be useful. For now,
+ // simply remove them.
+ // TODO: Support EntryPoint/ExecutionMode properly.
+ ErasePattern<spirv::EntryPointOp>, ErasePattern<spirv::ExecutionModeOp>,
// Function Call op
- FunctionCallPattern,
// GLSL extended instruction set ops
DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
NotPattern<spirv::LogicalNotOp>,
// Memory ops
- AddressOfPattern, GlobalVariablePattern, LoadStorePattern<spirv::LoadOp>,
- LoadStorePattern<spirv::StoreOp>, VariablePattern,
+ AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
+ LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
+ VariablePattern,
// Miscellaneous ops
DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
//===----------------------------------------------------------------------===//
+// spv.AccessChain
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @access_chain
+func @access_chain() -> () {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+ %0 = spv.constant 1: i32
+ %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+ // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr<struct<packed (float, array<4 x float>)>>, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm.ptr<float>
+ %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>, i32, i32
+ return
+}
+
+// CHECK-LABEL: @access_chain_array
+func @access_chain_array(%arg0 : i32) -> () {
+ %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+ // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %{{.*}}] : (!llvm.ptr<array<4 x array<4 x float>>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<array<4 x float>>
+ %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32
+ %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
// spv.globalVariable and spv._address_of
//===----------------------------------------------------------------------===//