[WebAssembly] Pattern match SIMD convert_low and promote_low during ISel
authorThomas Lively <tlively@google.com>
Thu, 19 Aug 2021 22:24:28 +0000 (15:24 -0700)
committerThomas Lively <tlively@google.com>
Thu, 19 Aug 2021 22:24:28 +0000 (15:24 -0700)
Since the simplest DAG patterns for convert_low and promote_low instructions
involved v2i32, v2f32, v4i64, and v4f64 types, which are not legal in the
WebAssembly backend and would be eliminated by type legalization, we were
previously matching those patterns in a DAG combine before the type legalization
stage. However in cases where the vectors were wider than 128 bits, the patterns
we matched were not created until the type legalization stage when the wide
vectors were split up. Type legalization would continue to eliminate the illegal
types we were matching as well, so the code ended up scalarized.

To make the ISel for these instructions more robust, match the scalarized
patterns rather than the patterns containing illegal types. Add tests with
double-wide vectors to show that this works as intended.

Fixes PR51098.
Depends on D107502.

Differential Revision: https://reviews.llvm.org/D108266

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
llvm/test/CodeGen/WebAssembly/simd-conversions.ll

index 507895e..c418a38 100644 (file)
@@ -1778,8 +1778,71 @@ WebAssemblyTargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
                      Op.getOperand(1));
 }
 
+static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
+  if (Op.getValueType() != MVT::v2f64)
+    return SDValue();
+
+  auto GetConvertedLane = [](SDValue Op, unsigned &Opcode, SDValue &SrcVec,
+                             unsigned &Index) -> bool {
+    switch (Op.getOpcode()) {
+    case ISD::SINT_TO_FP:
+      Opcode = WebAssemblyISD::CONVERT_LOW_S;
+      break;
+    case ISD::UINT_TO_FP:
+      Opcode = WebAssemblyISD::CONVERT_LOW_U;
+      break;
+    case ISD::FP_EXTEND:
+      Opcode = WebAssemblyISD::PROMOTE_LOW;
+      break;
+    default:
+      return false;
+    }
+
+    auto ExtractVector = Op.getOperand(0);
+    if (ExtractVector.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+      return false;
+
+    if (!isa<ConstantSDNode>(ExtractVector.getOperand(1).getNode()))
+      return false;
+
+    SrcVec = ExtractVector.getOperand(0);
+    Index = ExtractVector.getConstantOperandVal(1);
+    return true;
+  };
+
+  unsigned LHSOpcode, RHSOpcode, LHSIndex, RHSIndex;
+  SDValue LHSSrcVec, RHSSrcVec;
+  if (!GetConvertedLane(Op.getOperand(0), LHSOpcode, LHSSrcVec, LHSIndex) ||
+      !GetConvertedLane(Op.getOperand(1), RHSOpcode, RHSSrcVec, RHSIndex))
+    return SDValue();
+
+  if (LHSOpcode != RHSOpcode || LHSSrcVec != RHSSrcVec)
+    return SDValue();
+
+  if (LHSIndex != 0 || RHSIndex != 1)
+    return SDValue();
+
+  MVT ExpectedSrcVT;
+  switch (LHSOpcode) {
+  case WebAssemblyISD::CONVERT_LOW_S:
+  case WebAssemblyISD::CONVERT_LOW_U:
+    ExpectedSrcVT = MVT::v4i32;
+    break;
+  case WebAssemblyISD::PROMOTE_LOW:
+    ExpectedSrcVT = MVT::v4f32;
+    break;
+  }
+  if (LHSSrcVec.getValueType() != ExpectedSrcVT)
+    return SDValue();
+
+  return DAG.getNode(LHSOpcode, SDLoc(Op), MVT::v2f64, LHSSrcVec);
+}
+
 SDValue WebAssemblyTargetLowering::LowerBUILD_VECTOR(SDValue Op,
                                                      SelectionDAG &DAG) const {
+  if (auto ConvertLow = LowerConvertLow(Op, DAG))
+    return ConvertLow;
+
   SDLoc DL(Op);
   const EVT VecT = Op.getValueType();
   const EVT LaneT = Op.getOperand(0).getValueType();
@@ -2231,120 +2294,6 @@ performVectorExtendCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
 }
 
 static SDValue
-performVectorConvertLowCombine(SDNode *N,
-                               TargetLowering::DAGCombinerInfo &DCI) {
-  auto &DAG = DCI.DAG;
-
-  EVT ResVT = N->getValueType(0);
-  if (ResVT != MVT::v2f64)
-    return SDValue();
-
-  auto GetWasmConversionOp = [](unsigned Op) {
-    switch (Op) {
-    case ISD::SINT_TO_FP:
-      return WebAssemblyISD::CONVERT_LOW_S;
-    case ISD::UINT_TO_FP:
-      return WebAssemblyISD::CONVERT_LOW_U;
-    case ISD::FP_EXTEND:
-      return WebAssemblyISD::PROMOTE_LOW;
-    }
-    llvm_unreachable("unexpected op");
-  };
-
-  if (N->getOpcode() == ISD::EXTRACT_SUBVECTOR) {
-    // Combine this:
-    //
-    //   (v2f64 (extract_subvector
-    //     (v4f64 ({s,u}int_to_fp (v4i32 $x))), 0))
-    //
-    // into (f64x2.convert_low_i32x4_{s,u} $x).
-    //
-    // Or this:
-    //
-    //  (v2f64 (extract_subvector
-    //    (v4f64 (fp_extend (v4f32 $x))), 0))
-    //
-    // into (f64x2.promote_low_f32x4 $x).
-    auto Conversion = N->getOperand(0);
-    auto ConversionOp = Conversion.getOpcode();
-    MVT ExpectedSourceType;
-    switch (ConversionOp) {
-    case ISD::SINT_TO_FP:
-    case ISD::UINT_TO_FP:
-      ExpectedSourceType = MVT::v4i32;
-      break;
-    case ISD::FP_EXTEND:
-      ExpectedSourceType = MVT::v4f32;
-      break;
-    default:
-      return SDValue();
-    }
-
-    if (Conversion.getValueType() != MVT::v4f64)
-      return SDValue();
-
-    auto Source = Conversion.getOperand(0);
-    if (Source.getValueType() != ExpectedSourceType)
-      return SDValue();
-
-    auto IndexNode = dyn_cast<ConstantSDNode>(N->getOperand(1));
-    if (IndexNode == nullptr || IndexNode->getZExtValue() != 0)
-      return SDValue();
-
-    auto Op = GetWasmConversionOp(ConversionOp);
-    return DAG.getNode(Op, SDLoc(N), ResVT, Source);
-  }
-
-  // Combine this:
-  //
-  //   (v2f64 ({s,u}int_to_fp
-  //     (v2i32 (extract_subvector (v4i32 $x), 0))))
-  //
-  // into (f64x2.convert_low_i32x4_{s,u} $x).
-  //
-  // Or this:
-  //
-  //   (v2f64 (fp_extend
-  //     (v2f32 (extract_subvector (v4f32 $x), 0))))
-  //
-  // into (f64x2.promote_low_f32x4 $x).
-  auto ConversionOp = N->getOpcode();
-  MVT ExpectedExtractType;
-  MVT ExpectedSourceType;
-  switch (ConversionOp) {
-  case ISD::SINT_TO_FP:
-  case ISD::UINT_TO_FP:
-    ExpectedExtractType = MVT::v2i32;
-    ExpectedSourceType = MVT::v4i32;
-    break;
-  case ISD::FP_EXTEND:
-    ExpectedExtractType = MVT::v2f32;
-    ExpectedSourceType = MVT::v4f32;
-    break;
-  default:
-    llvm_unreachable("unexpected opcode");
-  }
-
-  auto Extract = N->getOperand(0);
-  if (Extract.getOpcode() != ISD::EXTRACT_SUBVECTOR)
-    return SDValue();
-
-  if (Extract.getValueType() != ExpectedExtractType)
-    return SDValue();
-
-  auto Source = Extract.getOperand(0);
-  if (Source.getValueType() != ExpectedSourceType)
-    return SDValue();
-
-  auto *IndexNode = dyn_cast<ConstantSDNode>(Extract.getOperand(1));
-  if (IndexNode == nullptr || IndexNode->getZExtValue() != 0)
-    return SDValue();
-
-  unsigned Op = GetWasmConversionOp(ConversionOp);
-  return DAG.getNode(Op, SDLoc(N), ResVT, Source);
-}
-
-static SDValue
 performVectorTruncZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   auto &DAG = DCI.DAG;
 
@@ -2474,11 +2423,6 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
     return performVectorExtendCombine(N, DCI);
-  case ISD::SINT_TO_FP:
-  case ISD::UINT_TO_FP:
-  case ISD::FP_EXTEND:
-  case ISD::EXTRACT_SUBVECTOR:
-    return performVectorConvertLowCombine(N, DCI);
   case ISD::FP_TO_SINT_SAT:
   case ISD::FP_TO_UINT_SAT:
   case ISD::FP_ROUND:
index 1aa0ccc..c624058 100644 (file)
@@ -304,3 +304,150 @@ define <2 x double> @promote_low_v2f64_2(<4 x float> %x) {
   %a = shufflevector <4 x double> %v, <4 x double> undef, <2 x i32> <i32 0, i32 1>
   ret <2 x double> %a
 }
+
+;; Also check with illegally wide vectors
+
+define <4 x double> @convert_low_s_v4f64(<8 x i32> %x) {
+; CHECK-LABEL: convert_low_s_v4f64:
+; CHECK:         .functype convert_low_s_v4f64 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f64x2.convert_low_i32x4_s
+; CHECK-NEXT:    v128.store 0
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extract_lane 2
+; CHECK-NEXT:    f64.convert_i32_s
+; CHECK-NEXT:    f64x2.splat
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extract_lane 3
+; CHECK-NEXT:    f64.convert_i32_s
+; CHECK-NEXT:    f64x2.replace_lane 1
+; CHECK-NEXT:    v128.store 16
+; CHECK-NEXT:    # fallthrough-return
+  %v = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %a = sitofp <4 x i32> %v to <4 x double>
+  ret <4 x double> %a
+}
+
+define <4 x double> @convert_low_u_v4f64(<8 x i32> %x) {
+; CHECK-LABEL: convert_low_u_v4f64:
+; CHECK:         .functype convert_low_u_v4f64 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f64x2.convert_low_i32x4_u
+; CHECK-NEXT:    v128.store 0
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extract_lane 2
+; CHECK-NEXT:    f64.convert_i32_u
+; CHECK-NEXT:    f64x2.splat
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extract_lane 3
+; CHECK-NEXT:    f64.convert_i32_u
+; CHECK-NEXT:    f64x2.replace_lane 1
+; CHECK-NEXT:    v128.store 16
+; CHECK-NEXT:    # fallthrough-return
+  %v = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %a = uitofp <4 x i32> %v to <4 x double>
+  ret <4 x double> %a
+}
+
+
+define <4 x double> @convert_low_s_v4f64_2(<8 x i32> %x) {
+; CHECK-LABEL: convert_low_s_v4f64_2:
+; CHECK:         .functype convert_low_s_v4f64_2 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f64x2.convert_low_i32x4_s
+; CHECK-NEXT:    v128.store 0
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extract_lane 2
+; CHECK-NEXT:    f64.convert_i32_s
+; CHECK-NEXT:    f64x2.splat
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extract_lane 3
+; CHECK-NEXT:    f64.convert_i32_s
+; CHECK-NEXT:    f64x2.replace_lane 1
+; CHECK-NEXT:    v128.store 16
+; CHECK-NEXT:    # fallthrough-return
+  %v = sitofp <8 x i32> %x to <8 x double>
+  %a = shufflevector <8 x double> %v, <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  ret <4 x double> %a
+}
+
+define <4 x double> @convert_low_u_v4f64_2(<8 x i32> %x) {
+; CHECK-LABEL: convert_low_u_v4f64_2:
+; CHECK:         .functype convert_low_u_v4f64_2 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f64x2.convert_low_i32x4_u
+; CHECK-NEXT:    v128.store 0
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extract_lane 2
+; CHECK-NEXT:    f64.convert_i32_u
+; CHECK-NEXT:    f64x2.splat
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extract_lane 3
+; CHECK-NEXT:    f64.convert_i32_u
+; CHECK-NEXT:    f64x2.replace_lane 1
+; CHECK-NEXT:    v128.store 16
+; CHECK-NEXT:    # fallthrough-return
+  %v = uitofp <8 x i32> %x to <8 x double>
+  %a = shufflevector <8 x double> %v, <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  ret <4 x double> %a
+}
+
+define <4 x double> @promote_low_v4f64(<8 x float> %x) {
+; CHECK-LABEL: promote_low_v4f64:
+; CHECK:         .functype promote_low_v4f64 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f64x2.promote_low_f32x4
+; CHECK-NEXT:    v128.store 0
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f32x4.extract_lane 2
+; CHECK-NEXT:    f64.promote_f32
+; CHECK-NEXT:    f64x2.splat
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f32x4.extract_lane 3
+; CHECK-NEXT:    f64.promote_f32
+; CHECK-NEXT:    f64x2.replace_lane 1
+; CHECK-NEXT:    v128.store 16
+; CHECK-NEXT:    # fallthrough-return
+  %v = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %a = fpext <4 x float> %v to <4 x double>
+  ret <4 x double> %a
+}
+
+define <4 x double> @promote_low_v4f64_2(<8 x float> %x) {
+; CHECK-LABEL: promote_low_v4f64_2:
+; CHECK:         .functype promote_low_v4f64_2 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f64x2.promote_low_f32x4
+; CHECK-NEXT:    v128.store 0
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f32x4.extract_lane 2
+; CHECK-NEXT:    f64.promote_f32
+; CHECK-NEXT:    f64x2.splat
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    f32x4.extract_lane 3
+; CHECK-NEXT:    f64.promote_f32
+; CHECK-NEXT:    f64x2.replace_lane 1
+; CHECK-NEXT:    v128.store 16
+; CHECK-NEXT:    # fallthrough-return
+  %v = fpext <8 x float> %x to <8 x double>
+  %a = shufflevector <8 x double> %v, <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  ret <4 x double> %a
+}