From a3fb6d0da307ebee9c4459d4296669362f0c7579 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 28 Nov 2018 07:08:55 -0800 Subject: [PATCH] StandardOps: introduce 'select'. The semantics of 'select' is conventional: return the second operand if the first operand is true (1 : i1) and the third operand otherwise. It is applicable to vectors and tensors element-wise, similarly to LLVM instruction. This operation is necessary to implement min/max to lower 'for' loops with complex bounds to CFG functions and to support ternary operations in ML functions. It is preferred to first-class min/max because of its simplicity, e.g. it is not concered with signedness. PiperOrigin-RevId: 223160860 --- mlir/g3doc/LangRef.md | 36 +++++++++- mlir/g3doc/Rationale.md | 31 ++++++++- mlir/include/mlir/StandardOps/StandardOps.h | 27 +++++++- mlir/lib/StandardOps/StandardOps.cpp | 73 ++++++++++++++++++++- mlir/test/IR/core-ops.mlir | 15 +++++ mlir/test/IR/invalid-ops.mlir | 48 ++++++++++++++ 6 files changed, 224 insertions(+), 6 deletions(-) diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index dc41e4bb8078..cc2d14b9230e 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -1900,11 +1900,45 @@ TODO: In the distant future, this will accept optional attributes for fast math, contraction, rounding mode, and other controls. -#### 'tensor_cast' operation {#'tensor_cast'-operation} +#### 'select' operation {#'select'-operation} Syntax: +``` {.ebnf} +operation ::= ssa-id `=` `select` ssa-use, ssa-use, ssa-use `:` type +``` + +Examples: + ```mlir {.mlir} +// Short-hand notation of scalar selection. +%x = select %cond, %true, %false : i32 + +// Long-hand notation of the same operation. +%x = "select"(%cond, %true, %false) : (i1, i32, i32) -> i32 + +// Vector selection is element-wise +%vx = "select"(%vcond, %vtrue, %vfalse) + : (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32> +``` + +The `select` operation chooses one value based on a binary condition supplied as +its first operand. If the value of the first operand is `1`, the second operand +is chosen, otherwise the third operand is chosen. The second and the third +operand must have the same type. + +The operation applies to vectors and tensors elementwise given the _shape_ of +all operands is identical. The choice is made for each element individually +based on the value at the same position as the element in the condition operand. + +The `select` operation combined with [`cmpi`](#'cmpi'-operation) can be used to +implement `min` and `max` with signed or unsigned comparison semantics. + +#### 'tensor_cast' operation {#'tensor_cast'-operation} + +Syntax: + +``` {.ebnf} operation ::= ssa-id `=` `tensor_cast` ssa-use `:` type `to` type ``` diff --git a/mlir/g3doc/Rationale.md b/mlir/g3doc/Rationale.md index d84d76f8dd31..d428d940c3c7 100644 --- a/mlir/g3doc/Rationale.md +++ b/mlir/g3doc/Rationale.md @@ -294,12 +294,39 @@ readability by humans, short-hand notation accepts string literals that are mapped to the underlying integer values: `cmpi "eq", %lhs, %rhs` better implies integer equality comparison than `cmpi 0, %lhs, %rhs` where it is unclear what gets compared to what else. This syntactic sugar is possible thanks to parser -logic redifinitions for short-hand notation of non-builtin operations. +logic redefinitions for short-hand notation of non-builtin operations. Supporting it in the full notation would have required changing how the main parsing algorithm works and may have unexpected repercussions. While it had been possible to store the predicate as string attribute, it would have rendered impossible to implement switching logic based on the comparison kind and made -attribute validity checks (one out of ten possibile kinds) more complex. +attribute validity checks (one out of ten possible kinds) more complex. + +### 'select' operation to implement min/max {#select-operation} + +Although `min` and `max` operations are likely to occur as a result of +transforming affine loops in ML functions, we did not make them first-class +operations. Instead, we provide the `select` operation that can be combined with +`cmpi` to implement the minimum and maximum computation. Although they now +require two operations, they are likely to be emitted automatically during the +transformation inside MLIR. On the other hand, there are multiple benefits of +introducing `select`: standalone min/max would concern themselves with the +signedness of the comparison, already taken into account by `cmpi`; `select` can +support floats transparently if used after a float-comparison operation; the +lower-level targets provide `select`-like instructions making the translation +trivial. + +This operation could have been implemented with additional control flow: `%r = +select %cond, %t, %f` is equivalent to + +```mlir +bb0: + br_cond %cond, bb1(%t), bb1(%f) +bb1(%r): +``` + +However, this control flow granularity is not available in the ML functions +where min/max, and thus `select`, are likely to appear. In addition, simpler +control flow may be beneficial for optimization in general. ### Quantized integer operations {#quantized-integer-operations} diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index a96b1a4ee2ea..ac1a8c7cfb4c 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -26,7 +26,6 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" -#include "third_party/llvm/llvm/projects/google-mlir/include/mlir/IR/OpDefinition.h" namespace mlir { class Builder; @@ -638,6 +637,32 @@ private: explicit MulIOp(const Operation *state) : BinaryOp(state) {} }; +class SelectOp : public Op::Impl, + OpTrait::OneResult, OpTrait::HasNoSideEffect> { +public: + static StringRef getOperationName() { return "select"; } + static void build(Builder *builder, OperationState *result, + SSAValue *condition, SSAValue *trueValue, + SSAValue *falseValue); + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + + SSAValue *getCondition() { return getOperand(0); } + const SSAValue *getCondition() const { return getOperand(0); } + SSAValue *getTrueValue() { return getOperand(1); } + const SSAValue *getTrueValue() const { return getOperand(1); } + SSAValue *getFalseValue() { return getOperand(2); } + const SSAValue *getFalseValue() const { return getOperand(2); } + + Attribute constantFold(ArrayRef operands, + MLIRContext *context) const; + +private: + friend class Operation; + explicit SelectOp(const Operation *state) : Op(state) {} +}; + /// The "store" op writes an element to a memref specified by an index list. /// The arity of indices is the rank of the memref (i.e. if the memref being /// stored to is of rank 3, then 3 indices are required for the store following diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 978835e88ab7..32a02e7b3835 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -39,8 +39,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(/*opPrefix=*/"", context) { addOperations(); + LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp, + SubIOp, TensorCastOp>(); } //===----------------------------------------------------------------------===// @@ -1085,6 +1085,75 @@ void MulIOp::getCanonicalizationPatterns(OwningPatternList &results, results.push_back(std::make_unique(context)); } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// +void SelectOp::build(Builder *builder, OperationState *result, + SSAValue *condition, SSAValue *trueValue, + SSAValue *falseValue) { + result->addOperands({condition, trueValue, falseValue}); + result->addTypes(trueValue->getType()); +} + +bool SelectOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector ops; + SmallVector attrs; + Type type; + + if (parser->parseOperandList(ops, 3) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type)) + return true; + + auto i1Type = getI1SameShape(&parser->getBuilder(), type); + SmallVector types = {i1Type, type, type}; + return parser->resolveOperands(ops, types, parser->getNameLoc(), + result->operands) || + parser->addTypeToList(type, result->types); +} + +void SelectOp::print(OpAsmPrinter *p) const { + *p << getOperationName() << ' '; + p->printOperands(getOperation()->getOperands()); + *p << " : " << getTrueValue()->getType(); + p->printOptionalAttrDict(getAttrs()); +} + +bool SelectOp::verify() const { + auto conditionType = getCondition()->getType(); + auto trueType = getTrueValue()->getType(); + auto falseType = getFalseValue()->getType(); + + if (trueType != falseType) + return emitOpError( + "requires 'true' and 'false' arguments to be of the same type"); + + if (checkI1SameShape(trueType, conditionType)) + return emitOpError("requires the condition to have the same shape as " + "arguments with elemental type i1"); + + return false; +} + +Attribute SelectOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + assert(operands.size() == 3 && "select takes three operands"); + + // select true, %0, %1 => %0 + // select false, %0, %1 => %1 + auto cond = operands[0].dyn_cast_or_null(); + if (!cond) + return {}; + + if (cond.getValue().isNullValue()) { + return operands[2]; + } else if (cond.getValue().isOneValue()) { + return operands[1]; + } + + llvm_unreachable("first argument of select must be i1"); +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 3e876e6a57f0..f9e0e5f94041 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -124,6 +124,21 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32, %idx : index): // CHECK: %{{[0-9]+}} = cmpi "eq", %cst_5, %cst_5 : vector<42xindex> %20 = cmpi "eq", %cidx, %cidx : vector<42 x index> + // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index + %21 = select %18, %idx, %idx : index + + // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex> + %22 = select %19, %tidx, %tidx : tensor<42 x index> + + // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xindex> + %23 = select %20, %cidx, %cidx : vector<42 x index> + + // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index + %24 = "select"(%18, %idx, %idx) : (i1, index, index) -> index + + // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex> + %25 = "select"(%19, %tidx, %tidx) : (tensor<42 x i1>, tensor<42 x index>, tensor<42 x index>) -> tensor<42 x index> + return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 5983d5c25530..a6ce86511849 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -230,3 +230,51 @@ bb0: %r = "cmpi"(%c, %c) {predicate: 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<42 x i32> } +// ----- + +cfgfunc @cfgfunc_with_ops(i32, i32, i32) { +bb0(%cond : i32, %t : i32, %f : i32): + // expected-error@+2 {{different type than prior uses}} + // expected-error@-2 {{prior use here}} + %r = select %cond, %t, %f : i32 +} + +// ----- + +cfgfunc @cfgfunc_with_ops(i32, i32, i32) { +bb0(%cond : i32, %t : i32, %f : i32): + // expected-error@+1 {{elemental type i1}} + %r = "select"(%cond, %t, %f) : (i32, i32, i32) -> i32 +} + +// ----- + +cfgfunc @cfgfunc_with_ops(i1, i32, i64) { +bb0(%cond : i1, %t : i32, %f : i64): + // expected-error@+1 {{'true' and 'false' arguments to be of the same type}} + %r = "select"(%cond, %t, %f) : (i1, i32, i64) -> i32 +} + +// ----- + +cfgfunc @cfgfunc_with_ops(i1, vector<42xi32>, vector<42xi32>) { +bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>): + // expected-error@+1 {{requires the condition to have the same shape as arguments}} + %r = "select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32> +} + +// ----- + +cfgfunc @cfgfunc_with_ops(i1, tensor<42xi32>, tensor) { +bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor): + // expected-error@+1 {{'true' and 'false' arguments to be of the same type}} + %r = "select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor) -> tensor<42xi32> +} + +// ----- + +cfgfunc @cfgfunc_with_ops(tensor, tensor<42xi32>, tensor<42xi32>) { +bb0(%cond : tensor, %t : tensor<42xi32>, %f : tensor<42xi32>): + // expected-error@+1 {{requires the condition to have the same shape as arguments}} + %r = "select"(%cond, %t, %f) : (tensor, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> +} -- 2.34.1