}
/// Returns the shifted `targetBits`-bit value with the given offset.
-Value shiftValue(Location loc, Value value, Value offset, Value mask,
- int targetBits, OpBuilder &builder) {
+static Value shiftValue(Location loc, Value value, Value offset, Value mask,
+ int targetBits, OpBuilder &builder) {
Type targetType = builder.getIntegerType(targetBits);
Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
offset);
}
+/// Returns true if the operator is operating on unsigned integers.
+/// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information
+/// to the ops themselves.
+template <typename SPIRVOp>
+bool isUnsignedOp() {
+ return false;
+}
+
+#define CHECK_UNSIGNED_OP(SPIRVOp) \
+ template <> \
+ bool isUnsignedOp<SPIRVOp>() { \
+ return true; \
+ }
+
+CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp);
+CHECK_UNSIGNED_OP(spirv::AtomicUMinOp);
+CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp);
+CHECK_UNSIGNED_OP(spirv::ConvertUToFOp);
+CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp);
+CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp);
+CHECK_UNSIGNED_OP(spirv::UConvertOp);
+CHECK_UNSIGNED_OP(spirv::UDivOp);
+CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp);
+CHECK_UNSIGNED_OP(spirv::UGreaterThanOp);
+CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp);
+CHECK_UNSIGNED_OP(spirv::ULessThanOp);
+CHECK_UNSIGNED_OP(spirv::UModOp);
+
+#undef CHECK_UNSIGNED_OP
+
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
+ if (isUnsignedOp<SPIRVOp>() && dstType != operation.getType()) {
+ return operation.emitError(
+ "bitwidth emulation is not implemented yet on unsigned op");
+ }
rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands,
ArrayRef<NamedAttribute>());
return success();
switch (cmpIOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
+ if (isUnsignedOp<spirvOp>() && \
+ operandType != this->typeConverter.convertType(operandType)) { \
+ return cmpIOp.emitError( \
+ "bitwidth emulation is not implemented yet on unsigned op"); \
+ } \
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
cmpIOpOperands.lhs(), \
cmpIOpOperands.rhs()); \
Value mask = rewriter.create<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+
+ // Apply sign extension on the loading value unconditionally. The signedness
+ // semantic is carried in the operator itself, we relies other pattern to
+ // handle the casting.
+ IntegerAttr shiftValueAttr =
+ rewriter.getIntegerAttr(dstType, dstBits - srcBits);
+ Value shiftValue =
+ rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
+ result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
+ shiftValue);
+ result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
+ shiftValue);
rewriter.replaceOp(loadOp, result);
assert(accessChainOp.use_empty());
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s
//===----------------------------------------------------------------------===//
// std arithmetic ops
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
-// CHECK-LABEL: @int_vector234
-func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) {
+// CHECK-LABEL: @int_vector23
+func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) {
// CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32>
%0 = divi_signed %arg0, %arg0: vector<2xi8>
// CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32>
%1 = remi_signed %arg1, %arg1: vector<3xi16>
- // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi32>
- %2 = divi_unsigned %arg2, %arg2: vector<4xi64>
return
}
// -----
+// Check that types are converted to 32-bit when no special capabilities that
+// are not supported.
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LEBEL: @int_vector4_invalid
+func @int_vector4_invalid(%arg0: vector<4xi64>) {
+ // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}}
+ // expected-error @+1 {{op requires the same type for all operands and results}}
+ %0 = divi_unsigned %arg0, %arg0: vector<4xi64>
+ return
+}
+
+} // end module
+
+// -----
+
//===----------------------------------------------------------------------===//
// std bit ops
//===----------------------------------------------------------------------===//
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.constant 255 : i32
- // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T2:.+]] = spv.constant 24 : i32
+ // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+ // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
%0 = load %arg0[] : memref<i8>
return
}
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.constant 65535 : i32
- // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T2:.+]] = spv.constant 16 : i32
+ // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+ // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
%0 = load %arg0[%index] : memref<10xi16>
return
}
// CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
// CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
// CHECK: %[[MASK:.+]] = spv.constant 255 : i32
- // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[T2:.+]] = spv.constant 24 : i32
+ // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+ // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
%0 = load %arg0[] : memref<i8>
return
}