From 520a5702680ea0b5059193a0d4ad52c217da7325 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Tue, 19 May 2020 10:55:17 -0700 Subject: [PATCH] [mlir][StandardToSPIRV] Fix signedness issue in bitwidth emulation. Summary: Previously, after applying the mask, a negative number would convert to a positive number because the sign flag was forgotten. This patch adds two more shift operations to do the sign extension. This assumes that we're using two's complement. This patch applies sign extension unconditionally when loading a unspported integer width, and it relies the pattern to do the casting because the signedness semantic is carried by operator itself. Differential Revision: https://reviews.llvm.org/D79753 --- .../StandardToSPIRV/ConvertStandardToSPIRV.cpp | 55 +++++++++++++++++++++- .../StandardToSPIRV/std-ops-to-spirv.mlir | 44 +++++++++++++---- 2 files changed, 89 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index fbe02560..560bc4a 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -147,14 +147,44 @@ static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, } /// Returns the shifted `targetBits`-bit value with the given offset. -Value shiftValue(Location loc, Value value, Value offset, Value mask, - int targetBits, OpBuilder &builder) { +static Value shiftValue(Location loc, Value value, Value offset, Value mask, + int targetBits, OpBuilder &builder) { Type targetType = builder.getIntegerType(targetBits); Value result = builder.create(loc, value, mask); return builder.create(loc, targetType, result, offset); } +/// Returns true if the operator is operating on unsigned integers. +/// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information +/// to the ops themselves. +template +bool isUnsignedOp() { + return false; +} + +#define CHECK_UNSIGNED_OP(SPIRVOp) \ + template <> \ + bool isUnsignedOp() { \ + return true; \ + } + +CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp); +CHECK_UNSIGNED_OP(spirv::AtomicUMinOp); +CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp); +CHECK_UNSIGNED_OP(spirv::ConvertUToFOp); +CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp); +CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp); +CHECK_UNSIGNED_OP(spirv::UConvertOp); +CHECK_UNSIGNED_OP(spirv::UDivOp); +CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp); +CHECK_UNSIGNED_OP(spirv::UGreaterThanOp); +CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp); +CHECK_UNSIGNED_OP(spirv::ULessThanOp); +CHECK_UNSIGNED_OP(spirv::UModOp); + +#undef CHECK_UNSIGNED_OP + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -178,6 +208,10 @@ public: auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); + if (isUnsignedOp() && dstType != operation.getType()) { + return operation.emitError( + "bitwidth emulation is not implemented yet on unsigned op"); + } rewriter.template replaceOpWithNewOp(operation, dstType, operands, ArrayRef()); return success(); @@ -581,6 +615,11 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ + if (isUnsignedOp() && \ + operandType != this->typeConverter.convertType(operandType)) { \ + return cmpIOp.emitError( \ + "bitwidth emulation is not implemented yet on unsigned op"); \ + } \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ @@ -661,6 +700,18 @@ IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); result = rewriter.create(loc, dstType, result, mask); + + // Apply sign extension on the loading value unconditionally. The signedness + // semantic is carried in the operator itself, we relies other pattern to + // handle the casting. + IntegerAttr shiftValueAttr = + rewriter.getIntegerAttr(dstType, dstBits - srcBits); + Value shiftValue = + rewriter.create(loc, dstType, shiftValueAttr); + result = rewriter.create(loc, dstType, result, + shiftValue); + result = rewriter.create(loc, dstType, result, + shiftValue); rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index 1663366..bf54dba 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s //===----------------------------------------------------------------------===// // std arithmetic ops @@ -128,14 +128,12 @@ module attributes { max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> } { -// CHECK-LABEL: @int_vector234 -func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) { +// CHECK-LABEL: @int_vector23 +func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) { // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32> %0 = divi_signed %arg0, %arg0: vector<2xi8> // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32> %1 = remi_signed %arg1, %arg1: vector<3xi16> - // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi32> - %2 = divi_unsigned %arg2, %arg2: vector<4xi64> return } @@ -152,6 +150,27 @@ func @float_scalar(%arg0: f16, %arg1: f64) { // ----- +// Check that types are converted to 32-bit when no special capabilities that +// are not supported. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LEBEL: @int_vector4_invalid +func @int_vector4_invalid(%arg0: vector<4xi64>) { + // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}} + // expected-error @+1 {{op requires the same type for all operands and results}} + %0 = divi_unsigned %arg0, %arg0: vector<4xi64> + return +} + +} // end module + +// ----- + //===----------------------------------------------------------------------===// // std bit ops //===----------------------------------------------------------------------===// @@ -717,7 +736,10 @@ func @load_i8(%arg0: memref) { // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[] : memref return } @@ -738,7 +760,10 @@ func @load_i16(%arg0: memref<10xi16>, %index : index) { // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 65535 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 16 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[%index] : memref<10xi16> return } @@ -852,7 +877,10 @@ func @load_i8(%arg0: memref) { // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 - // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 %0 = load %arg0[] : memref return } -- 2.7.4