From 33a448f935286de94f00a631277340b8c471a4c0 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Wed, 5 Dec 2018 17:10:30 +0000 Subject: [PATCH] [DAGCombiner] don't try to extract a fraction of a vector binop and crash (PR39893) Because we're potentially peeking through a bitcast in this transform, we need to use overall bitwidths rather than number of elements to determine when it's safe to proceed. Should fix: https://bugs.llvm.org/show_bug.cgi?id=39893 llvm-svn: 348383 --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 24 ++++++---- llvm/test/CodeGen/X86/vector-narrow-binop.ll | 69 +++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 10 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index b6be6d1..326b08d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -16660,29 +16660,33 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) { if (!ISD::isBinaryOp(BinOp.getNode())) return SDValue(); - // The binop must be a vector type, so we can chop it in half. + // The binop must be a vector type, so we can extract some fraction of it. EVT WideBVT = BinOp.getValueType(); if (!WideBVT.isVector()) return SDValue(); EVT VT = Extract->getValueType(0); - unsigned NumElems = VT.getVectorNumElements(); unsigned ExtractIndex = ExtractIndexC->getZExtValue(); - assert(ExtractIndex % NumElems == 0 && + assert(ExtractIndex % VT.getVectorNumElements() == 0 && "Extract index is not a multiple of the vector length."); - EVT SrcVT = Extract->getOperand(0).getValueType(); // Bail out if this is not a proper multiple width extraction. - unsigned NumSrcElems = SrcVT.getVectorNumElements(); - if (NumSrcElems % NumElems != 0) + unsigned WideWidth = WideBVT.getSizeInBits(); + unsigned NarrowWidth = VT.getSizeInBits(); + if (WideWidth % NarrowWidth != 0) return SDValue(); - // Bail out if the target does not support a narrower version of the binop. - unsigned NarrowingRatio = NumSrcElems / NumElems; - unsigned BOpcode = BinOp.getOpcode(); + // Bail out if we are extracting a fraction of a single operation. This can + // occur because we potentially looked through a bitcast of the binop. + unsigned NarrowingRatio = WideWidth / NarrowWidth; unsigned WideNumElts = WideBVT.getVectorNumElements(); + if (WideNumElts % NarrowingRatio != 0) + return SDValue(); + + // Bail out if the target does not support a narrower version of the binop. EVT NarrowBVT = EVT::getVectorVT(*DAG.getContext(), WideBVT.getScalarType(), WideNumElts / NarrowingRatio); + unsigned BOpcode = BinOp.getOpcode(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (!TLI.isOperationLegalOrCustomOrPromote(BOpcode, NarrowBVT)) return SDValue(); @@ -16691,7 +16695,7 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG) { // for concat ops. The narrow binop alone makes this transform profitable. // We can't just reuse the original extract index operand because we may have // bitcasted. - unsigned ConcatOpNum = ExtractIndex / NumElems; + unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements(); unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements(); EVT ExtBOIdxVT = Extract->getOperand(1).getValueType(); if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) && diff --git a/llvm/test/CodeGen/X86/vector-narrow-binop.ll b/llvm/test/CodeGen/X86/vector-narrow-binop.ll index c20dc09..4836b49 100644 --- a/llvm/test/CodeGen/X86/vector-narrow-binop.ll +++ b/llvm/test/CodeGen/X86/vector-narrow-binop.ll @@ -98,3 +98,72 @@ define <3 x float> @PR39511(<4 x float> %t0, <3 x float>* %b) { ret <3 x float> %ext } +; When extracting from a vector binop, we need to be extracting +; by a width of at least 1 of the original vector elements. +; https://bugs.llvm.org/show_bug.cgi?id=39893 + +define <2 x i8> @PR39893(<2 x i32> %x, <8 x i8> %y) { +; SSE-LABEL: PR39893: +; SSE: # %bb.0: +; SSE-NEXT: pxor %xmm2, %xmm2 +; SSE-NEXT: psubd %xmm0, %xmm2 +; SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm2[0,2,2,3] +; SSE-NEXT: psrld $16, %xmm0 +; SSE-NEXT: movsd {{.*#+}} xmm1 = xmm0[0],xmm1[1] +; SSE-NEXT: movapd %xmm1, %xmm0 +; SSE-NEXT: retq +; +; AVX1-LABEL: PR39893: +; AVX1: # %bb.0: +; AVX1-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX1-NEXT: vpsubd %xmm0, %xmm2, %xmm0 +; AVX1-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; AVX1-NEXT: vpsrld $16, %xmm0, %xmm0 +; AVX1-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7] +; AVX1-NEXT: retq +; +; AVX2-LABEL: PR39893: +; AVX2: # %bb.0: +; AVX2-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX2-NEXT: vpsubd %xmm0, %xmm2, %xmm0 +; AVX2-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; AVX2-NEXT: vpsrld $16, %xmm0, %xmm0 +; AVX2-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3] +; AVX2-NEXT: retq +; +; AVX512-LABEL: PR39893: +; AVX512: # %bb.0: +; AVX512-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512-NEXT: vpsubd %xmm0, %xmm2, %xmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; AVX512-NEXT: vpsrld $16, %xmm0, %xmm0 +; AVX512-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3] +; AVX512-NEXT: retq + %sub = sub <2 x i32> , %x + %bc = bitcast <2 x i32> %sub to <8 x i8> + %shuffle = shufflevector <8 x i8> %y, <8 x i8> %bc, <2 x i32> + ret <2 x i8> %shuffle +} + +define <2 x i8> @PR39893_2(<2 x float> %x) { +; SSE-LABEL: PR39893_2: +; SSE: # %bb.0: +; SSE-NEXT: xorps %xmm1, %xmm1 +; SSE-NEXT: subps %xmm0, %xmm1 +; SSE-NEXT: punpcklbw {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3],xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7] +; SSE-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3] +; SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm1[0,1,1,3] +; SSE-NEXT: retq +; +; AVX-LABEL: PR39893_2: +; AVX: # %bb.0: +; AVX-NEXT: vxorps %xmm1, %xmm1, %xmm1 +; AVX-NEXT: vsubps %xmm0, %xmm1, %xmm0 +; AVX-NEXT: vpmovzxbq {{.*#+}} xmm0 = xmm0[0],zero,zero,zero,zero,zero,zero,zero,xmm0[1],zero,zero,zero,zero,zero,zero,zero +; AVX-NEXT: retq + %fsub = fsub <2 x float> zeroinitializer, %x + %bc = bitcast <2 x float> %fsub to <8 x i8> + %shuffle = shufflevector <8 x i8> %bc, <8 x i8> undef, <2 x i32> + ret <2 x i8> %shuffle +} + -- 2.7.4