From e1e0ecb96e0aa3b3846484f82e4c0e8c31c50341 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 28 Jun 2022 11:58:42 -0400 Subject: [PATCH] [mlir][spirv] Support more comparisons on boolean values Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D128692 --- .../ArithmeticToSPIRV/ArithmeticToSPIRV.cpp | 48 ++++++++++++++++------ .../ArithmeticToSPIRV/arithmetic-to-spirv.mlir | 46 ++++++++++++++++++--- 2 files changed, 75 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp index 31d9023..4bf985e 100644 --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "arith-to-spirv-pattern" @@ -665,23 +666,44 @@ LogicalResult TypeCastingOpPattern::matchAndRewrite( LogicalResult CmpIOpBooleanPattern::matchAndRewrite( arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Type operandType = op.getLhs().getType(); - if (!isBoolScalarOrVector(operandType)) + Type srcType = op.getLhs().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) return failure(); switch (op.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: { \ - rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), \ - adaptor.getRhs()); \ - return success(); \ + case arith::CmpIPredicate::eq: { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); } - - DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); - DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); - -#undef DISPATCH - default:; + case arith::CmpIPredicate::ne: { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: { + // There are no direct corresponding instructions in SPIR-V for such cases. + // Extend them to 32-bit and do comparision then. + Type type = rewriter.getI32Type(); + if (auto vectorType = dstType.dyn_cast()) + type = VectorType::get(vectorType.getShape(), type); + auto extLhs = + rewriter.create(op.getLoc(), type, adaptor.getLhs()); + auto extRhs = + rewriter.create(op.getLoc(), type, adaptor.getRhs()); + + rewriter.replaceOpWithNewOp(op, op.getPredicate(), extLhs, + extRhs); + return success(); + } + default: + break; } return failure(); } diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir index 8925197..22e3dc4 100644 --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -401,8 +401,8 @@ func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) { return } -// CHECK-LABEL: @boolcmpi -func.func @boolcmpi(%arg0 : i1, %arg1 : i1) { +// CHECK-LABEL: @boolcmpi_equality +func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) { // CHECK: spv.LogicalEqual %0 = arith.cmpi eq, %arg0, %arg1 : i1 // CHECK: spv.LogicalNotEqual @@ -410,8 +410,19 @@ func.func @boolcmpi(%arg0 : i1, %arg1 : i1) { return } -// CHECK-LABEL: @vec1boolcmpi -func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { +// CHECK-LABEL: @boolcmpi_unsigned +func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) { + // CHECK-COUNT-2: spv.Select + // CHECK: spv.UGreaterThanEqual + %0 = arith.cmpi uge, %arg0, %arg1 : i1 + // CHECK-COUNT-2: spv.Select + // CHECK: spv.ULessThan + %1 = arith.cmpi ult, %arg0, %arg1 : i1 + return +} + +// CHECK-LABEL: @vec1boolcmpi_equality +func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { // CHECK: spv.LogicalEqual %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1> // CHECK: spv.LogicalNotEqual @@ -419,8 +430,19 @@ func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { return } -// CHECK-LABEL: @vecboolcmpi -func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { +// CHECK-LABEL: @vec1boolcmpi_unsigned +func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { + // CHECK-COUNT-2: spv.Select + // CHECK: spv.UGreaterThanEqual + %0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1> + // CHECK-COUNT-2: spv.Select + // CHECK: spv.ULessThan + %1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1> + return +} + +// CHECK-LABEL: @vecboolcmpi_equality +func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { // CHECK: spv.LogicalEqual %0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1> // CHECK: spv.LogicalNotEqual @@ -428,6 +450,18 @@ func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { return } +// CHECK-LABEL: @vecboolcmpi_unsigned +func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) { + // CHECK-COUNT-2: spv.Select + // CHECK: spv.UGreaterThanEqual + %0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1> + // CHECK-COUNT-2: spv.Select + // CHECK: spv.ULessThan + %1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1> + return +} + + } // end module // ----- -- 2.7.4