[WebAssembly] Fix SIMD shift unrolling to avoid assertion failure
authorThomas Lively <tlively@google.com>
Thu, 12 Mar 2020 01:08:46 +0000 (18:08 -0700)
committerThomas Lively <tlively@google.com>
Thu, 12 Mar 2020 19:20:14 +0000 (12:20 -0700)
Summary:
Using the default DAG.UnrollVectorOp on v16i8 and v8i16 vectors
results in i8 or i16 nodes being inserted into the SelectionDAG. Since
those are illegal types, this causes a legalization assertion failure
for some code patterns, as uncovered by PR45178. This change unrolls
shifts manually to avoid this issue by adding and using a new optional
EVT argument to DAG.ExtractVectorElements to control the type of the
extract_element nodes.

Reviewers: aheejin, dschuff

Subscribers: sbc100, jgravelle-google, hiraditya, sunfish, zzheng, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76043

llvm/include/llvm/CodeGen/SelectionDAG.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
llvm/test/CodeGen/WebAssembly/simd-shift-unroll.ll [new file with mode: 0644]

index 7c1dd5d..993a3c8 100644 (file)
@@ -1740,10 +1740,13 @@ public:
   /// Widen the vector up to the next power of two using INSERT_SUBVECTOR.
   SDValue WidenVector(const SDValue &N, const SDLoc &DL);
 
-  /// Append the extracted elements from Start to Count out of the vector Op
-  /// in Args. If Count is 0, all of the elements will be extracted.
+  /// Append the extracted elements from Start to Count out of the vector Op in
+  /// Args. If Count is 0, all of the elements will be extracted. The extracted
+  /// elements will have type EVT if it is provided, and otherwise their type
+  /// will be Op's element type.
   void ExtractVectorElements(SDValue Op, SmallVectorImpl<SDValue> &Args,
-                             unsigned Start = 0, unsigned Count = 0);
+                             unsigned Start = 0, unsigned Count = 0,
+                             EVT EltVT = EVT());
 
   /// Compute the default alignment value for the given type.
   unsigned getEVTAlignment(EVT MemoryVT) const;
index e6d431a..dcd072d 100644 (file)
@@ -9501,12 +9501,13 @@ SDValue SelectionDAG::WidenVector(const SDValue &N, const SDLoc &DL) {
 
 void SelectionDAG::ExtractVectorElements(SDValue Op,
                                          SmallVectorImpl<SDValue> &Args,
-                                         unsigned Start, unsigned Count) {
+                                         unsigned Start, unsigned Count,
+                                         EVT EltVT) {
   EVT VT = Op.getValueType();
   if (Count == 0)
     Count = VT.getVectorNumElements();
-
-  EVT EltVT = VT.getVectorElementType();
+  if (EltVT == EVT())
+    EltVT = VT.getVectorElementType();
   SDLoc SL(Op);
   for (unsigned i = Start, e = Start + Count; i != e; ++i) {
     Args.push_back(getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, Op,
index b1e0de5..68e9aa6 100644 (file)
@@ -1582,22 +1582,25 @@ static SDValue unrollVectorShift(SDValue Op, SelectionDAG &DAG) {
     return DAG.UnrollVectorOp(Op.getNode());
   // Otherwise mask the shift value to get proper semantics from 32-bit shift
   SDLoc DL(Op);
-  SDValue ShiftVal = Op.getOperand(1);
-  uint64_t MaskVal = LaneT.getSizeInBits() - 1;
-  SDValue MaskedShiftVal = DAG.getNode(
-      ISD::AND,                    // mask opcode
-      DL, ShiftVal.getValueType(), // masked value type
-      ShiftVal,                    // original shift value operand
-      DAG.getConstant(MaskVal, DL, ShiftVal.getValueType()) // mask operand
-  );
-
-  return DAG.UnrollVectorOp(
-      DAG.getNode(Op.getOpcode(),        // original shift opcode
-                  DL, Op.getValueType(), // original return type
-                  Op.getOperand(0),      // original vector operand,
-                  MaskedShiftVal         // new masked shift value operand
-                  )
-          .getNode());
+  size_t NumLanes = Op.getSimpleValueType().getVectorNumElements();
+  SDValue Mask = DAG.getConstant(LaneT.getSizeInBits() - 1, DL, MVT::i32);
+  unsigned ShiftOpcode = Op.getOpcode();
+  SmallVector<SDValue, 16> ShiftedElements;
+  DAG.ExtractVectorElements(Op.getOperand(0), ShiftedElements, 0, 0, MVT::i32);
+  SmallVector<SDValue, 16> ShiftElements;
+  DAG.ExtractVectorElements(Op.getOperand(1), ShiftElements, 0, 0, MVT::i32);
+  SmallVector<SDValue, 16> UnrolledOps;
+  for (size_t i = 0; i < NumLanes; ++i) {
+    SDValue MaskedShiftValue =
+        DAG.getNode(ISD::AND, DL, MVT::i32, ShiftElements[i], Mask);
+    SDValue ShiftedValue = ShiftedElements[i];
+    if (ShiftOpcode == ISD::SRA)
+      ShiftedValue = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i32,
+                                 ShiftedValue, DAG.getValueType(LaneT));
+    UnrolledOps.push_back(
+        DAG.getNode(ShiftOpcode, DL, MVT::i32, ShiftedValue, MaskedShiftValue));
+  }
+  return DAG.getBuildVector(Op.getValueType(), DL, UnrolledOps);
 }
 
 SDValue WebAssemblyTargetLowering::LowerShift(SDValue Op,
diff --git a/llvm/test/CodeGen/WebAssembly/simd-shift-unroll.ll b/llvm/test/CodeGen/WebAssembly/simd-shift-unroll.ll
new file mode 100644 (file)
index 0000000..2a5422c
--- /dev/null
@@ -0,0 +1,128 @@
+; RUN: llc < %s -asm-verbose=false -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+unimplemented-simd128 | FileCheck %s --check-prefixes CHECK,SIMD128,SIMD128-SLOW
+
+;; Test that the custom shift unrolling works correctly in cases that
+;; cause assertion failures due to illegal types when using
+;; DAG.UnrollVectorOp. Regression test for PR45178.
+
+target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128"
+target triple = "wasm32-unknown-unknown"
+
+; CHECK-LABEL: shl_v16i8:
+; CHECK-NEXT: .functype       shl_v16i8 (v128) -> (v128)
+; CHECK-NEXT: i8x16.extract_lane_u    $push0=, $0, 0
+; CHECK-NEXT: i32.const       $push1=, 3
+; CHECK-NEXT: i32.shl         $push2=, $pop0, $pop1
+; CHECK-NEXT: i8x16.splat     $push3=, $pop2
+; CHECK-NEXT: i8x16.extract_lane_u    $push4=, $0, 1
+; CHECK-NEXT: i8x16.replace_lane      $push5=, $pop3, 1, $pop4
+; ...
+; CHECK:      i8x16.extract_lane_u    $push32=, $0, 15
+; CHECK-NEXT: i8x16.replace_lane      $push33=, $pop31, 15, $pop32
+; CHECK-NEXT: v8x16.shuffle   $push34=, $pop33, $0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: return  $pop34
+define <16 x i8> @shl_v16i8(<16 x i8> %in) {
+  %out = shl <16 x i8> %in,
+    <i8 3, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0,
+     i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>
+  %ret = shufflevector <16 x i8> %out, <16 x i8> undef, <16 x i32> zeroinitializer
+  ret <16 x i8> %ret
+}
+
+; CHECK-LABEL: shr_s_v16i8:
+; CHECK-NEXT: functype       shr_s_v16i8 (v128) -> (v128)
+; CHECK-NEXT: i8x16.extract_lane_s    $push0=, $0, 0
+; CHECK-NEXT: i32.const       $push1=, 3
+; CHECK-NEXT: i32.shr_s       $push2=, $pop0, $pop1
+; CHECK-NEXT: i8x16.splat     $push3=, $pop2
+; CHECK-NEXT: i8x16.extract_lane_s    $push4=, $0, 1
+; CHECK-NEXT: i8x16.replace_lane      $push5=, $pop3, 1, $pop4
+; ...
+; CHECK:      i8x16.extract_lane_s    $push32=, $0, 15
+; CHECK-NEXT: i8x16.replace_lane      $push33=, $pop31, 15, $pop32
+; CHECK-NEXT: v8x16.shuffle   $push34=, $pop33, $0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: return  $pop34
+define <16 x i8> @shr_s_v16i8(<16 x i8> %in) {
+  %out = ashr <16 x i8> %in,
+    <i8 3, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0,
+     i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>
+  %ret = shufflevector <16 x i8> %out, <16 x i8> undef, <16 x i32> zeroinitializer
+  ret <16 x i8> %ret
+}
+
+; CHECK-LABEL: shr_u_v16i8:
+; CHECK-NEXT: functype       shr_u_v16i8 (v128) -> (v128)
+; CHECK-NEXT: i8x16.extract_lane_u    $push0=, $0, 0
+; CHECK-NEXT: i32.const       $push1=, 3
+; CHECK-NEXT: i32.shr_u       $push2=, $pop0, $pop1
+; CHECK-NEXT: i8x16.splat     $push3=, $pop2
+; CHECK-NEXT: i8x16.extract_lane_u    $push4=, $0, 1
+; CHECK-NEXT: i8x16.replace_lane      $push5=, $pop3, 1, $pop4
+; ...
+; CHECK:      i8x16.extract_lane_u    $push32=, $0, 15
+; CHECK-NEXT: i8x16.replace_lane      $push33=, $pop31, 15, $pop32
+; CHECK-NEXT: v8x16.shuffle   $push34=, $pop33, $0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: return  $pop34
+define <16 x i8> @shr_u_v16i8(<16 x i8> %in) {
+  %out = lshr <16 x i8> %in,
+    <i8 3, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0,
+     i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>
+  %ret = shufflevector <16 x i8> %out, <16 x i8> undef, <16 x i32> zeroinitializer
+  ret <16 x i8> %ret
+}
+
+; CHECK-LABEL: shl_v8i16:
+; CHECK-NEXT: functype       shl_v8i16 (v128) -> (v128)
+; CHECK-NEXT: i16x8.extract_lane_u    $push0=, $0, 0
+; CHECK-NEXT: i32.const       $push1=, 9
+; CHECK-NEXT: i32.shl         $push2=, $pop0, $pop1
+; CHECK-NEXT: i16x8.splat     $push3=, $pop2
+; CHECK-NEXT: i16x8.extract_lane_u    $push4=, $0, 1
+; CHECK-NEXT: i16x8.replace_lane      $push5=, $pop3, 1, $pop4
+; ...
+; CHECK:      i16x8.extract_lane_u    $push16=, $0, 7
+; CHECK-NEXT: i16x8.replace_lane      $push17=, $pop15, 7, $pop16
+; CHECK-NEXT: v8x16.shuffle   $push18=, $pop17, $0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; CHECK-NEXT: return  $pop18
+define <8 x i16> @shl_v8i16(<8 x i16> %in) {
+  %out = shl <8 x i16> %in, <i16 9, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>
+  %ret = shufflevector <8 x i16> %out, <8 x i16> undef, <8 x i32> zeroinitializer
+  ret <8 x i16> %ret
+}
+
+; CHECK-LABEL: shr_s_v8i16:
+; CHECK-NEXT: functype       shr_s_v8i16 (v128) -> (v128)
+; CHECK-NEXT: i16x8.extract_lane_s    $push0=, $0, 0
+; CHECK-NEXT: i32.const       $push1=, 9
+; CHECK-NEXT: i32.shr_s       $push2=, $pop0, $pop1
+; CHECK-NEXT: i16x8.splat     $push3=, $pop2
+; CHECK-NEXT: i16x8.extract_lane_s    $push4=, $0, 1
+; CHECK-NEXT: i16x8.replace_lane      $push5=, $pop3, 1, $pop4
+; ...
+; CHECK:      i16x8.extract_lane_s    $push16=, $0, 7
+; CHECK-NEXT: i16x8.replace_lane      $push17=, $pop15, 7, $pop16
+; CHECK-NEXT: v8x16.shuffle   $push18=, $pop17, $0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; CHECK-NEXT: return  $pop18
+define <8 x i16> @shr_s_v8i16(<8 x i16> %in) {
+  %out = ashr <8 x i16> %in, <i16 9, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>
+  %ret = shufflevector <8 x i16> %out, <8 x i16> undef, <8 x i32> zeroinitializer
+  ret <8 x i16> %ret
+}
+
+; CHECK-LABEL: shr_u_v8i16:
+; CHECK-NEXT: functype       shr_u_v8i16 (v128) -> (v128)
+; CHECK-NEXT: i16x8.extract_lane_u    $push0=, $0, 0
+; CHECK-NEXT: i32.const       $push1=, 9
+; CHECK-NEXT: i32.shr_u       $push2=, $pop0, $pop1
+; CHECK-NEXT: i16x8.splat     $push3=, $pop2
+; CHECK-NEXT: i16x8.extract_lane_u    $push4=, $0, 1
+; CHECK-NEXT: i16x8.replace_lane      $push5=, $pop3, 1, $pop4
+; ...
+; CHECK:      i16x8.extract_lane_u    $push16=, $0, 7
+; CHECK-NEXT: i16x8.replace_lane      $push17=, $pop15, 7, $pop16
+; CHECK-NEXT: v8x16.shuffle   $push18=, $pop17, $0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
+; CHECK-NEXT: return  $pop18
+define <8 x i16> @shr_u_v8i16(<8 x i16> %in) {
+  %out = lshr <8 x i16> %in, <i16 9, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>
+  %ret = shufflevector <8 x i16> %out, <8 x i16> undef, <8 x i32> zeroinitializer
+  ret <8 x i16> %ret
+}