From 2acd0dbf05aac71dca030ba6a4141e68ca509916 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 2 Sep 2019 21:06:35 -0700 Subject: [PATCH] Add Select operation to SPIR-V dialect. The SelectOp models the semantics of OpSelect from SPIR-V spec. PiperOrigin-RevId: 266849559 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 12 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td | 57 +++++++++++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 58 ++++++++++++ mlir/test/Dialect/SPIRV/Serialization/select.mlir | 20 ++++ mlir/test/Dialect/SPIRV/arithmetic-ops.mlir | 6 +- mlir/test/Dialect/SPIRV/ops.mlir | 104 +++++++++++++++++++++ 6 files changed, 247 insertions(+), 10 deletions(-) create mode 100644 mlir/test/Dialect/SPIRV/Serialization/select.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index e90e816..7dea586 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -121,6 +121,7 @@ def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>; def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>; def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; +def SPV_OC_OpSelect : I32EnumAttrCase<"OpSelect", 169>; def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>; def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>; def SPV_OC_OpUGreaterThan : I32EnumAttrCase<"OpUGreaterThan", 172>; @@ -164,7 +165,7 @@ def SPV_OpcodeAttr : SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, - SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual, + SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, @@ -217,16 +218,13 @@ def SPV_Type : AnyTypeOf<[ SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct ]>; -class SPV_ScalarOrVectorOf : - Type.predicate]>, - "scalar/vector of " # type.description>; +class SPV_ScalarOrVectorOf : AnyTypeOf<[type, VectorOf<[type]>]>; // TODO(antiagainst): Use a more appropriate way to model optional operands class SPV_Optional : Variadic; -def SPV_IsEntryPointType : - CPred<"$_self.isa<::mlir::spirv::EntryPointType>()">; -def SPV_EntryPoint : Type; +// TODO(ravishankarm): From 1.4, this should also include Composite type. +def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>; //===----------------------------------------------------------------------===// // SPIR-V extension definitions diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td index 51781d8..1e9a547 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -634,6 +634,63 @@ def SPV_SLessThanEqualOp : SPV_LogicalOp<"SLessThanEqual", SPV_Integer, []> { // ----- +def SPV_SelectOp : SPV_Op<"Select", []> { + let summary = [{ + Select between two objects. Before version 1.4, results are only + computed per component. + }]; + + let description = [{ + Before version 1.4, Result Type must be a pointer, scalar, or vector. + + The types of Object 1 and Object 2 must be the same as Result Type. + + Condition must be a scalar or vector of Boolean type. + + If Condition is a scalar and true, the result is Object 1. If Condition + is a scalar and false, the result is Object 2. + + If Condition is a vector, Result Type must be a vector with the same + number of components as Condition and the result is a mix of Object 1 + and Object 2: When a component of Condition is true, the corresponding + component in the result is taken from Object 1, otherwise it is taken + from Object 2. + + ### Custom assembly form + + ``` {.ebnf} + scalar-type ::= integer-type | float-type | boolean-type + select-object-type ::= scalar-type + | `vector<` integer-literal `x` scalar-type `>` + | pointer-type + select-condition-type ::= boolean-type + | `vector<` integer-literal `x` boolean-type `>` + select-op ::= ssa-id `=` `spv.Select` ssa-use, ssa-use, ssa-use + `:` select-condition-type `,` select-object-type + ``` + + For example: + + ``` + %3 = spv.Select %0, %1, %2 : i1, f32 + %3 = spv.Select %0, %1, %2 : i1, vector<3xi32> + %3 = spv.Select %0, %1, %2 : vector<3xi1>, vector<3xf32> + ``` + }]; + + let arguments = (ins + SPV_ScalarOrVectorOf:$condition, + SPV_SelectType:$true_value, + SPV_SelectType:$false_value + ); + + let results = (outs + SPV_SelectType:$result + ); +} + +// ----- + def SPV_UGreaterThanOp : SPV_LogicalOp<"UGreaterThan", SPV_Integer, []> { let summary = [{ Unsigned-integer comparison if Operand 1 is greater than Operand 2. diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 2b1248b..66b2f5d 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1260,6 +1260,64 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) { } //===----------------------------------------------------------------------===// +// spv.Select +//===----------------------------------------------------------------------===// + +static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *state) { + OpAsmParser::OperandType condition; + SmallVector operands; + SmallVector types; + auto loc = parser->getCurrentLocation(); + if (parser->parseOperand(condition) || parser->parseComma() || + parser->parseOperandList(operands, 2) || + parser->parseColonTypeList(types)) { + return failure(); + } + if (types.size() != 2) { + return parser->emitError( + loc, "need exactly two trailing types for select condition and object"); + } + if (parser->resolveOperand(condition, types[0], state->operands) || + parser->resolveOperands(operands, types[1], state->operands)) { + return failure(); + } + return parser->addTypesToList(types[1], state->types); +} + +static void print(spirv::SelectOp op, OpAsmPrinter *printer) { + *printer << spirv::SelectOp::getOperationName() << " "; + + // Print the operands. + printer->printOperands(op.getOperands()); + + // Print colon and types. + *printer << " : " << op.condition()->getType() << ", " + << op.result()->getType(); +} + +static LogicalResult verify(spirv::SelectOp op) { + auto resultTy = op.result()->getType(); + if (op.true_value()->getType() != resultTy) { + return op.emitOpError("result type and true value type must be the same"); + } + if (op.false_value()->getType() != resultTy) { + return op.emitOpError("result type and false value type must be the same"); + } + if (auto conditionTy = op.condition()->getType().dyn_cast()) { + auto resultVectorTy = resultTy.dyn_cast(); + if (!resultVectorTy) { + return op.emitOpError("result expected to be of vector type when " + "condition is of vector type"); + } + if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { + return op.emitOpError("result should have the same number of elements as " + "the condition when condition is of vector type"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// // spv.specConstant //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/select.mlir b/mlir/test/Dialect/SPIRV/Serialization/select.mlir new file mode 100644 index 0000000..aec39e8 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/select.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s + +spv.module "Logical" "VulkanKHR" { + spv.specConstant @condition_scalar = true + func @select() -> () { + %0 = spv.constant 4.0 : f32 + %1 = spv.constant 5.0 : f32 + %2 = spv._reference_of @condition_scalar : i1 + // CHECK: spv.Select {{.*}}, {{.*}}, {{.*}} : i1, f32 + %3 = spv.Select %2, %0, %1 : i1, f32 + %4 = spv.constant dense<[2.0, 3.0, 4.0, 5.0]> : vector<4xf32> + %5 = spv.constant dense<[6.0, 7.0, 8.0, 9.0]> : vector<4xf32> + // CHECK: spv.Select {{.*}}, {{.*}}, {{.*}} : i1, vector<4xf32> + %6 = spv.Select %2, %4, %5 : i1, vector<4xf32> + %7 = spv.constant dense<[true, true, true, true]> : vector<4xi1> + // CHECK: spv.Select {{.*}}, {{.*}}, {{.*}} : vector<4xi1>, vector<4xf32> + %8 = spv.Select %7, %4, %5 : vector<4xi1>, vector<4xf32> + spv.Return + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir index ea12268..9369962 100644 --- a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir @@ -55,7 +55,7 @@ func @fmul_vector(%arg: vector<4xf32>) -> vector<4xf32> { // ----- func @fmul_i32(%arg: i32) -> i32 { - // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} %0 = spv.FMul %arg, %arg : i32 return %0 : i32 } @@ -63,7 +63,7 @@ func @fmul_i32(%arg: i32) -> i32 { // ----- func @fmul_bf16(%arg: bf16) -> bf16 { - // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} %0 = spv.FMul %arg, %arg : bf16 return %0 : bf16 } @@ -71,7 +71,7 @@ func @fmul_bf16(%arg: bf16) -> bf16 { // ----- func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { - // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} %0 = spv.FMul %arg, %arg : tensor<4xf32> return %0 : tensor<4xf32> } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 3e32b90..524f1d2 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -431,6 +431,110 @@ spv.module "Logical" "VulkanKHR" { // ----- //===----------------------------------------------------------------------===// +// spv.SelectOp +//===----------------------------------------------------------------------===// + +func @select_op_bool(%arg0: i1) -> () { + %0 = spv.constant true + %1 = spv.constant false + // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, i1 + %2 = spv.Select %arg0, %0, %1 : i1, i1 + return +} + +func @select_op_int(%arg0: i1) -> () { + %0 = spv.constant 2 : i32 + %1 = spv.constant 3 : i32 + // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, i32 + %2 = spv.Select %arg0, %0, %1 : i1, i32 + return +} + +func @select_op_float(%arg0: i1) -> () { + %0 = spv.constant 2.0 : f32 + %1 = spv.constant 3.0 : f32 + // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, f32 + %2 = spv.Select %arg0, %0, %1 : i1, f32 + return +} + +func @select_op_ptr(%arg0: i1) -> () { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, !spv.ptr + %2 = spv.Select %arg0, %0, %1 : i1, !spv.ptr + return +} + +func @select_op_vec(%arg0: i1) -> () { + %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32> + %1 = spv.constant dense<[5.0, 6.0, 7.0]> : vector<3xf32> + // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, vector<3xf32> + %2 = spv.Select %arg0, %0, %1 : i1, vector<3xf32> + return +} + +func @select_op_vec_condn_vec(%arg0: vector<3xi1>) -> () { + %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32> + %1 = spv.constant dense<[5.0, 6.0, 7.0]> : vector<3xf32> + // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : vector<3xi1>, vector<3xf32> + %2 = spv.Select %arg0, %0, %1 : vector<3xi1>, vector<3xf32> + return +} + +// ----- + +func @select_op(%arg0: i1) -> () { + %0 = spv.constant 2 : i32 + %1 = spv.constant 3 : i32 + // expected-error @+1 {{need exactly two trailing types for select condition and object}} + %2 = spv.Select %arg0, %0, %1 : i1 + return +} + +// ----- + +func @select_op(%arg1: vector<3xi1>) -> () { + %0 = spv.constant 2 : i32 + %1 = spv.constant 3 : i32 + // expected-error @+1 {{result expected to be of vector type when condition is of vector type}} + %2 = spv.Select %arg1, %0, %1 : vector<3xi1>, i32 + return +} + +// ----- + +func @select_op(%arg1: vector<4xi1>) -> () { + %0 = spv.constant dense<[2, 3, 4]> : vector<3xi32> + %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32> + // expected-error @+1 {{result should have the same number of elements as the condition when condition is of vector type}} + %2 = spv.Select %arg1, %0, %1 : vector<4xi1>, vector<3xi32> + return +} + +// ----- + +func @select_op(%arg1: vector<4xi1>) -> () { + %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32> + %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32> + // expected-error @+1 {{op result type and true value type must be the same}} + %2 = "spv.Select"(%arg1, %0, %1) : (vector<4xi1>, vector<3xf32>, vector<3xi32>) -> vector<3xi32> + return +} + +// ----- + +func @select_op(%arg1: vector<4xi1>) -> () { + %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32> + %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32> + // expected-error @+1 {{op result type and false value type must be the same}} + %2 = "spv.Select"(%arg1, %1, %0) : (vector<4xi1>, vector<3xi32>, vector<3xf32>) -> vector<3xi32> + return +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.StoreOp //===----------------------------------------------------------------------===// -- 2.7.4