[mlir][arith] Add shli support to WIE
authorJakub Kuderski <kubak@google.com>
Wed, 5 Oct 2022 19:09:45 +0000 (15:09 -0400)
committerJakub Kuderski <kubak@google.com>
Wed, 5 Oct 2022 19:09:58 +0000 (15:09 -0400)
Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D135234

mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arith/emulate-wide-int.mlir
mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir
mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir [new file with mode: 0644]

index ea13207..c53abbc 100644 (file)
@@ -486,6 +486,95 @@ struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
 };
 
 //===----------------------------------------------------------------------===//
+// ConvertShLI
+//===----------------------------------------------------------------------===//
+
+struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+
+    Type oldTy = op.getType();
+    auto newTy =
+        getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
+    if (!newTy)
+      return rewriter.notifyMatchFailure(loc, "unsupported type");
+
+    Type newOperandTy = reduceInnermostDim(newTy);
+    // `oldBitWidth` == `2 * newBitWidth`
+    unsigned newBitWidth = newTy.getElementTypeBitWidth();
+
+    auto [lhsElem0, lhsElem1] =
+        extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+    Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
+
+    // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
+    // high halves of the results separately:
+    //   1. low := LHS.low shli RHS
+    //
+    //   2. high := a or b or c, where:
+    //     a) Bits from LHS.high, shifted by the RHS.
+    //     b) Bits from LHS.low, shifted right. These come into play when
+    //        RHS < newBitWidth, e.g.:
+    //         [0000][llll] shli 3 --> [0lll][l000]
+    //                                    ^
+    //                                    |
+    //                           [llll] shrui (4 - 3)
+    //     c) Bits from LHS.low, shifted left. These matter when
+    //        RHS > newBitWidth, e.g.:
+    //         [0000][llll] shli 7 --> [l000][0000]
+    //                                   ^
+    //                                   |
+    //                          [llll] shli (7 - 4)
+    //
+    // Because shifts by values >= newBitWidth are undefined, we ignore the high
+    // half of RHS, and introduce 'bounds checks' to account for
+    // RHS.low > newBitWidth.
+    //
+    // TODO: Explore possible optimizations.
+    Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
+    Value elemBitWidth =
+        createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
+
+    Value illegalElemShift = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
+
+    Value shiftedElem0 =
+        rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
+    Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
+                                                      zeroCst, shiftedElem0);
+
+    Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
+        loc, illegalElemShift, elemBitWidth, rhsElem0);
+    Value rightShiftAmount =
+        rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
+    Value shiftedRight =
+        rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
+    Value overshotShiftAmount =
+        rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
+    Value shiftedLeft =
+        rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
+
+    Value shiftedElem1 =
+        rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
+    Value resElem1High = rewriter.create<arith::SelectOp>(
+        loc, illegalElemShift, zeroCst, shiftedElem1);
+    Value resElem1Low = rewriter.create<arith::SelectOp>(
+        loc, illegalElemShift, shiftedLeft, shiftedRight);
+    Value resElem1 =
+        rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
+
+    Value resultVec =
+        constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
+    rewriter.replaceOp(op, resultVec);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
 // ConvertShRUI
 //===----------------------------------------------------------------------===//
 
@@ -498,8 +587,13 @@ struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
     Location loc = op->getLoc();
 
     Type oldTy = op.getType();
-    auto newTy = getTypeConverter()->convertType(oldTy).cast<VectorType>();
+    auto newTy =
+        getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
+    if (!newTy)
+      return rewriter.notifyMatchFailure(loc, "unsupported type");
+
     Type newOperandTy = reduceInnermostDim(newTy);
+    // `oldBitWidth` == `2 * newBitWidth`
     unsigned newBitWidth = newTy.getElementTypeBitWidth();
 
     auto [lhsElem0, lhsElem1] =
@@ -727,7 +821,7 @@ void arith::populateWideIntEmulationPatterns(
       // Misc ops.
       ConvertConstant, ConvertVectorPrint,
       // Binary ops.
-      ConvertAddI, ConvertMulI, ConvertShRUI,
+      ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRUI,
       // Bitwise binary ops.
       ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
       ConvertBitwiseBinary<arith::XOrIOp>,
index 59451f5..eebf1d6 100644 (file)
@@ -278,6 +278,46 @@ func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64>
     return %m : vector<3xi64>
 }
 
+// CHECK-LABEL: func.func @shli_scalar
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[LOW0:%.+]]     = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT:    [[HIGH0:%.+]]    = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT:    [[LOW1:%.+]]     = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT:    [[CST0:%.+]]     = arith.constant 0 : i32
+// CHECK-NEXT:    [[CST32:%.+]]    = arith.constant 32 : i32
+// CHECK-NEXT:    [[OOB:%.+]]      = arith.cmpi uge, [[LOW1]], [[CST32]] : i32
+// CHECK-NEXT:    [[SHLOW0:%.+]]   = arith.shli [[LOW0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[RES0:%.+]]     = arith.select [[OOB]], [[CST0]], [[SHLOW0]] : i32
+// CHECK-NEXT:    [[SHAMT:%.+]]    = arith.select [[OOB]], [[CST32]], [[LOW1]] : i32
+// CHECK-NEXT:    [[RSHAMT:%.+]]   = arith.subi [[CST32]], [[SHAMT]] : i32
+// CHECK-NEXT:    [[SHRHIGH0:%.+]] = arith.shrui [[LOW0]], [[RSHAMT]] : i32
+// CHECK-NEXT:    [[LSHAMT:%.+]]   = arith.subi [[LOW1]], [[CST32]] : i32
+// CHECK-NEXT:    [[SHLHIGH0:%.+]] = arith.shli [[LOW0]], [[LSHAMT]] : i32
+// CHECK-NEXT:    [[SHLHIGH1:%.+]] = arith.shli [[HIGH0]], [[LOW1]] : i32
+// CHECK-NEXT:    [[RES1HIGH:%.+]] = arith.select [[OOB]], [[CST0]], [[SHLHIGH1]] : i32
+// CHECK-NEXT:    [[RES1LOW:%.+]]  = arith.select [[OOB]], [[SHLHIGH0]], [[SHRHIGH0]] : i32
+// CHECK-NEXT:    [[RES1:%.+]]     = arith.ori [[RES1LOW]], [[RES1HIGH]] : i32
+// CHECK-NEXT:    [[VZ:%.+]]       = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT:    [[INS0:%.+]]     = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT:    [[INS1:%.+]]     = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT:    return [[INS1]] : vector<2xi32>
+func.func @shli_scalar(%a : i64, %b : i64) -> i64 {
+    %c = arith.shli %a, %b : i64
+    return %c : i64
+}
+
+// CHECK-LABEL: func.func @shli_vector
+// CHECK-SAME:    ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK:         {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:       return {{%.+}} : vector<3x2xi32>
+func.func @shli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+    %m = arith.shli %a, %b : vector<3xi64>
+    return %m : vector<3xi64>
+}
+
 // CHECK-LABEL: func.func @shrui_scalar
 // CHECK-SAME:    ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
 // CHECK-NEXT:    [[LOW0:%.+]]     = vector.extract [[ARG0]][0] : vector<2xi32>
@@ -326,6 +366,10 @@ func.func @shrui_scalar_cst_36(%a : i64) -> i64 {
 
 // CHECK-LABEL: func.func @shrui_vector
 // CHECK-SAME:    ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK:         {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32>
+// CHECK:         {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32>
 // CHECK:       return {{%.+}} : vector<3x2xi32>
 func.func @shrui_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
     %m = arith.shrui %a, %b : vector<3xi64>
index 6ca2790..16e8634 100644 (file)
@@ -157,6 +157,53 @@ func.func @test_muli() -> () {
 }
 
 //===----------------------------------------------------------------------===//
+// Test arith.shli
+//===----------------------------------------------------------------------===//
+
+// Ops in this function will be emulated using i8 ops.
+func.func @emulate_shli(%lhs : i16, %rhs : i16) -> (i16) {
+  %res = arith.shli %lhs, %rhs : i16
+  return %res : i16
+}
+
+// Performs both wide and emulated `arith.shli`, and checks that the results
+// match.
+func.func @check_shli(%lhs : i16, %rhs : i16) -> () {
+  %wide = arith.shli %lhs, %rhs : i16
+  %emulated = func.call @emulate_shli(%lhs, %rhs) : (i16, i16) -> (i16)
+  func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
+  return
+}
+
+// Checks that `arith.shli` is emulated properly by sampling the input space.
+// Checks all valid shift amounts for i16: 0 to 15.
+// In total, this test function checks 100 * 16 = 1.6k input pairs.
+func.func @test_shli() -> () {
+  %idx0 = arith.constant 0 : index
+  %idx1 = arith.constant 1 : index
+  %idx16 = arith.constant 16 : index
+  %idx100 = arith.constant 100 : index
+
+  %cst0 = arith.constant 0 : i16
+  %cst1 = arith.constant 1 : i16
+
+  scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
+    %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
+
+    scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
+        func.call @check_shli(%arg_lhs, %rhs) : (i16, i16) -> ()
+        %rhs_next = arith.addi %rhs, %cst1 : i16
+        scf.yield %rhs_next : i16
+    }
+
+    %lhs_next = arith.addi %lhs, %cst1 : i16
+    scf.yield %lhs_next : i16
+  }
+
+  return
+}
+
+//===----------------------------------------------------------------------===//
 // Test arith.shrui
 //===----------------------------------------------------------------------===//
 
@@ -210,6 +257,7 @@ func.func @test_shrui() -> () {
 func.func @entry() {
   func.call @test_addi() : () -> ()
   func.call @test_muli() : () -> ()
+  func.call @test_shli() : () -> ()
   func.call @test_shrui() : () -> ()
   return
 }
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir
new file mode 100644 (file)
index 0000000..1e32d18
--- /dev/null
@@ -0,0 +1,73 @@
+// Check that the wide integer `arith.shli` emulation produces the same result as wide
+// `arith.shli`. Emulate i16 ops with i8 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN:   FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
+// RUN:             --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN:   FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i8 types.
+func.func @emulate_shli(%lhs : i16, %rhs : i16) -> (i16) {
+  %res = arith.shli %lhs, %rhs : i16
+  return %res : i16
+}
+
+func.func @check_shli(%lhs : i16, %rhs : i16) -> () {
+  %res = func.call @emulate_shli(%lhs, %rhs) : (i16, i16) -> (i16)
+  vector.print %res : i16
+  return
+}
+
+func.func @entry() {
+  %cst0 = arith.constant 0 : i16
+  %cst1 = arith.constant 1 : i16
+  %cst2 = arith.constant 2 : i16
+  %cst7 = arith.constant 7 : i16
+  %cst8 = arith.constant 8 : i16
+  %cst9 = arith.constant 9 : i16
+  %cst15 = arith.constant 15 : i16
+
+  %cst_n1 = arith.constant -1 : i16
+
+  %cst1337 = arith.constant 1337 : i16
+
+  %cst_i16_min = arith.constant -32768 : i16
+
+  // CHECK:      0
+  // CHECK-NEXT: 0
+  // CHECK-NEXT: 1
+  // CHECK-NEXT: 2
+  // CHECK-NEXT: -2
+  // CHECK-NEXT: -32768
+  func.call @check_shli(%cst0, %cst0) : (i16, i16) -> ()
+  func.call @check_shli(%cst0, %cst1) : (i16, i16) -> ()
+  func.call @check_shli(%cst1, %cst0) : (i16, i16) -> ()
+  func.call @check_shli(%cst1, %cst1) : (i16, i16) -> ()
+  func.call @check_shli(%cst_n1, %cst1) : (i16, i16) -> ()
+  func.call @check_shli(%cst_n1, %cst15) : (i16, i16) -> ()
+
+  // CHECK-NEXT: 1337
+  // CHECK-NEXT: 5348
+  // CHECK-NEXT: -25472
+  // CHECK-NEXT: 14592
+  // CHECK-NEXT: 29184
+  // CHECK-NEXT: -32768
+  // CHECK-NEXT: 0
+  func.call @check_shli(%cst1337, %cst0) : (i16, i16) -> ()
+  func.call @check_shli(%cst1337, %cst2) : (i16, i16) -> ()
+  func.call @check_shli(%cst1337, %cst7) : (i16, i16) -> ()
+  func.call @check_shli(%cst1337, %cst8) : (i16, i16) -> ()
+  func.call @check_shli(%cst1337, %cst9) : (i16, i16) -> ()
+  func.call @check_shli(%cst1337, %cst15) : (i16, i16) -> ()
+  func.call @check_shli(%cst_i16_min, %cst1) : (i16, i16) -> ()
+
+  return
+}