def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>;
def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
+def SPV_OC_OpSelect : I32EnumAttrCase<"OpSelect", 169>;
def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>;
def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>;
def SPV_OC_OpUGreaterThan : I32EnumAttrCase<"OpUGreaterThan", 172>;
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
- SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual,
+ SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
]>;
-class SPV_ScalarOrVectorOf<Type type> :
- Type<Or<[type.predicate, VectorOf<[type]>.predicate]>,
- "scalar/vector of " # type.description>;
+class SPV_ScalarOrVectorOf<Type type> : AnyTypeOf<[type, VectorOf<[type]>]>;
// TODO(antiagainst): Use a more appropriate way to model optional operands
class SPV_Optional<Type type> : Variadic<type>;
-def SPV_IsEntryPointType :
- CPred<"$_self.isa<::mlir::spirv::EntryPointType>()">;
-def SPV_EntryPoint : Type<SPV_IsEntryPointType, "SPIR-V entry point type">;
+// TODO(ravishankarm): From 1.4, this should also include Composite type.
+def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>;
//===----------------------------------------------------------------------===//
// SPIR-V extension definitions
// -----
+def SPV_SelectOp : SPV_Op<"Select", []> {
+ let summary = [{
+ Select between two objects. Before version 1.4, results are only
+ computed per component.
+ }];
+
+ let description = [{
+ Before version 1.4, Result Type must be a pointer, scalar, or vector.
+
+ The types of Object 1 and Object 2 must be the same as Result Type.
+
+ Condition must be a scalar or vector of Boolean type.
+
+ If Condition is a scalar and true, the result is Object 1. If Condition
+ is a scalar and false, the result is Object 2.
+
+ If Condition is a vector, Result Type must be a vector with the same
+ number of components as Condition and the result is a mix of Object 1
+ and Object 2: When a component of Condition is true, the corresponding
+ component in the result is taken from Object 1, otherwise it is taken
+ from Object 2.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ scalar-type ::= integer-type | float-type | boolean-type
+ select-object-type ::= scalar-type
+ | `vector<` integer-literal `x` scalar-type `>`
+ | pointer-type
+ select-condition-type ::= boolean-type
+ | `vector<` integer-literal `x` boolean-type `>`
+ select-op ::= ssa-id `=` `spv.Select` ssa-use, ssa-use, ssa-use
+ `:` select-condition-type `,` select-object-type
+ ```
+
+ For example:
+
+ ```
+ %3 = spv.Select %0, %1, %2 : i1, f32
+ %3 = spv.Select %0, %1, %2 : i1, vector<3xi32>
+ %3 = spv.Select %0, %1, %2 : vector<3xi1>, vector<3xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_ScalarOrVectorOf<SPV_Bool>:$condition,
+ SPV_SelectType:$true_value,
+ SPV_SelectType:$false_value
+ );
+
+ let results = (outs
+ SPV_SelectType:$result
+ );
+}
+
+// -----
+
def SPV_UGreaterThanOp : SPV_LogicalOp<"UGreaterThan", SPV_Integer, []> {
let summary = [{
Unsigned-integer comparison if Operand 1 is greater than Operand 2.
}
//===----------------------------------------------------------------------===//
+// spv.Select
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *state) {
+ OpAsmParser::OperandType condition;
+ SmallVector<OpAsmParser::OperandType, 2> operands;
+ SmallVector<Type, 2> types;
+ auto loc = parser->getCurrentLocation();
+ if (parser->parseOperand(condition) || parser->parseComma() ||
+ parser->parseOperandList(operands, 2) ||
+ parser->parseColonTypeList(types)) {
+ return failure();
+ }
+ if (types.size() != 2) {
+ return parser->emitError(
+ loc, "need exactly two trailing types for select condition and object");
+ }
+ if (parser->resolveOperand(condition, types[0], state->operands) ||
+ parser->resolveOperands(operands, types[1], state->operands)) {
+ return failure();
+ }
+ return parser->addTypesToList(types[1], state->types);
+}
+
+static void print(spirv::SelectOp op, OpAsmPrinter *printer) {
+ *printer << spirv::SelectOp::getOperationName() << " ";
+
+ // Print the operands.
+ printer->printOperands(op.getOperands());
+
+ // Print colon and types.
+ *printer << " : " << op.condition()->getType() << ", "
+ << op.result()->getType();
+}
+
+static LogicalResult verify(spirv::SelectOp op) {
+ auto resultTy = op.result()->getType();
+ if (op.true_value()->getType() != resultTy) {
+ return op.emitOpError("result type and true value type must be the same");
+ }
+ if (op.false_value()->getType() != resultTy) {
+ return op.emitOpError("result type and false value type must be the same");
+ }
+ if (auto conditionTy = op.condition()->getType().dyn_cast<VectorType>()) {
+ auto resultVectorTy = resultTy.dyn_cast<VectorType>();
+ if (!resultVectorTy) {
+ return op.emitOpError("result expected to be of vector type when "
+ "condition is of vector type");
+ }
+ if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
+ return op.emitOpError("result should have the same number of elements as "
+ "the condition when condition is of vector type");
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.specConstant
//===----------------------------------------------------------------------===//
--- /dev/null
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+spv.module "Logical" "VulkanKHR" {
+ spv.specConstant @condition_scalar = true
+ func @select() -> () {
+ %0 = spv.constant 4.0 : f32
+ %1 = spv.constant 5.0 : f32
+ %2 = spv._reference_of @condition_scalar : i1
+ // CHECK: spv.Select {{.*}}, {{.*}}, {{.*}} : i1, f32
+ %3 = spv.Select %2, %0, %1 : i1, f32
+ %4 = spv.constant dense<[2.0, 3.0, 4.0, 5.0]> : vector<4xf32>
+ %5 = spv.constant dense<[6.0, 7.0, 8.0, 9.0]> : vector<4xf32>
+ // CHECK: spv.Select {{.*}}, {{.*}}, {{.*}} : i1, vector<4xf32>
+ %6 = spv.Select %2, %4, %5 : i1, vector<4xf32>
+ %7 = spv.constant dense<[true, true, true, true]> : vector<4xi1>
+ // CHECK: spv.Select {{.*}}, {{.*}}, {{.*}} : vector<4xi1>, vector<4xf32>
+ %8 = spv.Select %7, %4, %5 : vector<4xi1>, vector<4xf32>
+ spv.Return
+ }
+}
\ No newline at end of file
// -----
func @fmul_i32(%arg: i32) -> i32 {
- // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
%0 = spv.FMul %arg, %arg : i32
return %0 : i32
}
// -----
func @fmul_bf16(%arg: bf16) -> bf16 {
- // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
%0 = spv.FMul %arg, %arg : bf16
return %0 : bf16
}
// -----
func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
- // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
%0 = spv.FMul %arg, %arg : tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
//===----------------------------------------------------------------------===//
+// spv.SelectOp
+//===----------------------------------------------------------------------===//
+
+func @select_op_bool(%arg0: i1) -> () {
+ %0 = spv.constant true
+ %1 = spv.constant false
+ // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, i1
+ %2 = spv.Select %arg0, %0, %1 : i1, i1
+ return
+}
+
+func @select_op_int(%arg0: i1) -> () {
+ %0 = spv.constant 2 : i32
+ %1 = spv.constant 3 : i32
+ // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, i32
+ %2 = spv.Select %arg0, %0, %1 : i1, i32
+ return
+}
+
+func @select_op_float(%arg0: i1) -> () {
+ %0 = spv.constant 2.0 : f32
+ %1 = spv.constant 3.0 : f32
+ // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, f32
+ %2 = spv.Select %arg0, %0, %1 : i1, f32
+ return
+}
+
+func @select_op_ptr(%arg0: i1) -> () {
+ %0 = spv.Variable : !spv.ptr<f32, Function>
+ %1 = spv.Variable : !spv.ptr<f32, Function>
+ // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, !spv.ptr<f32, Function>
+ %2 = spv.Select %arg0, %0, %1 : i1, !spv.ptr<f32, Function>
+ return
+}
+
+func @select_op_vec(%arg0: i1) -> () {
+ %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
+ %1 = spv.constant dense<[5.0, 6.0, 7.0]> : vector<3xf32>
+ // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, vector<3xf32>
+ %2 = spv.Select %arg0, %0, %1 : i1, vector<3xf32>
+ return
+}
+
+func @select_op_vec_condn_vec(%arg0: vector<3xi1>) -> () {
+ %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
+ %1 = spv.constant dense<[5.0, 6.0, 7.0]> : vector<3xf32>
+ // CHECK : spv.Select {{%.*}}, {{%.*}}, {{%.*}} : vector<3xi1>, vector<3xf32>
+ %2 = spv.Select %arg0, %0, %1 : vector<3xi1>, vector<3xf32>
+ return
+}
+
+// -----
+
+func @select_op(%arg0: i1) -> () {
+ %0 = spv.constant 2 : i32
+ %1 = spv.constant 3 : i32
+ // expected-error @+1 {{need exactly two trailing types for select condition and object}}
+ %2 = spv.Select %arg0, %0, %1 : i1
+ return
+}
+
+// -----
+
+func @select_op(%arg1: vector<3xi1>) -> () {
+ %0 = spv.constant 2 : i32
+ %1 = spv.constant 3 : i32
+ // expected-error @+1 {{result expected to be of vector type when condition is of vector type}}
+ %2 = spv.Select %arg1, %0, %1 : vector<3xi1>, i32
+ return
+}
+
+// -----
+
+func @select_op(%arg1: vector<4xi1>) -> () {
+ %0 = spv.constant dense<[2, 3, 4]> : vector<3xi32>
+ %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
+ // expected-error @+1 {{result should have the same number of elements as the condition when condition is of vector type}}
+ %2 = spv.Select %arg1, %0, %1 : vector<4xi1>, vector<3xi32>
+ return
+}
+
+// -----
+
+func @select_op(%arg1: vector<4xi1>) -> () {
+ %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
+ %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
+ // expected-error @+1 {{op result type and true value type must be the same}}
+ %2 = "spv.Select"(%arg1, %0, %1) : (vector<4xi1>, vector<3xf32>, vector<3xi32>) -> vector<3xi32>
+ return
+}
+
+// -----
+
+func @select_op(%arg1: vector<4xi1>) -> () {
+ %0 = spv.constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
+ %1 = spv.constant dense<[5, 6, 7]> : vector<3xi32>
+ // expected-error @+1 {{op result type and false value type must be the same}}
+ %2 = "spv.Select"(%arg1, %1, %0) : (vector<4xi1>, vector<3xi32>, vector<3xf32>) -> vector<3xi32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spv.StoreOp
//===----------------------------------------------------------------------===//