Masked gather and scatter: Added code for SelectionDAG.
authorElena Demikhovsky <elena.demikhovsky@intel.com>
Tue, 28 Apr 2015 07:57:37 +0000 (07:57 +0000)
committerElena Demikhovsky <elena.demikhovsky@intel.com>
Tue, 28 Apr 2015 07:57:37 +0000 (07:57 +0000)
All other patches, including tests will follow.

http://reviews.llvm.org/D7665

llvm-svn: 235970

llvm/include/llvm/CodeGen/ISDOpcodes.h
llvm/include/llvm/CodeGen/SelectionDAG.h
llvm/include/llvm/CodeGen/SelectionDAGNodes.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h

index 2d1c8cd..9b7b44b 100644 (file)
@@ -687,9 +687,16 @@ namespace ISD {
     ATOMIC_LOAD_UMIN,
     ATOMIC_LOAD_UMAX,
 
-    // Masked load and store
+    // Masked load and store - consecutive vector load and store operations
+    // with additional mask operand that prevents memory accesses to the
+    // masked-off lanes.
     MLOAD, MSTORE,
 
+    // Masked gather and scatter - load and store operations for a vector of
+    // random addresses with additional mask operand that prevents memory
+    // accesses to the masked-off lanes.
+    MGATHER, MSCATTER,
+
     /// This corresponds to the llvm.lifetime.* intrinsics. The first operand
     /// is the chain and the second operand is the alloca pointer.
     LIFETIME_START, LIFETIME_END,
index 582febd..874842b 100644 (file)
@@ -856,6 +856,10 @@ public:
   SDValue getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val,
                          SDValue Ptr, SDValue Mask, EVT MemVT,
                          MachineMemOperand *MMO, bool IsTrunc);
+  SDValue getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl,
+                          ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
+  SDValue getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl,
+                           ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
   /// Construct a node to track a Value* through the backend.
   SDValue getSrcValue(const Value *v);
 
index 66f060c..148eaa9 100644 (file)
@@ -1151,6 +1151,8 @@ public:
            N->getOpcode() == ISD::ATOMIC_STORE        ||
            N->getOpcode() == ISD::MLOAD               ||
            N->getOpcode() == ISD::MSTORE              ||
+           N->getOpcode() == ISD::MGATHER             ||
+           N->getOpcode() == ISD::MSCATTER            ||
            N->isMemIntrinsic()                        ||
            N->isTargetMemoryOpcode();
   }
@@ -1987,6 +1989,82 @@ public:
   }
 };
 
+/// This is a base class is used to represent
+/// MGATHER and MSCATTER nodes
+///
+class MaskedGatherScatterSDNode : public MemSDNode {
+  // Operands
+  SDUse Ops[5];
+public:
+  friend class SelectionDAG;
+  MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, DebugLoc dl,
+                            ArrayRef<SDValue> Operands, SDVTList VTs, EVT MemVT,
+                            MachineMemOperand *MMO)
+    : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {
+    assert(Operands.size() == 5 && "Incompatible number of operands");
+    InitOperands(Ops, Operands.data(), Operands.size());
+  }
+
+  // In the both nodes address is Op1, mask is Op2:
+  // MaskedGatherSDNode  (Chain, src0, mask, base, index), src0 is a passthru value
+  // MaskedScatterSDNode (Chain, value, mask, base, index)
+  // Mask is a vector of i1 elements
+  const SDValue &getBasePtr() const { return getOperand(3); }
+  const SDValue &getIndex()   const { return getOperand(4); }
+  const SDValue &getMask()    const { return getOperand(2); }
+  const SDValue &getValue()   const { return getOperand(1); }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::MGATHER ||
+           N->getOpcode() == ISD::MSCATTER;
+  }
+};
+
+/// This class is used to represent an MGATHER node
+///
+class MaskedGatherSDNode : public MaskedGatherScatterSDNode {
+public:
+  friend class SelectionDAG;
+  MaskedGatherSDNode(unsigned Order, DebugLoc dl, ArrayRef<SDValue> Operands, 
+                     SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
+    : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, Operands, VTs, MemVT,
+                                MMO) {
+    assert(getValue().getValueType() == getValueType(0) &&
+           "Incompatible type of the PathThru value in MaskedGatherSDNode");
+    assert(getMask().getValueType().getVectorNumElements() == 
+           getValueType(0).getVectorNumElements() && 
+           "Vector width mismatch between mask and data");
+    assert(getMask().getValueType().getScalarType() == MVT::i1 && 
+           "Vector width mismatch between mask and data");
+  }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::MGATHER;
+  }
+};
+
+/// This class is used to represent an MSCATTER node
+///
+class MaskedScatterSDNode : public MaskedGatherScatterSDNode {
+
+public:
+  friend class SelectionDAG;
+  MaskedScatterSDNode(unsigned Order, DebugLoc dl,ArrayRef<SDValue> Operands,
+                      SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
+    : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, Operands, VTs, MemVT,
+                                MMO) {
+    assert(getMask().getValueType().getVectorNumElements() == 
+           getValue().getValueType().getVectorNumElements() && 
+           "Vector width mismatch between mask and data");
+    assert(getMask().getValueType().getScalarType() == MVT::i1 && 
+           "Vector width mismatch between mask and data");
+  }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::MSCATTER;
+  }
+};
+
 /// An SDNode that represents everything that will be needed
 /// to construct a MachineInstr. These nodes are created during the
 /// instruction selection proper phase.
@@ -2078,7 +2156,7 @@ template <> struct GraphTraits<SDNode*> {
 };
 
 /// The largest SDNode class.
-typedef AtomicSDNode LargestSDNode;
+typedef MaskedGatherScatterSDNode LargestSDNode;
 
 /// The SDNode class with the greatest alignment requirement.
 typedef GlobalAddressSDNode MostAlignedSDNode;
index 5b76698..3900176 100644 (file)
@@ -5097,6 +5097,55 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val,
   return SDValue(N, 0);
 }
 
+SDValue
+SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl,
+                              ArrayRef<SDValue> Ops,
+                              MachineMemOperand *MMO) {
+
+  FoldingSetNodeID ID;
+  AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
+  ID.AddInteger(VT.getRawBits());
+  ID.AddInteger(encodeMemSDNodeFlags(ISD::NON_EXTLOAD, ISD::UNINDEXED,
+                                     MMO->isVolatile(),
+                                     MMO->isNonTemporal(),
+                                     MMO->isInvariant()));
+  ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  void *IP = nullptr;
+  if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) {
+    cast<MaskedGatherSDNode>(E)->refineAlignment(MMO);
+    return SDValue(E, 0);
+  }
+  MaskedGatherSDNode *N = 
+    new (NodeAllocator) MaskedGatherSDNode(dl.getIROrder(), dl.getDebugLoc(),
+                                           Ops, VTs, VT, MMO);
+  CSEMap.InsertNode(N, IP);
+  InsertNode(N);
+  return SDValue(N, 0);
+}
+
+SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl,
+                                       ArrayRef<SDValue> Ops,
+                                       MachineMemOperand *MMO) {
+  FoldingSetNodeID ID;
+  AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
+  ID.AddInteger(VT.getRawBits());
+  ID.AddInteger(encodeMemSDNodeFlags(false, ISD::UNINDEXED, MMO->isVolatile(),
+                                     MMO->isNonTemporal(),
+                                     MMO->isInvariant()));
+  ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+  void *IP = nullptr;
+  if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) {
+    cast<MaskedScatterSDNode>(E)->refineAlignment(MMO);
+    return SDValue(E, 0);
+  }
+  SDNode *N =
+    new (NodeAllocator) MaskedScatterSDNode(dl.getIROrder(), dl.getDebugLoc(),
+                                            Ops, VTs, VT, MMO);
+  CSEMap.InsertNode(N, IP);
+  InsertNode(N);
+  return SDValue(N, 0);
+}
+
 SDValue SelectionDAG::getVAArg(EVT VT, SDLoc dl,
                                SDValue Chain, SDValue Ptr,
                                SDValue SV,
index 71b221a..cc97fbe 100644 (file)
@@ -1059,6 +1059,12 @@ SDValue SelectionDAGBuilder::getValue(const Value *V) {
   return Val;
 }
 
+// Return true if SDValue exists for the given Value
+bool SelectionDAGBuilder::findValue(const Value *V) const {
+  return (NodeMap.find(V) != NodeMap.end()) ||
+    (FuncInfo.ValueMap.find(V) != FuncInfo.ValueMap.end());
+}
+
 /// getNonRegisterValue - Return an SDValue for the given Value, but
 /// don't look in FuncInfo.ValueMap for a virtual register.
 SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) {
@@ -3026,6 +3032,92 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I) {
   setValue(&I, StoreNode);
 }
 
+// Gather/scatter receive a vector of pointers.
+// This vector of pointers may be represented as a base pointer + vector of 
+// indices, it depends on GEP and instruction preceeding GEP
+// that calculates indices
+static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index,
+                           SelectionDAGBuilder* SDB) {
+
+  assert (Ptr->getType()->isVectorTy() && "Uexpected pointer type");
+  GetElementPtrInst *Gep = dyn_cast<GetElementPtrInst>(Ptr);
+  if (!Gep || Gep->getNumOperands() > 2)
+    return false;
+  ShuffleVectorInst *ShuffleInst = 
+    dyn_cast<ShuffleVectorInst>(Gep->getPointerOperand());
+  if (!ShuffleInst || !ShuffleInst->getMask()->isNullValue() ||
+      cast<Instruction>(ShuffleInst->getOperand(0))->getOpcode() !=
+      Instruction::InsertElement)
+    return false;
+
+  Ptr = cast<InsertElementInst>(ShuffleInst->getOperand(0))->getOperand(1);
+
+  SelectionDAG& DAG = SDB->DAG;
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  // Check is the Ptr is inside current basic block
+  // If not, look for the shuffle instruction
+  if (SDB->findValue(Ptr))
+    Base = SDB->getValue(Ptr);
+  else if (SDB->findValue(ShuffleInst)) {
+    SDValue ShuffleNode = SDB->getValue(ShuffleInst);
+    Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(ShuffleNode),
+                       ShuffleNode.getValueType().getScalarType(), ShuffleNode,
+                       DAG.getConstant(0, TLI.getVectorIdxTy()));
+    SDB->setValue(Ptr, Base);
+  }
+  else
+    return false;
+
+  Value *IndexVal = Gep->getOperand(1);
+  if (SDB->findValue(IndexVal)) {
+    Index = SDB->getValue(IndexVal);
+
+    if (SExtInst* Sext = dyn_cast<SExtInst>(IndexVal)) {
+      IndexVal = Sext->getOperand(0);
+      if (SDB->findValue(IndexVal))
+        Index = SDB->getValue(IndexVal);
+    }
+    return true;
+  }
+  return false;
+}
+
+void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
+  SDLoc sdl = getCurSDLoc();
+
+  // llvm.masked.scatter.*(Src0, Ptrs, alignemt, Mask)
+  Value  *Ptr = I.getArgOperand(1);
+  SDValue Src0 = getValue(I.getArgOperand(0));
+  SDValue Mask = getValue(I.getArgOperand(3));
+  EVT VT = Src0.getValueType();
+  unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(2)))->getZExtValue();
+  if (!Alignment)
+    Alignment = DAG.getEVTAlignment(VT);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  AAMDNodes AAInfo;
+  I.getAAMetadata(AAInfo);
+
+  SDValue Base;
+  SDValue Index;
+  Value *BasePtr = Ptr;
+  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
+
+  Value *MemOpBasePtr = UniformBase ? BasePtr : NULL;
+  MachineMemOperand *MMO = DAG.getMachineFunction().
+    getMachineMemOperand(MachinePointerInfo(MemOpBasePtr),
+                         MachineMemOperand::MOStore,  VT.getStoreSize(),
+                         Alignment, AAInfo);
+  if (!UniformBase) {
+    Base = DAG.getTargetConstant(0, TLI.getPointerTy());
+    Index = getValue(Ptr);
+  }
+  SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index };
+  SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO);
+  DAG.setRoot(Scatter);
+  setValue(&I, Scatter);
+}
+
 void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
   SDLoc sdl = getCurSDLoc();
 
@@ -3067,6 +3159,60 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
   setValue(&I, Load);
 }
 
+void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
+  SDLoc sdl = getCurSDLoc();
+
+  // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
+  Value  *Ptr = I.getArgOperand(0);
+  SDValue Src0 = getValue(I.getArgOperand(3));
+  SDValue Mask = getValue(I.getArgOperand(2));
+
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  EVT VT = TLI.getValueType(I.getType());
+  unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(1)))->getZExtValue();
+  if (!Alignment)
+    Alignment = DAG.getEVTAlignment(VT);
+
+  AAMDNodes AAInfo;
+  I.getAAMetadata(AAInfo);
+  const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range);
+
+  SDValue Root = DAG.getRoot();
+  SDValue Base;
+  SDValue Index;
+  Value *BasePtr = Ptr;
+  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
+  bool ConstantMemory = false;
+  if (UniformBase && AA->pointsToConstantMemory(
+      AliasAnalysis::Location(BasePtr,
+                                   AA->getTypeStoreSize(I.getType()),
+                              AAInfo))) {
+    // Do not serialize (non-volatile) loads of constant memory with anything.
+    Root = DAG.getEntryNode();
+    ConstantMemory = true;
+  }
+
+  MachineMemOperand *MMO =
+    DAG.getMachineFunction().
+    getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : NULL),
+                          MachineMemOperand::MOLoad,  VT.getStoreSize(),
+                          Alignment, AAInfo, Ranges);
+
+  if (!UniformBase) {
+    Base = DAG.getTargetConstant(0, TLI.getPointerTy());
+    Index = getValue(Ptr);
+  }
+
+  SDValue Ops[] = { Root, Src0, Mask, Base, Index };
+  SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
+                                       Ops, MMO);
+
+  SDValue OutChain = Gather.getValue(1);
+  if (!ConstantMemory)
+    PendingLoads.push_back(OutChain);
+  setValue(&I, Gather);
+}
+
 void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) {
   SDLoc dl = getCurSDLoc();
   AtomicOrdering SuccessOrder = I.getSuccessOrdering();
@@ -4216,9 +4362,13 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) {
     return nullptr;
   }
 
+  case Intrinsic::masked_gather:
+    visitMaskedGather(I);
   case Intrinsic::masked_load:
     visitMaskedLoad(I);
     return nullptr;
+  case Intrinsic::masked_scatter:
+    visitMaskedScatter(I);
   case Intrinsic::masked_store:
     visitMaskedStore(I);
     return nullptr;
index ba6d974..3e232e7 100644 (file)
@@ -667,6 +667,8 @@ public:
   // generate the debug data structures now that we've seen its definition.
   void resolveDanglingDebugInfo(const Value *V, SDValue Val);
   SDValue getValue(const Value *V);
+  bool findValue(const Value *V) const;
+
   SDValue getNonRegisterValue(const Value *V);
   SDValue getValueImpl(const Value *V);
 
@@ -814,6 +816,8 @@ private:
   void visitStore(const StoreInst &I);
   void visitMaskedLoad(const CallInst &I);
   void visitMaskedStore(const CallInst &I);
+  void visitMaskedGather(const CallInst &I);
+  void visitMaskedScatter(const CallInst &I);
   void visitAtomicCmpXchg(const AtomicCmpXchgInst &I);
   void visitAtomicRMW(const AtomicRMWInst &I);
   void visitFence(const FenceInst &I);