[mlir][Standard] Allow select to use an i1 for vector and tensor values
authorRiver Riddle <riddleriver@gmail.com>
Thu, 23 Apr 2020 11:40:42 +0000 (04:40 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 23 Apr 2020 11:50:09 +0000 (04:50 -0700)
It currently requires that the condition match the shape of the selected value, but this is only really useful for things like masks. This revision allows for the use of i1 to mean that all of the vector/tensor is selected. This also matches the behavior of LLVM select. A benefit of this change is that transformations that want to generate selects, like those on the CFG, don't have to special case vector/tensor. Previously the only way to generate  a select from an i1 was to use a splat, but that doesn't support dynamically shaped/unranked tensors.

Differential Revision: https://reviews.llvm.org/D78690

llvm/include/llvm/ADT/STLExtras.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize-cf.mlir
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir

index c10f4d2..d9a1d0c 100644 (file)
@@ -1124,7 +1124,8 @@ public:
 
   /// Compare this range with another.
   template <typename OtherT> bool operator==(const OtherT &other) const {
-    return size() == std::distance(other.begin(), other.end()) &&
+    return size() ==
+               static_cast<size_t>(std::distance(other.begin(), other.end())) &&
            std::equal(begin(), end(), other.begin());
   }
   template <typename OtherT> bool operator!=(const OtherT &other) const {
index 39c9597..54800a5 100644 (file)
@@ -1915,11 +1915,8 @@ def RsqrtOp : FloatUnaryOp<"rsqrt"> {
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
-     AllTypesMatch<["true_value", "false_value", "result"]>,
-     TypesMatchWith<"condition type matches i1 equivalent of result type",
-                     "result", "condition",
-                     "getI1SameShape($_self)">]> {
+def SelectOp : Std_Op<"select", [NoSideEffect,
+     AllTypesMatch<["true_value", "false_value", "result"]>]> {
   let summary = "select operation";
   let description = [{
     The `select` operation chooses one value based on a binary condition
@@ -1930,7 +1927,8 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
     The operation applies to vectors and tensors elementwise given the _shape_
     of all operands is identical. The choice is made for each element
     individually based on the value at the same position as the element in the
-    condition operand.
+    condition operand. If an i1 is provided as the condition, the entire vector
+    or tensor is chosen.
 
     The `select` operation combined with [`cmpi`](#stdcmpi-cmpiop) can be used
     to implement `min` and `max` with signed or unsigned comparison semantics.
@@ -1944,9 +1942,11 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
     // Generic form of the same operation.
     %x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32
 
-    // Vector selection is element-wise
-    %vx = "std.select"(%vcond, %vtrue, %vfalse)
-        : (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32>
+    // Element-wise vector selection.
+    %vx = std.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32>
+
+    // Full vector selection.
+    %vx = std.select %cond, %vtrue, %vfalse : vector<42xf32>
     ```
   }];
 
@@ -1954,7 +1954,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
                        AnyType:$true_value,
                        AnyType:$false_value);
   let results = (outs AnyType:$result);
-  let verifier = ?;
 
   let builders = [OpBuilder<
     "Builder *builder, OperationState &result, Value condition,"
@@ -1970,10 +1969,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,
   }];
 
   let hasFolder = 1;
-
-  let assemblyFormat = [{
-    $condition `,` $true_value `,` $false_value attr-dict `:` type($result)
-  }];
 }
 
 //===----------------------------------------------------------------------===//
index 3294210..bf4bfc8 100644 (file)
@@ -999,15 +999,6 @@ struct SimplifyCondBranchIdenticalSuccessors
     if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock())
       return failure();
 
-    // TODO: ATM Tensor/Vector SelectOp requires that the condition has the same
-    // shape as the operands. We should relax that to allow an i1 to signify
-    // that everything is selected.
-    auto doesntSupportsScalarI1 = [](Type type) {
-      return type.isa<TensorType>() || type.isa<VectorType>();
-    };
-    if (llvm::any_of(trueOperands.getTypes(), doesntSupportsScalarI1))
-      return failure();
-
     // Generate a select for any operands that differ between the two.
     SmallVector<Value, 8> mergedOperands;
     mergedOperands.reserve(trueOperands.size());
@@ -1925,6 +1916,59 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
   return nullptr;
 }
 
+static void print(OpAsmPrinter &p, SelectOp op) {
+  p << "select " << op.getOperands();
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : ";
+  if (ShapedType condType = op.getCondition().getType().dyn_cast<ShapedType>())
+    p << condType << ", ";
+  p << op.getType();
+}
+
+static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
+  Type conditionType, resultType;
+  SmallVector<OpAsmParser::OperandType, 3> operands;
+  if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(resultType))
+    return failure();
+
+  // Check for the explicit condition type if this is a masked tensor or vector.
+  if (succeeded(parser.parseOptionalComma())) {
+    conditionType = resultType;
+    if (parser.parseType(resultType))
+      return failure();
+  } else {
+    conditionType = parser.getBuilder().getI1Type();
+  }
+
+  result.addTypes(resultType);
+  return parser.resolveOperands(operands,
+                                {conditionType, resultType, resultType},
+                                parser.getNameLoc(), result.operands);
+}
+
+static LogicalResult verify(SelectOp op) {
+  Type conditionType = op.getCondition().getType();
+  if (conditionType.isSignlessInteger(1))
+    return success();
+
+  // If the result type is a vector or tensor, the type can be a mask with the
+  // same elements.
+  Type resultType = op.getType();
+  if (!resultType.isa<TensorType>() && !resultType.isa<VectorType>())
+    return op.emitOpError()
+           << "expected condition to be a signless i1, but got "
+           << conditionType;
+  Type shapedConditionType = getI1SameShape(resultType);
+  if (conditionType != shapedConditionType)
+    return op.emitOpError()
+           << "expected condition type to have the same shape "
+              "as the result type, expected "
+           << shapedConditionType << ", but got " << conditionType;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SignExtendIOp
 //===----------------------------------------------------------------------===//
index 71ee7f1..b0fd844 100644 (file)
@@ -69,39 +69,18 @@ func @cond_br_same_successor(%cond : i1, %a : i32) {
 
 // CHECK-LABEL: func @cond_br_same_successor_insert_select(
 // CHECK-SAME: %[[COND:.*]]: i1, %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
-func @cond_br_same_successor_insert_select(%cond : i1, %a : i32, %b : i32) -> i32 {
+// CHECK-SAME: %[[ARG2:.*]]: tensor<2xi32>, %[[ARG3:.*]]: tensor<2xi32>
+func @cond_br_same_successor_insert_select(
+      %cond : i1, %a : i32, %b : i32, %c : tensor<2xi32>, %d : tensor<2xi32>
+    ) -> (i32, tensor<2xi32>)  {
   // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG1]]
-  // CHECK: return %[[RES]]
-
-  cond_br %cond, ^bb1(%a : i32), ^bb1(%b : i32)
-
-^bb1(%result : i32):
-  return %result : i32
-}
-
-/// Check that we don't generate a select if the type requires a splat.
-/// TODO: SelectOp should allow for matching a vector/tensor with i1.
-
-// CHECK-LABEL: func @cond_br_same_successor_no_select_tensor(
-func @cond_br_same_successor_no_select_tensor(%cond : i1, %a : tensor<2xi32>,
-                                              %b : tensor<2xi32>) -> tensor<2xi32>{
-  // CHECK: cond_br
-
-  cond_br %cond, ^bb1(%a : tensor<2xi32>), ^bb1(%b : tensor<2xi32>)
-
-^bb1(%result : tensor<2xi32>):
-  return %result : tensor<2xi32>
-}
-
-// CHECK-LABEL: func @cond_br_same_successor_no_select_vector(
-func @cond_br_same_successor_no_select_vector(%cond : i1, %a : vector<2xi32>,
-                                              %b : vector<2xi32>) -> vector<2xi32> {
-  // CHECK: cond_br
+  // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG2]], %[[ARG3]]
+  // CHECK: return %[[RES]], %[[RES2]]
 
-  cond_br %cond, ^bb1(%a : vector<2xi32>), ^bb1(%b : vector<2xi32>)
+  cond_br %cond, ^bb1(%a, %c : i32, tensor<2xi32>), ^bb1(%b, %d : i32, tensor<2xi32>)
 
-^bb1(%result : vector<2xi32>):
-  return %result : vector<2xi32>
+^bb1(%result : i32, %result2 : tensor<2xi32>):
+  return %result, %result2 : i32, tensor<2xi32>
 }
 
 /// Test the compound folding of BranchOp and CondBranchOp.
index d19f344..d0a27ec 100644 (file)
@@ -141,17 +141,17 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
   %21 = select %18, %idx, %idx : index
 
-  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32>
-  %22 = select %19, %tci32, %tci32 : tensor<42 x i32>
+  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi1>, tensor<42xi32>
+  %22 = select %19, %tci32, %tci32 : tensor<42 x i1>, tensor<42 x i32>
 
-  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi32>
-  %23 = select %20, %vci32, %vci32 : vector<42 x i32>
+  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xi1>, vector<42xi32>
+  %23 = select %20, %vci32, %vci32 : vector<42 x i1>, vector<42 x i32>
 
   // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
   %24 = "std.select"(%18, %idx, %idx) : (i1, index, index) -> index
 
   // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xi32>
-  %25 = "std.select"(%19, %tci32, %tci32) : (tensor<42 x i1>, tensor<42 x i32>, tensor<42 x i32>) -> tensor<42 x i32>
+  %25 = std.select %18, %tci32, %tci32 : tensor<42 x i32>
 
   // CHECK: %{{[0-9]+}} = divi_signed %arg2, %arg2 : i32
   %26 = divi_signed %i, %i : i32
index c7b2905..80fdf33 100644 (file)
@@ -281,18 +281,18 @@ func @func_with_ops(i1, i32, i64) {
 
 // -----
 
-func @func_with_ops(i1, vector<42xi32>, vector<42xi32>) {
-^bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>):
-  // expected-error@+1 {{requires the same shape for all operands and results}}
-  %r = "std.select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
+func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
+^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
+  // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'vector<42xi1>', but got 'vector<12xi1>'}}
+  %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
 }
 
 // -----
 
-func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
-^bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor<?xi32>):
-  // expected-error@+1 {{ op requires the same shape for all operands and results}}
-  %r = "std.select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor<?xi32>) -> tensor<42xi32>
+func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
+^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
+  // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'tensor<42xi1>', but got 'tensor<12xi1>'}}
+  %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
 }
 
 // -----