From f3bc0fccd68a1208d568360b0e3f6483759bae4a Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 15 Jun 2022 20:38:47 -0400 Subject: [PATCH] [mlir][spirv] Define spv.ISubBorrowOp Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D127909 --- .../mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 51 +++++++++++++++++++++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | 53 +++++++++++----------- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 53 ++++++++++++++++++++++ mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir | 52 +++++++++++++++++++++ 4 files changed, 183 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 4689341..30fd95f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -390,6 +390,57 @@ def SPV_ISubOp : SPV_ArithmeticBinaryOp<"ISub", // ----- +def SPV_ISubBorrowOp : SPV_BinaryOp<"ISubBorrow", SPV_AnyStruct, SPV_Integer, + [NoSideEffect]> { + let summary = [{ + Result is the unsigned integer subtraction of Operand 2 from Operand 1, + and what it needed to borrow. + }]; + + let description = [{ + Result Type must be from OpTypeStruct. The struct must have two + members, and the two members must be the same type. The member type + must be a scalar or vector of integer type, whose Signedness operand is + 0. + + Operand 1 and Operand 2 must have the same type as the members of Result + Type. These are consumed as unsigned integers. + + Results are computed per component. + + Member 0 of the result gets the low-order bits (full component width) of + the subtraction. That is, if Operand 1 is larger than Operand 2, member + 0 gets the full value of the subtraction; if Operand 2 is larger than + Operand 1, member 0 gets 2w + Operand 1 - Operand 2, where w is the + component width. + + Member 1 of the result gets 0 if Operand 1 ≥ Operand 2, and gets 1 + otherwise. + + + + #### Example: + + ```mlir + %2 = spv.ISubBorrow %0, %1 : !spv.struct<(i32, i32)> + %2 = spv.ISubBorrow %0, %1 : !spv.struct<(vector<2xi32>, vector<2xi32>)> + ``` + }]; + + let arguments = (ins + SPV_ScalarOrVectorOf:$operand1, + SPV_ScalarOrVectorOf:$operand2 + ); + + let results = (outs + SPV_AnyStruct:$result + ); + + let hasVerifier = 1; +} + +// ----- + def SPV_SDivOp : SPV_ArithmeticBinaryOp<"SDiv", SPV_Integer, [UsableInSpecConstantOp]> { diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 0b29e47..79d1222 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4090,6 +4090,7 @@ def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; def SPV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>; def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; +def SPV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>; def SPV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; def SPV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>; def SPV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>; @@ -4214,32 +4215,32 @@ def SPV_OpcodeAttr : SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpVectorTimesScalar, - SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpIsNan, - SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, SPV_OC_OpLogicalEqual, - SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, - SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, - SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, - SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, - SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, - SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, - SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, - SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, - SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, - SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, - SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, - SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, - SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, - SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, - SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicExchange, - SPV_OC_OpAtomicCompareExchange, SPV_OC_OpAtomicCompareExchangeWeak, - SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, - SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, - SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, - SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, - SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, - SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, - SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, SPV_OC_OpNoLine, - SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, + SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpISubBorrow, + SPV_OC_OpIsNan, SPV_OC_OpIsInf, SPV_OC_OpOrdered, SPV_OC_OpUnordered, + SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, + SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, + SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, + SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, + SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, + SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, + SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, + SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, + SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, + SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, + SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, + SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, + SPV_OC_OpAtomicExchange, SPV_OC_OpAtomicCompareExchange, + SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, + SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, + SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, + SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, + SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, + SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, + SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, + SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 5d7d943..123f707 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -28,6 +28,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" @@ -2840,6 +2841,58 @@ void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) { } //===----------------------------------------------------------------------===// +// spv.ISubBorrowOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ISubBorrowOp::verify() { + auto resultType = getType().cast(); + if (resultType.getNumElements() != 2) + return emitOpError("expected result struct type containing two members"); + + SmallVector types; + types.push_back(operand1().getType()); + types.push_back(operand2().getType()); + types.push_back(resultType.getElementType(0)); + types.push_back(resultType.getElementType(1)); + if (!llvm::is_splat(types)) + return emitOpError( + "expected all operand types and struct member types are the same"); + + return success(); +} + +ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser, + OperationState &state) { + SmallVector operands; + if (parser.parseOptionalAttrDict(state.attributes) || + parser.parseOperandList(operands) || parser.parseColon()) + return failure(); + + Type resultType; + auto loc = parser.getCurrentLocation(); + if (parser.parseType(resultType)) + return failure(); + + auto structType = resultType.dyn_cast(); + if (!structType || structType.getNumElements() != 2) + return parser.emitError(loc, "expected spv.struct type with two members"); + + SmallVector operandTypes(2, structType.getElementType(0)); + if (parser.resolveOperands(operands, operandTypes, loc, state.operands)) + return failure(); + + state.addTypes(resultType); + return success(); +} + +void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) { + printer << ' '; + printer.printOptionalAttrDict((*this)->getAttrs()); + printer.printOperands((*this)->getOperands()); + printer << " : " << getType(); +} + +//===----------------------------------------------------------------------===// // spv.LoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir index 214f755..22159b0 100644 --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -151,6 +151,58 @@ func.func @isub_scalar(%arg: i32) -> i32 { // ----- //===----------------------------------------------------------------------===// +// spv.ISubBorrow +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @isub_borrow_scalar +func.func @isub_borrow_scalar(%arg: i32) -> !spv.struct<(i32, i32)> { + // CHECK: spv.ISubBorrow %{{.+}}, %{{.+}} : !spv.struct<(i32, i32)> + %0 = spv.ISubBorrow %arg, %arg : !spv.struct<(i32, i32)> + return %0 : !spv.struct<(i32, i32)> +} + +// CHECK-LABEL: @isub_borrow_vector +func.func @isub_borrow_vector(%arg: vector<3xi32>) -> !spv.struct<(vector<3xi32>, vector<3xi32>)> { + // CHECK: spv.ISubBorrow %{{.+}}, %{{.+}} : !spv.struct<(vector<3xi32>, vector<3xi32>)> + %0 = spv.ISubBorrow %arg, %arg : !spv.struct<(vector<3xi32>, vector<3xi32>)> + return %0 : !spv.struct<(vector<3xi32>, vector<3xi32>)> +} + +// ----- + +func.func @isub_borrow(%arg: i32) -> !spv.struct<(i32, i32, i32)> { + // expected-error @+1 {{expected spv.struct type with two members}} + %0 = spv.ISubBorrow %arg, %arg : !spv.struct<(i32, i32, i32)> + return %0 : !spv.struct<(i32, i32, i32)> +} + +// ----- + +func.func @isub_borrow(%arg: i32) -> !spv.struct<(i32)> { + // expected-error @+1 {{expected result struct type containing two members}} + %0 = "spv.ISubBorrow"(%arg, %arg): (i32, i32) -> !spv.struct<(i32)> + return %0 : !spv.struct<(i32)> +} + +// ----- + +func.func @isub_borrow(%arg: i32) -> !spv.struct<(i32, i64)> { + // expected-error @+1 {{expected all operand types and struct member types are the same}} + %0 = "spv.ISubBorrow"(%arg, %arg): (i32, i32) -> !spv.struct<(i32, i64)> + return %0 : !spv.struct<(i32, i64)> +} + +// ----- + +func.func @isub_borrow(%arg: i64) -> !spv.struct<(i32, i32)> { + // expected-error @+1 {{expected all operand types and struct member types are the same}} + %0 = "spv.ISubBorrow"(%arg, %arg): (i64, i64) -> !spv.struct<(i32, i32)> + return %0 : !spv.struct<(i32, i32)> +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.SDiv //===----------------------------------------------------------------------===// -- 2.7.4