};
//===----------------------------------------------------------------------===//
+// ConvertSelect
+//===----------------------------------------------------------------------===//
+
+struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto newTy = getTypeConverter()
+ ->convertType(op.getType())
+ .dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {0}", op.getType()));
+
+ auto [trueElem0, trueElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
+ auto [falseElem0, falseElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
+ Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
+
+ Value resElem0 =
+ rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
+ Value resElem1 =
+ rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
+ Value resultVec =
+ constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
// ConvertShLI
//===----------------------------------------------------------------------===//
// Populate `arith.*` conversion patterns.
patterns.add<
// Misc ops.
- ConvertConstant, ConvertVectorPrint,
+ ConvertConstant, ConvertVectorPrint, ConvertSelect,
// Binary ops.
ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRUI,
// Bitwise binary ops.
return %b : vector<3xi16>
}
+// CHECK-LABEL: func.func @select_scalar
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>, [[ARG2:%.+]]: i1)
+// CHECK-SAME: -> vector<2xi32>
+// CHECK-NEXT: [[TLOW:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT: [[THIGH:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[FLOW:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[FHIGH:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT: [[SLOW:%.+]] = arith.select [[ARG2]], [[TLOW]], [[FLOW]] : i32
+// CHECK-NEXT: [[SHIGH:%.+]] = arith.select [[ARG2]], [[THIGH]], [[FHIGH]] : i32
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[SLOW]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SHIGH]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK: return [[INS1]] : vector<2xi32>
+func.func @select_scalar(%a : i64, %b : i64, %c : i1) -> i64 {
+ %r = arith.select %c, %a, %b : i64
+ return %r : i64
+}
+
+// CHECK-LABEL: func.func @select_vector_whole
+// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>, [[ARG2:%.+]]: i1)
+// CHECK-SAME: -> vector<3x2xi32>
+// CHECK: arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK-NEXT: arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK: return {{%.+}} : vector<3x2xi32>
+func.func @select_vector_whole(%a : vector<3xi64>, %b : vector<3xi64>, %c : i1) -> vector<3xi64> {
+ %r = arith.select %c, %a, %b : vector<3xi64>
+ return %r : vector<3xi64>
+}
+
+// CHECK-LABEL: func.func @select_vector_elementwise
+// CHECK-SAME: ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>, [[ARG2:%.+]]: vector<3xi1>)
+// CHECK-SAME: -> vector<3x2xi32>
+// CHECK: arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi1>, vector<3x1xi32>
+// CHECK-NEXT: arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi1>, vector<3x1xi32>
+// CHECK: return {{%.+}} : vector<3x2xi32>
+func.func @select_vector_elementwise(%a : vector<3xi64>, %b : vector<3xi64>, %c : vector<3xi1>) -> vector<3xi64> {
+ %r = arith.select %c, %a, %b : vector<3xi1>, vector<3xi64>
+ return %r : vector<3xi64>
+}
+
// CHECK-LABEL: func.func @muli_scalar
// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>