From 7f85adb54d1956183630eb43c2f3e578f7366276 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 23 Apr 2020 04:40:42 -0700 Subject: [PATCH] [mlir][Standard] Allow select to use an i1 for vector and tensor values It currently requires that the condition match the shape of the selected value, but this is only really useful for things like masks. This revision allows for the use of i1 to mean that all of the vector/tensor is selected. This also matches the behavior of LLVM select. A benefit of this change is that transformations that want to generate selects, like those on the CFG, don't have to special case vector/tensor. Previously the only way to generate a select from an i1 was to use a splat, but that doesn't support dynamically shaped/unranked tensors. Differential Revision: https://reviews.llvm.org/D78690 --- llvm/include/llvm/ADT/STLExtras.h | 3 +- mlir/include/mlir/Dialect/StandardOps/IR/Ops.td | 23 ++++----- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 62 +++++++++++++++++++++---- mlir/test/Dialect/Standard/canonicalize-cf.mlir | 39 ++++------------ mlir/test/IR/core-ops.mlir | 10 ++-- mlir/test/IR/invalid-ops.mlir | 16 +++---- 6 files changed, 86 insertions(+), 67 deletions(-) diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index c10f4d2..d9a1d0c 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1124,7 +1124,8 @@ public: /// Compare this range with another. template bool operator==(const OtherT &other) const { - return size() == std::distance(other.begin(), other.end()) && + return size() == + static_cast(std::distance(other.begin(), other.end())) && std::equal(begin(), end(), other.begin()); } template bool operator!=(const OtherT &other) const { diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 39c9597..54800a5 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1915,11 +1915,8 @@ def RsqrtOp : FloatUnaryOp<"rsqrt"> { // SelectOp //===----------------------------------------------------------------------===// -def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape, - AllTypesMatch<["true_value", "false_value", "result"]>, - TypesMatchWith<"condition type matches i1 equivalent of result type", - "result", "condition", - "getI1SameShape($_self)">]> { +def SelectOp : Std_Op<"select", [NoSideEffect, + AllTypesMatch<["true_value", "false_value", "result"]>]> { let summary = "select operation"; let description = [{ The `select` operation chooses one value based on a binary condition @@ -1930,7 +1927,8 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape, The operation applies to vectors and tensors elementwise given the _shape_ of all operands is identical. The choice is made for each element individually based on the value at the same position as the element in the - condition operand. + condition operand. If an i1 is provided as the condition, the entire vector + or tensor is chosen. The `select` operation combined with [`cmpi`](#stdcmpi-cmpiop) can be used to implement `min` and `max` with signed or unsigned comparison semantics. @@ -1944,9 +1942,11 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape, // Generic form of the same operation. %x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32 - // Vector selection is element-wise - %vx = "std.select"(%vcond, %vtrue, %vfalse) - : (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32> + // Element-wise vector selection. + %vx = std.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32> + + // Full vector selection. + %vx = std.select %cond, %vtrue, %vfalse : vector<42xf32> ``` }]; @@ -1954,7 +1954,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape, AnyType:$true_value, AnyType:$false_value); let results = (outs AnyType:$result); - let verifier = ?; let builders = [OpBuilder< "Builder *builder, OperationState &result, Value condition," @@ -1970,10 +1969,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape, }]; let hasFolder = 1; - - let assemblyFormat = [{ - $condition `,` $true_value `,` $false_value attr-dict `:` type($result) - }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 3294210..bf4bfc8 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -999,15 +999,6 @@ struct SimplifyCondBranchIdenticalSuccessors if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock()) return failure(); - // TODO: ATM Tensor/Vector SelectOp requires that the condition has the same - // shape as the operands. We should relax that to allow an i1 to signify - // that everything is selected. - auto doesntSupportsScalarI1 = [](Type type) { - return type.isa() || type.isa(); - }; - if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1)) - return failure(); - // Generate a select for any operands that differ between the two. SmallVector mergedOperands; mergedOperands.reserve(trueOperands.size()); @@ -1925,6 +1916,59 @@ OpFoldResult SelectOp::fold(ArrayRef operands) { return nullptr; } +static void print(OpAsmPrinter &p, SelectOp op) { + p << "select " << op.getOperands(); + p.printOptionalAttrDict(op.getAttrs()); + p << " : "; + if (ShapedType condType = op.getCondition().getType().dyn_cast()) + p << condType << ", "; + p << op.getType(); +} + +static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { + Type conditionType, resultType; + SmallVector operands; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(resultType)) + return failure(); + + // Check for the explicit condition type if this is a masked tensor or vector. + if (succeeded(parser.parseOptionalComma())) { + conditionType = resultType; + if (parser.parseType(resultType)) + return failure(); + } else { + conditionType = parser.getBuilder().getI1Type(); + } + + result.addTypes(resultType); + return parser.resolveOperands(operands, + {conditionType, resultType, resultType}, + parser.getNameLoc(), result.operands); +} + +static LogicalResult verify(SelectOp op) { + Type conditionType = op.getCondition().getType(); + if (conditionType.isSignlessInteger(1)) + return success(); + + // If the result type is a vector or tensor, the type can be a mask with the + // same elements. + Type resultType = op.getType(); + if (!resultType.isa() && !resultType.isa()) + return op.emitOpError() + << "expected condition to be a signless i1, but got " + << conditionType; + Type shapedConditionType = getI1SameShape(resultType); + if (conditionType != shapedConditionType) + return op.emitOpError() + << "expected condition type to have the same shape " + "as the result type, expected " + << shapedConditionType << ", but got " << conditionType; + return success(); +} + //===----------------------------------------------------------------------===// // SignExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir index 71ee7f1..b0fd844 100644 --- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -69,39 +69,18 @@ func @cond_br_same_successor(%cond : i1, %a : i32) { // CHECK-LABEL: func @cond_br_same_successor_insert_select( // CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 -func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 { +// CHECK-SAME: %[[ARG2:.*]]: tensor<2xi32>, %[[ARG3:.*]]: tensor<2xi32> +func @cond_br_same_successor_insert_select( + %cond : i1, %a : i32, %b : i32, %c : tensor<2xi32>, %d : tensor<2xi32> + ) -> (i32, tensor<2xi32>) { // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]] - // CHECK: return %[[RES]] - - cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32) - -^bb1(%result : i32): - return %result : i32 -} - -/// Check that we don't generate a select if the type requires a splat. -/// TODO: SelectOp should allow for matching a vector/tensor with i1. - -// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor( -func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>, - %b : tensor<2xi32>) -> tensor<2xi32>{ - // CHECK: cond_br - - cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>) - -^bb1(%result : tensor<2xi32>): - return %result : tensor<2xi32> -} - -// CHECK-LABEL: func @cond_br_same_successor_no_select_vector( -func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>, - %b : vector<2xi32>) -> vector<2xi32> { - // CHECK: cond_br + // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG2]], %[[ARG3]] + // CHECK: return %[[RES]], %[[RES2]] - cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>) + cond_br %cond, ^bb1(%a, %c : i32, tensor<2xi32>), ^bb1(%b, %d : i32, tensor<2xi32>) -^bb1(%result : vector<2xi32>): - return %result : vector<2xi32> +^bb1(%result : i32, %result2 : tensor<2xi32>): + return %result, %result2 : i32, tensor<2xi32> } /// Test the compound folding of BranchOp and CondBranchOp. diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index d19f344..d0a27ec 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -141,17 +141,17 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) { // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index %21 = select %18, %idx, %idx : index - // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32> - %22 = select %19, %tci32, %tci32 : tensor<42 x i32> + // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi1>, tensor<42xi32> + %22 = select %19, %tci32, %tci32 : tensor<42 x i1>, tensor<42 x i32> - // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi32> - %23 = select %20, %vci32, %vci32 : vector<42 x i32> + // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi1>, vector<42xi32> + %23 = select %20, %vci32, %vci32 : vector<42 x i1>, vector<42 x i32> // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index %24 = "std.select"(%18, %idx, %idx) : (i1, index, index) -> index // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32> - %25 = "std.select"(%19, %tci32, %tci32) : (tensor<42 x i1>, tensor<42 x i32>, tensor<42 x i32>) -> tensor<42 x i32> + %25 = std.select %18, %tci32, %tci32 : tensor<42 x i32> // CHECK: %{{[0-9]+}} = divi_signed %arg2, %arg2 : i32 %26 = divi_signed %i, %i : i32 diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index c7b2905..80fdf33 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -281,18 +281,18 @@ func @func_with_ops(i1, i32, i64) { // ----- -func @func_with_ops(i1, vector<42xi32>, vector<42xi32>) { -^bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>): - // expected-error@+1 {{requires the same shape for all operands and results}} - %r = "std.select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32> +func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) { +^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>): + // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'vector<42xi1>', but got 'vector<12xi1>'}} + %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32> } // ----- -func @func_with_ops(i1, tensor<42xi32>, tensor) { -^bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor): - // expected-error@+1 {{ op requires the same shape for all operands and results}} - %r = "std.select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor) -> tensor<42xi32> +func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) { +^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>): + // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'tensor<42xi1>', but got 'tensor<12xi1>'}} + %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> } // ----- -- 2.7.4