optional attributes for fast math, contraction, rounding mode, and other
controls.
-#### 'tensor_cast' operation {#'tensor_cast'-operation}
+#### 'select' operation {#'select'-operation}
Syntax:
+``` {.ebnf}
+operation ::= ssa-id `=` `select` ssa-use, ssa-use, ssa-use `:` type
+```
+
+Examples:
+
```mlir {.mlir}
+// Short-hand notation of scalar selection.
+%x = select %cond, %true, %false : i32
+
+// Long-hand notation of the same operation.
+%x = "select"(%cond, %true, %false) : (i1, i32, i32) -> i32
+
+// Vector selection is element-wise
+%vx = "select"(%vcond, %vtrue, %vfalse)
+ : (vector<42xi1>, vector<42xf32>, vector<42xf32>) -> vector<42xf32>
+```
+
+The `select` operation chooses one value based on a binary condition supplied as
+its first operand. If the value of the first operand is `1`, the second operand
+is chosen, otherwise the third operand is chosen. The second and the third
+operand must have the same type.
+
+The operation applies to vectors and tensors elementwise given the _shape_ of
+all operands is identical. The choice is made for each element individually
+based on the value at the same position as the element in the condition operand.
+
+The `select` operation combined with [`cmpi`](#'cmpi'-operation) can be used to
+implement `min` and `max` with signed or unsigned comparison semantics.
+
+#### 'tensor_cast' operation {#'tensor_cast'-operation}
+
+Syntax:
+
+``` {.ebnf}
operation ::= ssa-id `=` `tensor_cast` ssa-use `:` type `to` type
```
mapped to the underlying integer values: `cmpi "eq", %lhs, %rhs` better implies
integer equality comparison than `cmpi 0, %lhs, %rhs` where it is unclear what
gets compared to what else. This syntactic sugar is possible thanks to parser
-logic redifinitions for short-hand notation of non-builtin operations.
+logic redefinitions for short-hand notation of non-builtin operations.
Supporting it in the full notation would have required changing how the main
parsing algorithm works and may have unexpected repercussions. While it had been
possible to store the predicate as string attribute, it would have rendered
impossible to implement switching logic based on the comparison kind and made
-attribute validity checks (one out of ten possibile kinds) more complex.
+attribute validity checks (one out of ten possible kinds) more complex.
+
+### 'select' operation to implement min/max {#select-operation}
+
+Although `min` and `max` operations are likely to occur as a result of
+transforming affine loops in ML functions, we did not make them first-class
+operations. Instead, we provide the `select` operation that can be combined with
+`cmpi` to implement the minimum and maximum computation. Although they now
+require two operations, they are likely to be emitted automatically during the
+transformation inside MLIR. On the other hand, there are multiple benefits of
+introducing `select`: standalone min/max would concern themselves with the
+signedness of the comparison, already taken into account by `cmpi`; `select` can
+support floats transparently if used after a float-comparison operation; the
+lower-level targets provide `select`-like instructions making the translation
+trivial.
+
+This operation could have been implemented with additional control flow: `%r =
+select %cond, %t, %f` is equivalent to
+
+```mlir
+bb0:
+ br_cond %cond, bb1(%t), bb1(%f)
+bb1(%r):
+```
+
+However, this control flow granularity is not available in the ML functions
+where min/max, and thus `select`, are likely to appear. In addition, simpler
+control flow may be beneficial for optimization in general.
### Quantized integer operations {#quantized-integer-operations}
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
-#include "third_party/llvm/llvm/projects/google-mlir/include/mlir/IR/OpDefinition.h"
namespace mlir {
class Builder;
explicit MulIOp(const Operation *state) : BinaryOp(state) {}
};
+class SelectOp : public Op<SelectOp, OpTrait::NOperands<3>::Impl,
+ OpTrait::OneResult, OpTrait::HasNoSideEffect> {
+public:
+ static StringRef getOperationName() { return "select"; }
+ static void build(Builder *builder, OperationState *result,
+ SSAValue *condition, SSAValue *trueValue,
+ SSAValue *falseValue);
+ static bool parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p) const;
+ bool verify() const;
+
+ SSAValue *getCondition() { return getOperand(0); }
+ const SSAValue *getCondition() const { return getOperand(0); }
+ SSAValue *getTrueValue() { return getOperand(1); }
+ const SSAValue *getTrueValue() const { return getOperand(1); }
+ SSAValue *getFalseValue() { return getOperand(2); }
+ const SSAValue *getFalseValue() const { return getOperand(2); }
+
+ Attribute constantFold(ArrayRef<Attribute> operands,
+ MLIRContext *context) const;
+
+private:
+ friend class Operation;
+ explicit SelectOp(const Operation *state) : Op(state) {}
+};
+
/// The "store" op writes an element to a memref specified by an index list.
/// The arity of indices is the rank of the memref (i.e. if the memref being
/// stored to is of rank 3, then 3 indices are required for the store following
: Dialect(/*opPrefix=*/"", context) {
addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, CmpIOp,
DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
- LoadOp, MemRefCastOp, MulFOp, MulIOp, StoreOp, SubFOp, SubIOp,
- TensorCastOp>();
+ LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp,
+ SubIOp, TensorCastOp>();
}
//===----------------------------------------------------------------------===//
results.push_back(std::make_unique<SimplifyMulX1>(context));
}
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+void SelectOp::build(Builder *builder, OperationState *result,
+ SSAValue *condition, SSAValue *trueValue,
+ SSAValue *falseValue) {
+ result->addOperands({condition, trueValue, falseValue});
+ result->addTypes(trueValue->getType());
+}
+
+bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 3> ops;
+ SmallVector<NamedAttribute, 4> attrs;
+ Type type;
+
+ if (parser->parseOperandList(ops, 3) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type))
+ return true;
+
+ auto i1Type = getI1SameShape(&parser->getBuilder(), type);
+ SmallVector<Type, 3> types = {i1Type, type, type};
+ return parser->resolveOperands(ops, types, parser->getNameLoc(),
+ result->operands) ||
+ parser->addTypeToList(type, result->types);
+}
+
+void SelectOp::print(OpAsmPrinter *p) const {
+ *p << getOperationName() << ' ';
+ p->printOperands(getOperation()->getOperands());
+ *p << " : " << getTrueValue()->getType();
+ p->printOptionalAttrDict(getAttrs());
+}
+
+bool SelectOp::verify() const {
+ auto conditionType = getCondition()->getType();
+ auto trueType = getTrueValue()->getType();
+ auto falseType = getFalseValue()->getType();
+
+ if (trueType != falseType)
+ return emitOpError(
+ "requires 'true' and 'false' arguments to be of the same type");
+
+ if (checkI1SameShape(trueType, conditionType))
+ return emitOpError("requires the condition to have the same shape as "
+ "arguments with elemental type i1");
+
+ return false;
+}
+
+Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
+ MLIRContext *context) const {
+ assert(operands.size() == 3 && "select takes three operands");
+
+ // select true, %0, %1 => %0
+ // select false, %0, %1 => %1
+ auto cond = operands[0].dyn_cast_or_null<IntegerAttr>();
+ if (!cond)
+ return {};
+
+ if (cond.getValue().isNullValue()) {
+ return operands[2];
+ } else if (cond.getValue().isOneValue()) {
+ return operands[1];
+ }
+
+ llvm_unreachable("first argument of select must be i1");
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
// CHECK: %{{[0-9]+}} = cmpi "eq", %cst_5, %cst_5 : vector<42xindex>
%20 = cmpi "eq", %cidx, %cidx : vector<42 x index>
+ // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
+ %21 = select %18, %idx, %idx : index
+
+ // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex>
+ %22 = select %19, %tidx, %tidx : tensor<42 x index>
+
+ // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_5, %cst_5 : vector<42xindex>
+ %23 = select %20, %cidx, %cidx : vector<42 x index>
+
+ // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %arg3, %arg3 : index
+ %24 = "select"(%18, %idx, %idx) : (i1, index, index) -> index
+
+ // CHECK: %{{[0-9]+}} = select %{{[0-9]+}}, %cst_4, %cst_4 : tensor<42xindex>
+ %25 = "select"(%19, %tidx, %tidx) : (tensor<42 x i1>, tensor<42 x index>, tensor<42 x index>) -> tensor<42 x index>
+
return
}
%r = "cmpi"(%c, %c) {predicate: 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<42 x i32>
}
+// -----
+
+cfgfunc @cfgfunc_with_ops(i32, i32, i32) {
+bb0(%cond : i32, %t : i32, %f : i32):
+ // expected-error@+2 {{different type than prior uses}}
+ // expected-error@-2 {{prior use here}}
+ %r = select %cond, %t, %f : i32
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(i32, i32, i32) {
+bb0(%cond : i32, %t : i32, %f : i32):
+ // expected-error@+1 {{elemental type i1}}
+ %r = "select"(%cond, %t, %f) : (i32, i32, i32) -> i32
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(i1, i32, i64) {
+bb0(%cond : i1, %t : i32, %f : i64):
+ // expected-error@+1 {{'true' and 'false' arguments to be of the same type}}
+ %r = "select"(%cond, %t, %f) : (i1, i32, i64) -> i32
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(i1, vector<42xi32>, vector<42xi32>) {
+bb0(%cond : i1, %t : vector<42xi32>, %f : vector<42xi32>):
+ // expected-error@+1 {{requires the condition to have the same shape as arguments}}
+ %r = "select"(%cond, %t, %f) : (i1, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
+bb0(%cond : i1, %t : tensor<42xi32>, %f : tensor<?xi32>):
+ // expected-error@+1 {{'true' and 'false' arguments to be of the same type}}
+ %r = "select"(%cond, %t, %f) : (i1, tensor<42xi32>, tensor<?xi32>) -> tensor<42xi32>
+}
+
+// -----
+
+cfgfunc @cfgfunc_with_ops(tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) {
+bb0(%cond : tensor<?xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
+ // expected-error@+1 {{requires the condition to have the same shape as arguments}}
+ %r = "select"(%cond, %t, %f) : (tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
+}