using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Operation conversion
+//===----------------------------------------------------------------------===//
+
namespace {
-class BitwiseAndOpConversion : public ConvertToLLVMPattern {
+/// Converts SPIR-V operations that have straightforward LLVM equivalent
+/// into LLVM dialect operations.
+template <typename SPIRVOp, typename LLVMOp>
+class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
public:
- explicit BitwiseAndOpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(spirv::BitwiseAndOp::getOperationName(), context,
- typeConverter) {}
+ using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto bitwiseAndOp = cast<spirv::BitwiseAndOp>(op);
- auto dstType = typeConverter.convertType(bitwiseAndOp.getType());
+ auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::AndOp>(bitwiseAndOp, dstType, operands);
+ rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands);
return success();
}
};
} // namespace
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
void mlir::populateSPIRVToLLVMConversionPatterns(
MLIRContext *context, LLVMTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<BitwiseAndOpConversion>(context, typeConverter);
+ patterns.insert<DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
+ DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
+ DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
+ DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
+ DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
+ DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
+ DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
+ DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
+ DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
+ DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
+ DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
+ DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
+ DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
+ DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>>(
+ context, typeConverter);
}
--- /dev/null
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.IAdd
+//===----------------------------------------------------------------------===//
+
+func @iadd_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.IAdd %arg0, %arg1 : i32
+ return
+}
+
+func @iadd_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) {
+ // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %0 = spv.IAdd %arg0, %arg1 : vector<4xi64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ISub
+//===----------------------------------------------------------------------===//
+
+func @isub_scalar(%arg0: i8, %arg1: i8) {
+ // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i8
+ %0 = spv.ISub %arg0, %arg1 : i8
+ return
+}
+
+func @isub_vector(%arg0: vector<2xi16>, %arg1: vector<2xi16>) {
+ // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm<"<2 x i16>">
+ %0 = spv.ISub %arg0, %arg1 : vector<2xi16>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.IMul
+//===----------------------------------------------------------------------===//
+
+func @imul_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.IMul %arg0, %arg1 : i32
+ return
+}
+
+func @imul_vector(%arg0: vector<3xi32>, %arg1: vector<3xi32>) {
+ // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm<"<3 x i32>">
+ %0 = spv.IMul %arg0, %arg1 : vector<3xi32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FAdd
+//===----------------------------------------------------------------------===//
+
+func @fadd_scalar(%arg0: f16, %arg1: f16) {
+ // CHECK: %{{.*}} = llvm.fadd %{{.*}}, %{{.*}} : !llvm.half
+ %0 = spv.FAdd %arg0, %arg1 : f16
+ return
+}
+
+func @fadd_vector(%arg0: vector<4xf32>, %arg1: vector<4xf32>) {
+ // CHECK: %{{.*}} = llvm.fadd %{{.*}}, %{{.*}} : !llvm<"<4 x float>">
+ %0 = spv.FAdd %arg0, %arg1 : vector<4xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FSub
+//===----------------------------------------------------------------------===//
+
+func @fsub_scalar(%arg0: f32, %arg1: f32) {
+ // CHECK: %{{.*}} = llvm.fsub %{{.*}}, %{{.*}} : !llvm.float
+ %0 = spv.FSub %arg0, %arg1 : f32
+ return
+}
+
+func @fsub_vector(%arg0: vector<2xf32>, %arg1: vector<2xf32>) {
+ // CHECK: %{{.*}} = llvm.fsub %{{.*}}, %{{.*}} : !llvm<"<2 x float>">
+ %0 = spv.FSub %arg0, %arg1 : vector<2xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FDiv
+//===----------------------------------------------------------------------===//
+
+func @fdiv_scalar(%arg0: f32, %arg1: f32) {
+ // CHECK: %{{.*}} = llvm.fdiv %{{.*}}, %{{.*}} : !llvm.float
+ %0 = spv.FDiv %arg0, %arg1 : f32
+ return
+}
+
+func @fdiv_vector(%arg0: vector<3xf64>, %arg1: vector<3xf64>) {
+ // CHECK: %{{.*}} = llvm.fdiv %{{.*}}, %{{.*}} : !llvm<"<3 x double>">
+ %0 = spv.FDiv %arg0, %arg1 : vector<3xf64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FRem
+//===----------------------------------------------------------------------===//
+
+func @frem_scalar(%arg0: f32, %arg1: f32) {
+ // CHECK: %{{.*}} = llvm.frem %{{.*}}, %{{.*}} : !llvm.float
+ %0 = spv.FRem %arg0, %arg1 : f32
+ return
+}
+
+func @frem_vector(%arg0: vector<3xf64>, %arg1: vector<3xf64>) {
+ // CHECK: %{{.*}} = llvm.frem %{{.*}}, %{{.*}} : !llvm<"<3 x double>">
+ %0 = spv.FRem %arg0, %arg1 : vector<3xf64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.FNegate
+//===----------------------------------------------------------------------===//
+
+func @fneg_scalar(%arg: f64) {
+ // CHECK: %{{.*}} = llvm.fneg %{{.*}} : !llvm.double
+ %0 = spv.FNegate %arg : f64
+ return
+}
+
+func @fneg_vector(%arg: vector<2xf32>) {
+ // CHECK: %{{.*}} = llvm.fneg %{{.*}} : !llvm<"<2 x float>">
+ %0 = spv.FNegate %arg : vector<2xf32>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.UDiv
+//===----------------------------------------------------------------------===//
+
+func @udiv_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.udiv %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.UDiv %arg0, %arg1 : i32
+ return
+}
+
+func @udiv_vector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) {
+ // CHECK: %{{.*}} = llvm.udiv %{{.*}}, %{{.*}} : !llvm<"<3 x i64>">
+ %0 = spv.UDiv %arg0, %arg1 : vector<3xi64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.SDiv
+//===----------------------------------------------------------------------===//
+
+func @sdiv_scalar(%arg0: i16, %arg1: i16) {
+ // CHECK: %{{.*}} = llvm.sdiv %{{.*}}, %{{.*}} : !llvm.i16
+ %0 = spv.SDiv %arg0, %arg1 : i16
+ return
+}
+
+func @sdiv_vector(%arg0: vector<2xi64>, %arg1: vector<2xi64>) {
+ // CHECK: %{{.*}} = llvm.sdiv %{{.*}}, %{{.*}} : !llvm<"<2 x i64>">
+ %0 = spv.SDiv %arg0, %arg1 : vector<2xi64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.SRem
+//===----------------------------------------------------------------------===//
+
+func @srem_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.srem %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.SRem %arg0, %arg1 : i32
+ return
+}
+
+func @srem_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) {
+ // CHECK: %{{.*}} = llvm.srem %{{.*}}, %{{.*}} : !llvm<"<4 x i32>">
+ %0 = spv.SRem %arg0, %arg1 : vector<4xi32>
+ return
+}
--- /dev/null
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.BitwiseAnd
+//===----------------------------------------------------------------------===//
+
+func @bitwise_and_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.BitwiseAnd %arg0, %arg1 : i32
+ return
+}
+
+func @bitwise_and_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) {
+ // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i64>">
+ %0 = spv.BitwiseAnd %arg0, %arg1 : vector<4xi64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BitwiseOr
+//===----------------------------------------------------------------------===//
+
+func @bitwise_or_scalar(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm.i64
+ %0 = spv.BitwiseOr %arg0, %arg1 : i64
+ return
+}
+
+func @bitwise_or_vector(%arg0: vector<3xi8>, %arg1: vector<3xi8>) {
+ // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm<"<3 x i8>">
+ %0 = spv.BitwiseOr %arg0, %arg1 : vector<3xi8>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BitwiseXor
+//===----------------------------------------------------------------------===//
+
+func @bitwise_xor_scalar(%arg0: i32, %arg1: i32) {
+ // CHECK: %{{.*}} = llvm.xor %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = spv.BitwiseXor %arg0, %arg1 : i32
+ return
+}
+
+func @bitwise_xor_vector(%arg0: vector<2xi16>, %arg1: vector<2xi16>) {
+ // CHECK: %{{.*}} = llvm.xor %{{.*}}, %{{.*}} : !llvm<"<2 x i16>">
+ %0 = spv.BitwiseXor %arg0, %arg1 : vector<2xi16>
+ return
+}