StandardOps: introduce 'select'.
authorAlex Zinenko <zinenko@google.com>
Wed, 28 Nov 2018 15:08:55 +0000 (07:08 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:11:25 +0000 (14:11 -0700)
The semantics of 'select' is conventional: return the second operand if the
first operand is true (1 : i1) and the third operand otherwise.  It is
applicable to vectors and tensors element-wise, similarly to LLVM instruction.
This operation is necessary to implement min/max to lower 'for' loops with
complex bounds to CFG functions and to support ternary operations in ML
functions.  It is preferred to first-class min/max because of its simplicity,
e.g. it is not concered with signedness.

PiperOrigin-RevId: 223160860

mlir/g3doc/LangRef.md
mlir/g3doc/Rationale.md
mlir/include/mlir/StandardOps/StandardOps.h
mlir/lib/StandardOps/StandardOps.cpp
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir

index dc41e4bb80788157d2f2ce4ce08a0d061b99083d..cc2d14b9230ecc533b3c91bf006f1ad5d2f7852f 100644 (file)
@@ -1900,11 +1900,45 @@ TODO: In the distant future, this will accept
 optional attributes for fast math, contraction, rounding mode, and other
 controls.
 
-#### 'tensor_cast' operation {#'tensor_cast'-operation}
+#### 'select' operation {#'select'-operation}
 
 Syntax:
 
+``` {.ebnf}
+operation ::= ssa-id `=` `select` ssa-use, ssa-use, ssa-use `:` type
+```
+
+Examples:
+
 ```mlir {.mlir}
+// Short-hand notation of scalar selection.
+%x = select %cond, %true, %false : i32
+
+// Long-hand notation of the same operation.
+%x = "select"(%cond, %true, %false) : (i1, i32, i32) -> i32
+
+// Vector selection is element-wise
+%vx = "select"(%vcond, %vtrue, %vfalse)
+    : (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32>
+```
+
+The `select` operation chooses one value based on a binary condition supplied as
+its first operand. If the value of the first operand is `1`, the second operand
+is chosen, otherwise the third operand is chosen. The second and the third
+operand must have the same type.
+
+The operation applies to vectors and tensors elementwise given the _shape_ of
+all operands is identical. The choice is made for each element individually
+based on the value at the same position as the element in the condition operand.
+
+The `select` operation combined with [`cmpi`](#'cmpi'-operation) can be used to
+implement `min` and `max` with signed or unsigned comparison semantics.
+
+#### 'tensor_cast' operation {#'tensor_cast'-operation}
+
+Syntax:
+
+``` {.ebnf}
 operation ::= ssa-id `=` `tensor_cast` ssa-use `:` type `to` type
 ```
 
index d84d76f8dd31d373171bd05d27f9c0779892bafc..d428d940c3c7f3e6dfe69d90e976d939877c073e 100644 (file)
@@ -294,12 +294,39 @@ readability by humans, short-hand notation accepts string literals that are
 mapped to the underlying integer values: `cmpi "eq", %lhs, %rhs` better implies
 integer equality comparison than `cmpi 0, %lhs, %rhs` where it is unclear what
 gets compared to what else. This syntactic sugar is possible thanks to parser
-logic redifinitions for short-hand notation of non-builtin operations.
+logic redefinitions for short-hand notation of non-builtin operations.
 Supporting it in the full notation would have required changing how the main
 parsing algorithm works and may have unexpected repercussions. While it had been
 possible to store the predicate as string attribute, it would have rendered
 impossible to implement switching logic based on the comparison kind and made
-attribute validity checks (one out of ten possibile kinds) more complex.
+attribute validity checks (one out of ten possible kinds) more complex.
+
+### 'select' operation to implement min/max {#select-operation}
+
+Although `min` and `max` operations are likely to occur as a result of
+transforming affine loops in ML functions, we did not make them first-class
+operations. Instead, we provide the `select` operation that can be combined with
+`cmpi` to implement the minimum and maximum computation. Although they now
+require two operations, they are likely to be emitted automatically during the
+transformation inside MLIR. On the other hand, there are multiple benefits of
+introducing `select`: standalone min/max would concern themselves with the
+signedness of the comparison, already taken into account by `cmpi`; `select` can
+support floats transparently if used after a float-comparison operation; the
+lower-level targets provide `select`-like instructions making the translation
+trivial.
+
+This operation could have been implemented with additional control flow: `%r =
+select %cond, %t, %f` is equivalent to
+
+```mlir
+bb0:
+  br_cond %cond, bb1(%t), bb1(%f)
+bb1(%r):
+```
+
+However, this control flow granularity is not available in the ML functions
+where min/max, and thus `select`, are likely to appear. In addition, simpler
+control flow may be beneficial for optimization in general.
 
 ### Quantized integer operations {#quantized-integer-operations}
 
index a96b1a4ee2ea3b6753e6ec0c14dee749742901fa..ac1a8c7cfb4cb4874abf523343c549ebbe446bbc 100644 (file)
@@ -26,7 +26,6 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
-#include "third_party/llvm/llvm/projects/google-mlir/include/mlir/IR/OpDefinition.h"
 
 namespace mlir {
 class Builder;
@@ -638,6 +637,32 @@ private:
   explicit MulIOp(const Operation *state) : BinaryOp(state) {}
 };
 
+class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
+                           OpTrait::OneResult, OpTrait::HasNoSideEffect> {
+public:
+  static StringRef getOperationName() { return "select"; }
+  static void build(Builder *builder, OperationState *result,
+                    SSAValue *condition, SSAValue *trueValue,
+                    SSAValue *falseValue);
+  static bool parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p) const;
+  bool verify() const;
+
+  SSAValue *getCondition() { return getOperand(0); }
+  const SSAValue *getCondition() const { return getOperand(0); }
+  SSAValue *getTrueValue() { return getOperand(1); }
+  const SSAValue *getTrueValue() const { return getOperand(1); }
+  SSAValue *getFalseValue() { return getOperand(2); }
+  const SSAValue *getFalseValue() const { return getOperand(2); }
+
+  Attribute constantFold(ArrayRef<Attribute> operands,
+                         MLIRContext *context) const;
+
+private:
+  friend class Operation;
+  explicit SelectOp(const Operation *state) : Op(state) {}
+};
+
 /// The "store" op writes an element to a memref specified by an index list.
 /// The arity of indices is the rank of the memref (i.e. if the memref being
 /// stored to is of rank 3, then 3 indices are required for the store following
index 978835e88ab76c3bb8e4197df129884fda4f6573..32a02e7b38359d78ba1e30c36ae132a5b544414d 100644 (file)
@@ -39,8 +39,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
     : Dialect(/*opPrefix=*/"", context) {
   addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, CmpIOp,
                 DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
-                LoadOp, MemRefCastOp, MulFOp, MulIOp, StoreOp, SubFOp, SubIOp,
-                TensorCastOp>();
+                LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp,
+                SubIOp, TensorCastOp>();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1085,6 +1085,75 @@ void MulIOp::getCanonicalizationPatterns(OwningPatternList &results,
   results.push_back(std::make_unique<SimplifyMulX1>(context));
 }
 
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+void SelectOp::build(Builder *builder, OperationState *result,
+                     SSAValue *condition, SSAValue *trueValue,
+                     SSAValue *falseValue) {
+  result->addOperands({condition, trueValue, falseValue});
+  result->addTypes(trueValue->getType());
+}
+
+bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 3> ops;
+  SmallVector<NamedAttribute, 4> attrs;
+  Type type;
+
+  if (parser->parseOperandList(ops, 3) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(type))
+    return true;
+
+  auto i1Type = getI1SameShape(&parser->getBuilder(), type);
+  SmallVector<Type, 3> types = {i1Type, type, type};
+  return parser->resolveOperands(ops, types, parser->getNameLoc(),
+                                 result->operands) ||
+         parser->addTypeToList(type, result->types);
+}
+
+void SelectOp::print(OpAsmPrinter *p) const {
+  *p << getOperationName() << ' ';
+  p->printOperands(getOperation()->getOperands());
+  *p << " : " << getTrueValue()->getType();
+  p->printOptionalAttrDict(getAttrs());
+}
+
+bool SelectOp::verify() const {
+  auto conditionType = getCondition()->getType();
+  auto trueType = getTrueValue()->getType();
+  auto falseType = getFalseValue()->getType();
+
+  if (trueType != falseType)
+    return emitOpError(
+        "requires 'true' and 'false' arguments to be of the same type");
+
+  if (checkI1SameShape(trueType, conditionType))
+    return emitOpError("requires the condition to have the same shape as "
+                       "arguments with elemental type i1");
+
+  return false;
+}
+
+Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
+                                 MLIRContext *context) const {
+  assert(operands.size() == 3 && "select takes three operands");
+
+  // select true, %0, %1 => %0
+  // select false, %0, %1 => %1
+  auto cond = operands[0].dyn_cast_or_null<IntegerAttr>();
+  if (!cond)
+    return {};
+
+  if (cond.getValue().isNullValue()) {
+    return operands[2];
+  } else if (cond.getValue().isOneValue()) {
+    return operands[1];
+  }
+
+  llvm_unreachable("first argument of select must be i1");
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
index 3e876e6a57f0f485933ef17def001dd76b64bf82..f9e0e5f9404165cc6a2f64f5940dd6086d281429 100644 (file)
@@ -124,6 +124,21 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32, %idx : index):
   // CHECK: %{{[0-9]+}} = cmpi "eq", %cst_5, %cst_5 : vector<42xindex>
   %20 = cmpi "eq", %cidx, %cidx : vector<42 x index>
 
+  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
+  %21 = select %18, %idx, %idx : index
+
+  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex>
+  %22 = select %19, %tidx, %tidx : tensor<42 x index>
+
+  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xindex>
+  %23 = select %20, %cidx, %cidx : vector<42 x index>
+
+  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
+  %24 = "select"(%18, %idx, %idx) : (i1, index, index) -> index
+
+  // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex>
+  %25 = "select"(%19, %tidx, %tidx) : (tensor<42 x i1>, tensor<42 x index>, tensor<42 x index>) -> tensor<42 x index>
+
   return
 }
 
index 5983d5c255304da1d65eef5b8188ec66187a1020..a6ce865118491217e9e44ac322aaad17309b55ae 100644 (file)
@@ -230,3 +230,51 @@ bb0:
   %r = "cmpi"(%c, %c) {predicate: 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<42 x i32>
 }
 
+// -----
+
+cfgfunc @cfgfunc_with_ops(i32, i32, i32) {
+bb0(%cond : i32, %t : i32, %f : i32):
+  // expected-error@+2 {{different type than prior uses}}
+  // expected-error@-2 {{prior use here}}
+  %r = select %cond, %t, %f : i32
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(i32, i32, i32) {
+bb0(%cond : i32, %t : i32, %f : i32):
+  // expected-error@+1 {{elemental type i1}}
+  %r = "select"(%cond, %t, %f) : (i32, i32, i32) -> i32
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(i1, i32, i64) {
+bb0(%cond : i1, %t : i32, %f : i64):
+  // expected-error@+1 {{'true' and 'false' arguments to be of the same type}}
+  %r = "select"(%cond, %t, %f) : (i1, i32, i64) -> i32
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(i1, vector<42xi32>, vector<42xi32>) {
+bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>):
+  // expected-error@+1 {{requires the condition to have the same shape as arguments}}
+  %r = "select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
+bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor<?xi32>):
+  // expected-error@+1 {{'true' and 'false' arguments to be of the same type}}
+  %r = "select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor<?xi32>) -> tensor<42xi32>
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) {
+bb0(%cond : tensor<?xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
+  // expected-error@+1 {{requires the condition to have the same shape as arguments}}
+  %r = "select"(%cond, %t, %f) : (tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
+}