[mlir][arith] Add `select` support to WIE
authorJakub Kuderski <kubak@google.com>
Wed, 9 Nov 2022 01:34:31 +0000 (20:34 -0500)
committerJakub Kuderski <kubak@google.com>
Wed, 9 Nov 2022 01:34:51 +0000 (20:34 -0500)
Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D137589

mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arith/emulate-wide-int.mlir

index 826c8ee..6d83bfa 100644 (file)
@@ -492,6 +492,41 @@ struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
 };
 
 //===----------------------------------------------------------------------===//
+// ConvertSelect
+//===----------------------------------------------------------------------===//
+
+struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto newTy = getTypeConverter()
+                     ->convertType(op.getType())
+                     .dyn_cast_or_null<VectorType>();
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", op.getType()));
+
+    auto [trueElem0, trueElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
+    auto [falseElem0, falseElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
+    Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
+
+    Value resElem0 =
+        rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
+    Value resElem1 =
+        rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
+    Value resultVec =
+        constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
+    rewriter.replaceOp(op, resultVec);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // ConvertShLI
 //===----------------------------------------------------------------------===//
 
@@ -828,7 +863,7 @@ void arith::populateArithWideIntEmulationPatterns(
   // Populate `arith.*` conversion patterns.
   patterns.add<
       // Misc ops.
-      ConvertConstant, ConvertVectorPrint,
+      ConvertConstant, ConvertVectorPrint, ConvertSelect,
       // Binary ops.
       ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRUI,
       // Bitwise binary ops.
index eebf1d6..b09aac0 100644 (file)
@@ -224,6 +224,46 @@ func.func @trunci_vector(%a : vector<3xi64>) -> vector<3xi16> {
     return %b : vector<3xi16>
 }
 
+// CHECK-LABEL: func.func @select_scalar
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>, [[ARG2:%.+]]: i1)
+// CHECK-SAME:    -> vector<2xi32>
+// CHECK-NEXT:    [[TLOW:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[THIGH:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT:    [[FLOW:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT:    [[FHIGH:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
+// CHECK-NEXT:    [[SLOW:%.+]] = arith.select [[ARG2]], [[TLOW]], [[FLOW]] : i32
+// CHECK-NEXT:    [[SHIGH:%.+]] = arith.select [[ARG2]], [[THIGH]], [[FHIGH]] : i32
+// CHECK-NEXT:    [[VZ:%.+]]   = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:    [[INS0:%.+]] = vector.insert [[SLOW]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]] = vector.insert [[SHIGH]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK:         return [[INS1]] : vector<2xi32>
+func.func @select_scalar(%a : i64, %b : i64, %c : i1) -> i64 {
+    %r = arith.select %c, %a, %b : i64
+    return %r : i64
+}
+
+// CHECK-LABEL: func.func @select_vector_whole
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>, [[ARG2:%.+]]: i1)
+// CHECK-SAME:    -> vector<3x2xi32>
+// CHECK:         arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK-NEXT:    arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         return {{%.+}} : vector<3x2xi32>
+func.func @select_vector_whole(%a : vector<3xi64>, %b : vector<3xi64>, %c : i1) -> vector<3xi64> {
+    %r = arith.select %c, %a, %b : vector<3xi64>
+    return %r : vector<3xi64>
+}
+
+// CHECK-LABEL: func.func @select_vector_elementwise
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<3x2xi32>, [[ARG1:%.+]]: vector<3x2xi32>, [[ARG2:%.+]]: vector<3xi1>)
+// CHECK-SAME:    -> vector<3x2xi32>
+// CHECK:         arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi1>, vector<3x1xi32>
+// CHECK-NEXT:    arith.select {{%.+}}, {{%.+}}, {{%.+}} : vector<3x1xi1>, vector<3x1xi32>
+// CHECK:         return {{%.+}} : vector<3x2xi32>
+func.func @select_vector_elementwise(%a : vector<3xi64>, %b : vector<3xi64>, %c : vector<3xi1>) -> vector<3xi64> {
+    %r = arith.select %c, %a, %b : vector<3xi1>, vector<3xi64>
+    return %r : vector<3xi64>
+}
+
 // CHECK-LABEL: func.func @muli_scalar
 // CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
 // CHECK-NEXT:    [[LOW0:%.+]]      = vector.extract [[ARG0]][0] : vector<2xi32>