Teach the legalizer how to handle operands for VSELECT nodes
authorJustin Holewinski <jholewinski@nvidia.com>
Thu, 29 Nov 2012 14:26:28 +0000 (14:26 +0000)
committerJustin Holewinski <jholewinski@nvidia.com>
Thu, 29 Nov 2012 14:26:28 +0000 (14:26 +0000)
If we need to split the operand of a VSELECT, it must be the mask operand. We
split the entire VSELECT operand with EXTRACT_SUBVECTOR.

llvm-svn: 168883

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
llvm/test/CodeGen/NVPTX/vector-select.ll [new file with mode: 0644]

index 20b7ce6..8464b7d 100644 (file)
@@ -578,6 +578,7 @@ private:
 
   // Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
   bool SplitVectorOperand(SDNode *N, unsigned OpNo);
+  SDValue SplitVecOp_VSELECT(SDNode *N, unsigned OpNo);
   SDValue SplitVecOp_UnaryOp(SDNode *N);
 
   SDValue SplitVecOp_BITCAST(SDNode *N);
index d51a6eb..595d83b 100644 (file)
@@ -1030,7 +1030,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
     case ISD::STORE:
       Res = SplitVecOp_STORE(cast<StoreSDNode>(N), OpNo);
       break;
-
+    case ISD::VSELECT:
+      Res = SplitVecOp_VSELECT(N, OpNo);
+      break;
     case ISD::CTTZ:
     case ISD::CTLZ:
     case ISD::CTPOP:
@@ -1064,6 +1066,62 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   return false;
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_VSELECT(SDNode *N, unsigned OpNo) {
+  // The only possibility for an illegal operand is the mask, since result type
+  // legalization would have handled this node already otherwise.
+  assert(OpNo == 0 && "Illegal operand must be mask");
+
+  SDValue Mask = N->getOperand(0);
+  SDValue Src0 = N->getOperand(1);
+  SDValue Src1 = N->getOperand(2);
+  DebugLoc DL = N->getDebugLoc();
+  EVT MaskVT = Mask.getValueType();
+  assert(MaskVT.isVector() && "VSELECT without a vector mask?");
+
+  SDValue Lo, Hi;
+  GetSplitVector(N->getOperand(0), Lo, Hi);
+
+  unsigned LoNumElts = Lo.getValueType().getVectorNumElements();
+  unsigned HiNumElts = Hi.getValueType().getVectorNumElements();
+  assert(LoNumElts == HiNumElts && "Asymmetric vector split?");
+
+  EVT LoOpVT = EVT::getVectorVT(*DAG.getContext(),
+                                Src0.getValueType().getVectorElementType(),
+                                LoNumElts);
+  EVT LoMaskVT = EVT::getVectorVT(*DAG.getContext(),
+                                  MaskVT.getVectorElementType(),
+                                  LoNumElts);
+  EVT HiOpVT = EVT::getVectorVT(*DAG.getContext(),
+                                Src0.getValueType().getVectorElementType(),
+                                HiNumElts);
+  EVT HiMaskVT = EVT::getVectorVT(*DAG.getContext(),
+                                  MaskVT.getVectorElementType(),
+                                  HiNumElts);
+
+  SDValue LoOp0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoOpVT, Src0,
+                              DAG.getIntPtrConstant(0));
+  SDValue LoOp1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoOpVT, Src1,
+                              DAG.getIntPtrConstant(0));
+
+  SDValue HiOp0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiOpVT, Src0,
+                              DAG.getIntPtrConstant(LoNumElts));
+  SDValue HiOp1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiOpVT, Src1,
+                              DAG.getIntPtrConstant(LoNumElts));
+
+  SDValue LoMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoMaskVT, Mask,
+                               DAG.getIntPtrConstant(0));
+  SDValue HiMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiMaskVT, Mask,
+                               DAG.getIntPtrConstant(LoNumElts));
+
+  SDValue LoSelect = DAG.getNode(ISD::VSELECT, DL, LoOpVT, LoMask, LoOp0,
+                                 LoOp1);
+  SDValue HiSelect = DAG.getNode(ISD::VSELECT, DL, HiOpVT, HiMask, HiOp0,
+                                 HiOp1);
+
+  return DAG.getNode(ISD::CONCAT_VECTORS, DL, Src0.getValueType(), LoSelect,
+                     HiSelect);
+}
+
 SDValue DAGTypeLegalizer::SplitVecOp_UnaryOp(SDNode *N) {
   // The result has a legal vector type, but the input needs splitting.
   EVT ResVT = N->getValueType(0);
diff --git a/llvm/test/CodeGen/NVPTX/vector-select.ll b/llvm/test/CodeGen/NVPTX/vector-select.ll
new file mode 100644 (file)
index 0000000..11893df
--- /dev/null
@@ -0,0 +1,16 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20
+
+; This test makes sure that vector selects are scalarized by the type legalizer.
+; If not, type legalization will fail.
+
+define void @foo(<2 x i32> addrspace(1)* %def_a, <2 x i32> addrspace(1)* %def_b, <2 x i32> addrspace(1)* %def_c) {
+entry:
+  %tmp4 = load <2 x i32> addrspace(1)* %def_a
+  %tmp6 = load <2 x i32> addrspace(1)* %def_c
+  %tmp8 = load <2 x i32> addrspace(1)* %def_b
+  %0 = icmp sge <2 x i32> %tmp4, zeroinitializer
+  %cond = select <2 x i1> %0, <2 x i32> %tmp6, <2 x i32> %tmp8
+  store <2 x i32> %cond, <2 x i32> addrspace(1)* %def_c
+  ret void
+}