Enable llvm's isa/cast/dyn_cast on MemAccInst.
authorHongbin Zheng <etherzhhb@gmail.com>
Sat, 27 Feb 2016 01:49:58 +0000 (01:49 +0000)
committerHongbin Zheng <etherzhhb@gmail.com>
Sat, 27 Feb 2016 01:49:58 +0000 (01:49 +0000)
Differential Revision: http://reviews.llvm.org/D17250

llvm-svn: 262100

polly/include/polly/Support/ScopHelper.h
polly/lib/Analysis/ScopDetection.cpp
polly/lib/Analysis/ScopInfo.cpp

index aa1a662..165f1f9 100644 (file)
@@ -242,6 +242,10 @@ public:
 
   bool isNull() const { return !I; }
   bool isInstruction() const { return I; }
+
+  llvm::Instruction *asInstruction() const { return I; }
+
+private:
   bool isLoad() const { return I && llvm::isa<llvm::LoadInst>(I); }
   bool isStore() const { return I && llvm::isa<llvm::StoreInst>(I); }
   bool isCallInst() const { return I && llvm::isa<llvm::CallInst>(I); }
@@ -251,7 +255,6 @@ public:
     return I && llvm::isa<llvm::MemTransferInst>(I);
   }
 
-  llvm::Instruction *asInstruction() const { return I; }
   llvm::LoadInst *asLoad() const { return llvm::cast<llvm::LoadInst>(I); }
   llvm::StoreInst *asStore() const { return llvm::cast<llvm::StoreInst>(I); }
   llvm::CallInst *asCallInst() const { return llvm::cast<llvm::CallInst>(I); }
@@ -265,6 +268,20 @@ public:
     return llvm::cast<llvm::MemTransferInst>(I);
   }
 };
+}
+
+namespace llvm {
+/// @brief Specialize simplify_type for MemAccInst to enable dyn_cast and cast
+///        from a MemAccInst object.
+template <> struct simplify_type<polly::MemAccInst> {
+  typedef Instruction *SimpleType;
+  static SimpleType getSimplifiedValue(polly::MemAccInst &I) {
+    return I.asInstruction();
+  }
+};
+}
+
+namespace polly {
 
 /// @brief Check if the PHINode has any incoming Invoke edge.
 ///
index 7cfb03c..f621b2d 100644 (file)
@@ -989,8 +989,8 @@ bool ScopDetection::isValidInstruction(Instruction &Inst,
 
   // Check the access function.
   if (auto MemInst = MemAccInst::dyn_cast(Inst)) {
-    Context.hasStores |= MemInst.isStore();
-    Context.hasLoads |= MemInst.isLoad();
+    Context.hasStores |= isa<StoreInst>(MemInst);
+    Context.hasLoads |= isa<LoadInst>(MemInst);
     if (!MemInst.isSimple())
       return invalid<ReportNonSimpleMemoryAccess>(Context, /*Assert=*/true,
                                                   &Inst);
index 71f1b62..fba2ab9 100644 (file)
@@ -616,9 +616,7 @@ void MemoryAccess::assumeNoOutOfBound() {
 }
 
 void MemoryAccess::buildMemIntrinsicAccessRelation() {
-  auto MAI = MemAccInst(getAccessInstruction());
-  (void)MAI;
-  assert(MAI.isMemIntrinsic());
+  assert(isa<MemIntrinsic>(getAccessInstruction()));
   assert(Subscripts.size() == 2 && Sizes.size() == 0);
 
   auto *SubscriptPWA = Statement->getPwAff(Subscripts[0]);
@@ -646,7 +644,7 @@ void MemoryAccess::computeBoundsOnAccessRelation(unsigned ElementSize) {
   ScalarEvolution *SE = Statement->getParent()->getSE();
 
   auto MAI = MemAccInst(getAccessInstruction());
-  if (MAI.isMemIntrinsic())
+  if (isa<MemIntrinsic>(MAI))
     return;
 
   Value *Ptr = MAI.getPointerOperand();
@@ -2613,8 +2611,8 @@ bool Scop::buildAliasGroups(AliasAnalysis &AA) {
       if (!MA->isRead())
         HasWriteAccess.insert(MA->getBaseAddr());
       MemAccInst Acc(MA->getAccessInstruction());
-      if (MA->isRead() && Acc.isMemTransferInst())
-        PtrToAcc[Acc.asMemTransferInst()->getSource()] = MA;
+      if (MA->isRead() && isa<MemTransferInst>(Acc))
+        PtrToAcc[cast<MemTransferInst>(Acc)->getSource()] = MA;
       else
         PtrToAcc[Acc.getPointerOperand()] = MA;
       AST.add(Acc);
@@ -3850,7 +3848,7 @@ bool ScopInfo::buildAccessMultiDimFixed(
   const SCEVUnknown *BasePointer =
       dyn_cast<SCEVUnknown>(SE->getPointerBase(AccessFunction));
   enum MemoryAccess::AccessType Type =
-      Inst.isLoad() ? MemoryAccess::READ : MemoryAccess::MUST_WRITE;
+      isa<LoadInst>(Inst) ? MemoryAccess::READ : MemoryAccess::MUST_WRITE;
 
   if (auto *BitCast = dyn_cast<BitCastInst>(Address)) {
     auto *Src = BitCast->getOperand(0);
@@ -3905,7 +3903,7 @@ bool ScopInfo::buildAccessMultiDimParam(
   Type *ElementType = Val->getType();
   unsigned ElementSize = DL->getTypeAllocSize(ElementType);
   enum MemoryAccess::AccessType Type =
-      Inst.isLoad() ? MemoryAccess::READ : MemoryAccess::MUST_WRITE;
+      isa<LoadInst>(Inst) ? MemoryAccess::READ : MemoryAccess::MUST_WRITE;
 
   const SCEV *AccessFunction = SE->getSCEVAtScope(Address, L);
   const SCEVUnknown *BasePointer =
@@ -3942,10 +3940,12 @@ bool ScopInfo::buildAccessMemIntrinsic(
     MemAccInst Inst, Loop *L, Region *R,
     const ScopDetection::BoxedLoopsSetTy *BoxedLoops,
     const InvariantLoadsSetTy &ScopRIL) {
-  if (!Inst.isMemIntrinsic())
+  auto *MemIntr = dyn_cast_or_null<MemIntrinsic>(Inst);
+
+  if (MemIntr == nullptr)
     return false;
 
-  auto *LengthVal = SE->getSCEVAtScope(Inst.asMemIntrinsic()->getLength(), L);
+  auto *LengthVal = SE->getSCEVAtScope(MemIntr->getLength(), L);
   assert(LengthVal);
 
   // Check if the length val is actually affine or if we overapproximate it
@@ -3957,7 +3957,7 @@ bool ScopInfo::buildAccessMemIntrinsic(
   if (!LengthIsAffine)
     LengthVal = nullptr;
 
-  auto *DestPtrVal = Inst.asMemIntrinsic()->getDest();
+  auto *DestPtrVal = MemIntr->getDest();
   assert(DestPtrVal);
   auto *DestAccFunc = SE->getSCEVAtScope(DestPtrVal, L);
   assert(DestAccFunc);
@@ -3968,10 +3968,11 @@ bool ScopInfo::buildAccessMemIntrinsic(
                  IntegerType::getInt8Ty(DestPtrVal->getContext()), false,
                  {DestAccFunc, LengthVal}, {}, Inst.getValueOperand());
 
-  if (!Inst.isMemTransferInst())
+  auto *MemTrans = dyn_cast<MemTransferInst>(MemIntr);
+  if (!MemTrans)
     return true;
 
-  auto *SrcPtrVal = Inst.asMemTransferInst()->getSource();
+  auto *SrcPtrVal = MemTrans->getSource();
   assert(SrcPtrVal);
   auto *SrcAccFunc = SE->getSCEVAtScope(SrcPtrVal, L);
   assert(SrcAccFunc);
@@ -3989,30 +3990,31 @@ bool ScopInfo::buildAccessCallInst(
     MemAccInst Inst, Loop *L, Region *R,
     const ScopDetection::BoxedLoopsSetTy *BoxedLoops,
     const InvariantLoadsSetTy &ScopRIL) {
-  if (!Inst.isCallInst())
+  auto *CI = dyn_cast_or_null<CallInst>(Inst);
+
+  if (CI == nullptr)
     return false;
 
-  auto &CI = *Inst.asCallInst();
-  if (CI.doesNotAccessMemory() || isIgnoredIntrinsic(&CI))
+  if (CI->doesNotAccessMemory() || isIgnoredIntrinsic(CI))
     return true;
 
   bool ReadOnly = false;
-  auto *AF = SE->getConstant(IntegerType::getInt64Ty(CI.getContext()), 0);
-  auto *CalledFunction = CI.getCalledFunction();
+  auto *AF = SE->getConstant(IntegerType::getInt64Ty(CI->getContext()), 0);
+  auto *CalledFunction = CI->getCalledFunction();
   switch (AA->getModRefBehavior(CalledFunction)) {
   case llvm::FMRB_UnknownModRefBehavior:
     llvm_unreachable("Unknown mod ref behaviour cannot be represented.");
   case llvm::FMRB_DoesNotAccessMemory:
     return true;
   case llvm::FMRB_OnlyReadsMemory:
-    GlobalReads.push_back(&CI);
+    GlobalReads.push_back(CI);
     return true;
   case llvm::FMRB_OnlyReadsArgumentPointees:
     ReadOnly = true;
   // Fall through
   case llvm::FMRB_OnlyAccessesArgumentPointees:
     auto AccType = ReadOnly ? MemoryAccess::READ : MemoryAccess::MAY_WRITE;
-    for (const auto &Arg : CI.arg_operands()) {
+    for (const auto &Arg : CI->arg_operands()) {
       if (!Arg->getType()->isPointerTy())
         continue;
 
@@ -4022,7 +4024,7 @@ bool ScopInfo::buildAccessCallInst(
 
       auto *ArgBasePtr = cast<SCEVUnknown>(SE->getPointerBase(ArgSCEV));
       addArrayAccess(Inst, AccType, ArgBasePtr->getValue(),
-                     ArgBasePtr->getType(), false, {AF}, {}, &CI);
+                     ArgBasePtr->getType(), false, {AF}, {}, CI);
     }
     return true;
   }
@@ -4038,7 +4040,7 @@ void ScopInfo::buildAccessSingleDim(
   Value *Val = Inst.getValueOperand();
   Type *ElementType = Val->getType();
   enum MemoryAccess::AccessType Type =
-      Inst.isLoad() ? MemoryAccess::READ : MemoryAccess::MUST_WRITE;
+      isa<LoadInst>(Inst) ? MemoryAccess::READ : MemoryAccess::MUST_WRITE;
 
   const SCEV *AccessFunction = SE->getSCEVAtScope(Address, L);
   const SCEVUnknown *BasePointer =