[Acc] Do not statically dispatch into IslNodeBuilder's createFor
authorPhilip Pfaffe <philip.pfaffe@gmail.com>
Sun, 29 Oct 2017 21:36:34 +0000 (21:36 +0000)
committerPhilip Pfaffe <philip.pfaffe@gmail.com>
Sun, 29 Oct 2017 21:36:34 +0000 (21:36 +0000)
Summary:
When GPUNodeBuilder creates loops inside the kernel, it dispatches to
IslNodeBuilder. This however is surprisingly dangerous, since it accesses the
AST Node's user through the wrong type. This patch fixes this problem by
overriding createFor correctly.

This fixes PR35010.

Reviewers: grosser, bollu, Meinersbur

Reviewed By: Meinersbur

Subscribers: Meinersbur, nemanjai, pollydev, llvm-commits, kbarton

Differential Revision: https://reviews.llvm.org/D39364

llvm-svn: 316872

polly/include/polly/CodeGen/IslNodeBuilder.h
polly/lib/CodeGen/IslNodeBuilder.cpp
polly/lib/CodeGen/PPCGCodeGeneration.cpp

index 7b3ce6cdeb8370ba22980ba049aba973ab4b09a3..cae93eef6a7b8f9932c1ae74b14b6a5c2befa90e 100644 (file)
@@ -351,7 +351,7 @@ protected:
   bool preloadInvariantEquivClass(InvariantEquivClassTy &IAClass);
 
   void createForVector(__isl_take isl_ast_node *For, int VectorWidth);
-  void createForSequential(__isl_take isl_ast_node *For, bool KnownParallel);
+  void createForSequential(__isl_take isl_ast_node *For, bool MarkParallel);
 
   /// Create LLVM-IR that executes a for node thread parallel.
   ///
index 35210b714be278ef90b6ec3ee5c4023d4a5b7bc4..d6dafcbcb4a2dc024d38fd2d7ae1d4edc45b449b 100644 (file)
@@ -530,7 +530,7 @@ static bool IsLoopVectorizerDisabled(isl::ast_node Node) {
 }
 
 void IslNodeBuilder::createForSequential(__isl_take isl_ast_node *For,
-                                         bool KnownParallel) {
+                                         bool MarkParallel) {
   isl_ast_node *Body;
   isl_ast_expr *Init, *Inc, *Iterator, *UB;
   isl_id *IteratorID;
@@ -539,10 +539,6 @@ void IslNodeBuilder::createForSequential(__isl_take isl_ast_node *For,
   BasicBlock *ExitBlock;
   Value *IV;
   CmpInst::Predicate Predicate;
-  bool Parallel;
-
-  Parallel = KnownParallel || (IslAstInfo::isParallel(For) &&
-                               !IslAstInfo::isReductionParallel(For));
 
   bool LoopVectorizerDisabled =
       IsLoopVectorizerDisabled(isl::manage(isl_ast_node_copy(For)));
@@ -582,13 +578,13 @@ void IslNodeBuilder::createForSequential(__isl_take isl_ast_node *For,
   bool UseGuardBB =
       !SE.isKnownPredicate(Predicate, SE.getSCEV(ValueLB), SE.getSCEV(ValueUB));
   IV = createLoop(ValueLB, ValueUB, ValueInc, Builder, LI, DT, ExitBlock,
-                  Predicate, &Annotator, Parallel, UseGuardBB,
+                  Predicate, &Annotator, MarkParallel, UseGuardBB,
                   LoopVectorizerDisabled);
   IDToValue[IteratorID] = IV;
 
   create(Body);
 
-  Annotator.popLoop(Parallel);
+  Annotator.popLoop(MarkParallel);
 
   IDToValue.erase(IDToValue.find(IteratorID));
 
@@ -795,7 +791,9 @@ void IslNodeBuilder::createFor(__isl_take isl_ast_node *For) {
     createForParallel(For);
     return;
   }
-  createForSequential(For, false);
+  bool Parallel =
+      (IslAstInfo::isParallel(For) && !IslAstInfo::isReductionParallel(For));
+  createForSequential(For, Parallel);
 }
 
 void IslNodeBuilder::createIf(__isl_take isl_ast_node *If) {
index ed45e0738ffb08fc5d58153c53bf881e429031b9..720e2564d3e9fe0ecdfe438e75248ad31115411b 100644 (file)
@@ -415,6 +415,8 @@ private:
   /// @param UserStmt The ast node to generate code for.
   virtual void createUser(__isl_take isl_ast_node *UserStmt);
 
+  virtual void createFor(__isl_take isl_ast_node *Node);
+
   enum DataDirection { HOST_TO_DEVICE, DEVICE_TO_HOST };
 
   /// Create code for a data transfer statement
@@ -1286,6 +1288,11 @@ void GPUNodeBuilder::createUser(__isl_take isl_ast_node *UserStmt) {
   isl_ast_node_free(UserStmt);
   return;
 }
+
+void GPUNodeBuilder::createFor(__isl_take isl_ast_node *Node) {
+  createForSequential(Node, false);
+}
+
 void GPUNodeBuilder::createKernelCopy(ppcg_kernel_stmt *KernelStmt) {
   isl_ast_expr *LocalIndex = isl_ast_expr_copy(KernelStmt->u.c.local_index);
   LocalIndex = isl_ast_expr_address_of(LocalIndex);