[WebAssembly] Skip implied bitmask operation in LowerShift
authorJun Ma <JunMa@linux.alibaba.com>
Thu, 23 Feb 2023 07:45:48 +0000 (15:45 +0800)
committerJun Ma <JunMa@linux.alibaba.com>
Thu, 2 Mar 2023 01:37:25 +0000 (09:37 +0800)
This patch skips redundant explicit masks of the shift count since
it is implied inside wasm shift instruction.

Differential Revision: https://reviews.llvm.org/D144619

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
llvm/test/CodeGen/WebAssembly/masked-shifts.ll

index 9454480..32d0b01 100644 (file)
@@ -2287,10 +2287,43 @@ SDValue WebAssemblyTargetLowering::LowerShift(SDValue Op,
   // Only manually lower vector shifts
   assert(Op.getSimpleValueType().isVector());
 
-  auto ShiftVal = DAG.getSplatValue(Op.getOperand(1));
+  uint64_t LaneBits = Op.getValueType().getScalarSizeInBits();
+  auto ShiftVal = Op.getOperand(1);
+
+  // Try to skip bitmask operation since it is implied inside shift instruction
+  auto SkipImpliedMask = [](SDValue MaskOp, uint64_t MaskBits) {
+    if (MaskOp.getOpcode() != ISD::AND)
+      return MaskOp;
+    SDValue LHS = MaskOp.getOperand(0);
+    SDValue RHS = MaskOp.getOperand(1);
+    if (MaskOp.getValueType().isVector()) {
+      APInt MaskVal;
+      if (!ISD::isConstantSplatVector(RHS.getNode(), MaskVal))
+        std::swap(LHS, RHS);
+
+      if (ISD::isConstantSplatVector(RHS.getNode(), MaskVal) &&
+          MaskVal == MaskBits)
+        MaskOp = LHS;
+    } else {
+      if (!isa<ConstantSDNode>(RHS.getNode()))
+        std::swap(LHS, RHS);
+
+      auto ConstantRHS = dyn_cast<ConstantSDNode>(RHS.getNode());
+      if (ConstantRHS && ConstantRHS->getAPIntValue() == MaskBits)
+        MaskOp = LHS;
+    }
+
+    return MaskOp;
+  };
+
+  // Skip vector and operation
+  ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
+  ShiftVal = DAG.getSplatValue(ShiftVal);
   if (!ShiftVal)
     return unrollVectorShift(Op, DAG);
 
+  // Skip scalar and operation
+  ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
   // Use anyext because none of the high bits can affect the shift
   ShiftVal = DAG.getAnyExtOrTrunc(ShiftVal, DL, MVT::i32);
 
index 56e6119..5bcb023 100644 (file)
@@ -106,10 +106,6 @@ define <16 x i8> @shl_v16i8_late(<16 x i8> %v, i8 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i8x16.splat
-; CHECK-NEXT:    v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i8x16.extract_lane_u 0
 ; CHECK-NEXT:    i8x16.shl
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -145,10 +141,6 @@ define <16 x i8> @ashr_v16i8_late(<16 x i8> %v, i8 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i8x16.splat
-; CHECK-NEXT:    v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i8x16.extract_lane_u 0
 ; CHECK-NEXT:    i8x16.shr_s
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -184,10 +176,6 @@ define <16 x i8> @lshr_v16i8_late(<16 x i8> %v, i8 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i8x16.splat
-; CHECK-NEXT:    v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i8x16.extract_lane_u 0
 ; CHECK-NEXT:    i8x16.shr_u
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -222,10 +210,6 @@ define <8 x i16> @shl_v8i16_late(<8 x i16> %v, i16 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i16x8.splat
-; CHECK-NEXT:    v128.const 15, 15, 15, 15, 15, 15, 15, 15
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i16x8.extract_lane_u 0
 ; CHECK-NEXT:    i16x8.shl
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <8 x i16> undef, i16 %x, i32 0
@@ -259,10 +243,6 @@ define <8 x i16> @ashr_v8i16_late(<8 x i16> %v, i16 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i16x8.splat
-; CHECK-NEXT:    v128.const 15, 15, 15, 15, 15, 15, 15, 15
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i16x8.extract_lane_u 0
 ; CHECK-NEXT:    i16x8.shr_s
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <8 x i16> undef, i16 %x, i32 0
@@ -296,10 +276,6 @@ define <8 x i16> @lshr_v8i16_late(<8 x i16> %v, i16 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i16x8.splat
-; CHECK-NEXT:    v128.const 15, 15, 15, 15, 15, 15, 15, 15
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i16x8.extract_lane_u 0
 ; CHECK-NEXT:    i16x8.shr_u
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <8 x i16> undef, i16 %x, i32 0
@@ -519,6 +495,22 @@ define <2 x i64> @shl_v2i64_i32(<2 x i64> %v, i32 %x) {
   ret <2 x i64> %a
 }
 
+define <2 x i64> @shl_v2i64_i32_late(<2 x i64> %v, i32 %x) {
+; CHECK-LABEL: shl_v2i64_i32_late:
+; CHECK:         .functype shl_v2i64_i32_late (v128, i32) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i64x2.shl
+; CHECK-NEXT:    # fallthrough-return
+  %z = zext i32 %x to i64
+  %t = insertelement <2 x i64> undef, i64 %z, i32 0
+  %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
+  %m = and <2 x i64> %s, <i64 63, i64 63>
+  %a = shl <2 x i64> %v, %m
+  ret <2 x i64> %a
+}
+
 define <2 x i64> @ashr_v2i64_i32(<2 x i64> %v, i32 %x) {
 ; CHECK-LABEL: ashr_v2i64_i32:
 ; CHECK:         .functype ashr_v2i64_i32 (v128, i32) -> (v128)
@@ -535,6 +527,22 @@ define <2 x i64> @ashr_v2i64_i32(<2 x i64> %v, i32 %x) {
   ret <2 x i64> %a
 }
 
+define <2 x i64> @ashr_v2i64_i32_late(<2 x i64> %v, i32 %x) {
+; CHECK-LABEL: ashr_v2i64_i32_late:
+; CHECK:         .functype ashr_v2i64_i32_late (v128, i32) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i64x2.shr_s
+; CHECK-NEXT:    # fallthrough-return
+  %z = zext i32 %x to i64
+  %t = insertelement <2 x i64> undef, i64 %z, i32 0
+  %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
+  %m = and <2 x i64> %s, <i64 63, i64 63>
+  %a = ashr <2 x i64> %v, %m
+  ret <2 x i64> %a
+}
+
 define <2 x i64> @lshr_v2i64_i32(<2 x i64> %v, i32 %x) {
 ; CHECK-LABEL: lshr_v2i64_i32:
 ; CHECK:         .functype lshr_v2i64_i32 (v128, i32) -> (v128)
@@ -551,3 +559,18 @@ define <2 x i64> @lshr_v2i64_i32(<2 x i64> %v, i32 %x) {
   ret <2 x i64> %a
 }
 
+define <2 x i64> @lshr_v2i64_i32_late(<2 x i64> %v, i32 %x) {
+; CHECK-LABEL: lshr_v2i64_i32_late:
+; CHECK:         .functype lshr_v2i64_i32_late (v128, i32) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i64x2.shr_u
+; CHECK-NEXT:    # fallthrough-return
+  %z = zext i32 %x to i64
+  %t = insertelement <2 x i64> undef, i64 %z, i32 0
+  %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
+  %m = and <2 x i64> %s, <i64 63, i64 63>
+  %a = lshr <2 x i64> %v, %m
+  ret <2 x i64> %a
+}