#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "spirv-to-llvm-pattern"
using namespace mlir;
}
};
+class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
+ ArrayRef<Value>());
+ return success();
+ }
+};
+
+class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
+ operands);
+ return success();
+ }
+};
+
/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
/// puts a restriction on `Shift` and `Base` to have the same bit width,
/// `Shift` is zero or sign extended to match this specification. Cases when
return success();
}
};
+
+//===----------------------------------------------------------------------===//
+// FuncOp conversion
+//===----------------------------------------------------------------------===//
+
+class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Convert function signature. At the moment LLVMType converter is enough
+ // for currently supported types.
+ auto funcType = funcOp.getType();
+ TypeConverter::SignatureConversion signatureConverter(
+ funcType.getNumInputs());
+ auto llvmType = this->typeConverter.convertFunctionSignature(
+ funcOp.getType(), /*isVariadic=*/false, signatureConverter);
+
+ // Create a new `LLVMFuncOp`
+ Location loc = funcOp.getLoc();
+ StringRef name = funcOp.getName();
+ auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
+
+ // Convert SPIR-V Function Control to equivalent LLVM function attribute
+ MLIRContext *context = funcOp.getContext();
+ switch (funcOp.function_control()) {
+#define DISPATCH(functionControl, llvmAttr) \
+ case functionControl: \
+ newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \
+ break;
+
+ DISPATCH(spirv::FunctionControl::Inline,
+ StringAttr::get("alwaysinline", context));
+ DISPATCH(spirv::FunctionControl::DontInline,
+ StringAttr::get("noinline", context));
+ DISPATCH(spirv::FunctionControl::Pure,
+ StringAttr::get("readonly", context));
+ DISPATCH(spirv::FunctionControl::Const,
+ StringAttr::get("readnone", context));
+
+#undef DISPATCH
+
+ // Default: if `spirv::FunctionControl::None`, then no attributes are
+ // needed.
+ default:
+ break;
+ }
+
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
// Shift ops
ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
- ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>>(context,
- typeConverter);
+ ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
+
+ // Return ops
+ ReturnPattern, ReturnValuePattern>(context, typeConverter);
+}
+
+void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
+ MLIRContext *context, LLVMTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<FuncConversionPattern>(context, typeConverter);
}
--- /dev/null
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.Return
+//===----------------------------------------------------------------------===//
+
+func @return() {
+ // CHECK: llvm.return
+ spv.Return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ReturnValue
+//===----------------------------------------------------------------------===//
+
+func @return_value(%arg: i32) {
+ // CHECK: llvm.return %{{.*}} : !llvm.i32
+ spv.ReturnValue %arg : i32
+}
+
+//===----------------------------------------------------------------------===//
+// spv.func
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: llvm.func @none()
+spv.func @none() -> () "None" {
+ spv.Return
+}
+
+// CHECK-LABEL: llvm.func @inline() attributes {passthrough = ["alwaysinline"]}
+spv.func @inline() -> () "Inline" {
+ spv.Return
+}
+
+// CHECK-LABEL: llvm.func @dont_inline() attributes {passthrough = ["noinline"]}
+spv.func @dont_inline() -> () "DontInline" {
+ spv.Return
+}
+
+// CHECK-LABEL: llvm.func @pure() attributes {passthrough = ["readonly"]}
+spv.func @pure() -> () "Pure" {
+ spv.Return
+}
+
+// CHECK-LABEL: llvm.func @const() attributes {passthrough = ["readnone"]}
+spv.func @const() -> () "Const" {
+ spv.Return
+}
+
+// CHECK-LABEL: llvm.func @scalar_types(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.double, %arg3: !llvm.float)
+spv.func @scalar_types(%arg0: i32, %arg1: i1, %arg2: f64, %arg3: f32) -> () "None" {
+ spv.Return
+}
+
+// CHECK-LABEL: llvm.func @vector_types(%arg0: !llvm<"<2 x i64>">, %arg1: !llvm<"<2 x i64>">) -> !llvm<"<2 x i64>">
+spv.func @vector_types(%arg0: vector<2xi64>, %arg1: vector<2xi64>) -> vector<2xi64> "None" {
+ %0 = spv.IAdd %arg0, %arg1 : vector<2xi64>
+ spv.ReturnValue %0 : vector<2xi64>
+}
+
+
+