[FuncSpec] Support specialising recursive functions
authorSjoerd Meijer <sjoerd.meijer@arm.com>
Tue, 3 Aug 2021 19:42:09 +0000 (20:42 +0100)
committerSjoerd Meijer <sjoerd.meijer@arm.com>
Wed, 4 Aug 2021 07:07:04 +0000 (08:07 +0100)
This adds support for specialising recursive functions. For example:

    int Global = 1;
    void recursiveFunc(int *arg) {
      if (*arg < 4) {
        print(*arg);
        recursiveFunc(*arg + 1);
      }
    }
    void main() {
      recursiveFunc(&Global);
    }

After 3 iterations of function specialisation, followed by inlining of the
specialised versions of recursiveFunc, the main function looks like this:

    void main() {
      print(1);
      print(2);
      print(3);
    }

To support this, the following has been added:
- Update the solver and state of the new specialised functions,
- An optimisation to propagate constant stack values after each iteration of
  function specialisation, which is necessary for the next iteration to
  recognise the constant values and trigger.

Specialising recursive functions is (at the moment) controlled by option
-func-specialization-max-iters and is opt-in for compile-time reasons. I.e.,
the default is -func-specialization-max-iters=1, but for the example above we
would need to use -func-specialization-max-iters=3. Future work is to see if we
can increase the default, or improve the cost-model/heuristics to control
compile-times.

Differential Revision: https://reviews.llvm.org/D106426

llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll
llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive2.ll [new file with mode: 0644]
llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive3.ll [new file with mode: 0644]
llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive4.ll [new file with mode: 0644]

index f61f431..87fe6cd 100644 (file)
@@ -11,7 +11,6 @@
 // are propagated to the callee by specializing the function.
 //
 // Current limitations:
-// - It does not handle specialization of recursive functions,
 // - It does not yet handle integer ranges.
 // - Only 1 argument per function is specialised,
 // - The cost-model could be further looked into,
@@ -68,9 +67,142 @@ static cl::opt<bool> EnableSpecializationForLiteralConstant(
     "function-specialization-for-literal-constant", cl::init(false), cl::Hidden,
     cl::desc("Make function specialization available for literal constant."));
 
+// Helper to check if \p LV is either a constant or a constant
+// range with a single element. This should cover exactly the same cases as the
+// old ValueLatticeElement::isConstant() and is intended to be used in the
+// transition to ValueLatticeElement.
+static bool isConstant(const ValueLatticeElement &LV) {
+  return LV.isConstant() ||
+         (LV.isConstantRange() && LV.getConstantRange().isSingleElement());
+}
+
 // Helper to check if \p LV is either overdefined or a constant int.
 static bool isOverdefined(const ValueLatticeElement &LV) {
-  return !LV.isUnknownOrUndef() && !LV.isConstant();
+  return !LV.isUnknownOrUndef() && !isConstant(LV);
+}
+
+static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) {
+  Value *StoreValue = nullptr;
+  for (auto *User : Alloca->users()) {
+    // We can't use llvm::isAllocaPromotable() as that would fail because of
+    // the usage in the CallInst, which is what we check here.
+    if (User == Call)
+      continue;
+    if (auto *Bitcast = dyn_cast<BitCastInst>(User)) {
+      if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call)
+        return nullptr;
+      continue;
+    }
+
+    if (auto *Store = dyn_cast<StoreInst>(User)) {
+      // This is a duplicate store, bail out.
+      if (StoreValue || Store->isVolatile())
+        return nullptr;
+      StoreValue = Store->getValueOperand();
+      continue;
+    }
+    // Bail if there is any other unknown usage.
+    return nullptr;
+  }
+  return dyn_cast_or_null<Constant>(StoreValue);
+}
+
+// A constant stack value is an AllocaInst that has a single constant
+// value stored to it. Return this constant if such an alloca stack value
+// is a function argument.
+static Constant *getConstantStackValue(CallInst *Call, Value *Val,
+                                       SCCPSolver &Solver) {
+  if (!Val)
+    return nullptr;
+  Val = Val->stripPointerCasts();
+  if (auto *ConstVal = dyn_cast<ConstantInt>(Val))
+    return ConstVal;
+  auto *Alloca = dyn_cast<AllocaInst>(Val);
+  if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy())
+    return nullptr;
+  return getPromotableAlloca(Alloca, Call);
+}
+
+// To support specializing recursive functions, it is important to propagate
+// constant arguments because after a first iteration of specialisation, a
+// reduced example may look like this:
+//
+//     define internal void @RecursiveFn(i32* arg1) {
+//       %temp = alloca i32, align 4
+//       store i32 2 i32* %temp, align 4
+//       call void @RecursiveFn.1(i32* nonnull %temp)
+//       ret void
+//     }
+//
+// Before a next iteration, we need to propagate the constant like so
+// which allows further specialization in next iterations.
+//
+//     @funcspec.arg = internal constant i32 2
+//
+//     define internal void @someFunc(i32* arg1) {
+//       call void @otherFunc(i32* nonnull @funcspec.arg)
+//       ret void
+//     }
+//
+static void constantArgPropagation(SmallVectorImpl<Function *> &WorkList,
+                                   Module &M, SCCPSolver &Solver) {
+  // Iterate over the argument tracked functions see if there
+  // are any new constant values for the call instruction via
+  // stack variables.
+  for (auto *F : WorkList) {
+    // TODO: Generalize for any read only arguments.
+    if (F->arg_size() != 1)
+      continue;
+
+    auto &Arg = *F->arg_begin();
+    if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy())
+      continue;
+
+    for (auto *User : F->users()) {
+      auto *Call = dyn_cast<CallInst>(User);
+      if (!Call)
+        break;
+      auto *ArgOp = Call->getArgOperand(0);
+      auto *ArgOpType = ArgOp->getType();
+      auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver);
+      if (!ConstVal)
+        break;
+
+      Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
+                                     GlobalValue::InternalLinkage, ConstVal,
+                                     "funcspec.arg");
+
+      if (ArgOpType != ConstVal->getType())
+        GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOp->getType());
+
+      Call->setArgOperand(0, GV);
+
+      // Add the changed CallInst to Solver Worklist
+      Solver.visitCall(*Call);
+    }
+  }
+}
+
+// ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics
+// interfere with the constantArgPropagation optimization.
+static void removeSSACopy(Function &F) {
+  for (BasicBlock &BB : F) {
+    for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) {
+      Instruction *Inst = &*BI++;
+      auto *II = dyn_cast<IntrinsicInst>(Inst);
+      if (!II)
+        continue;
+      if (II->getIntrinsicID() != Intrinsic::ssa_copy)
+        continue;
+      Inst->replaceAllUsesWith(II->getOperand(0));
+      Inst->eraseFromParent();
+    }
+  }
+}
+
+static void removeSSACopy(Module &M) {
+  for (Function &F : M)
+    removeSSACopy(F);
 }
 
 class FunctionSpecializer {
@@ -115,9 +247,14 @@ public:
     for (auto *SpecializedFunc : CurrentSpecializations) {
       SpecializedFuncs.insert(SpecializedFunc);
 
-      // TODO: If we want to support specializing specialized functions,
-      // initialize here the state of the newly created functions, marking
-      // them argument-tracked and executable.
+      // Initialize the state of the newly created functions, marking them
+      // argument-tracked and executable.
+      if (SpecializedFunc->hasExactDefinition() &&
+          !SpecializedFunc->hasFnAttribute(Attribute::Naked))
+        Solver.addTrackedFunction(SpecializedFunc);
+      Solver.addArgumentTrackedFunction(SpecializedFunc);
+      FuncDecls.push_back(SpecializedFunc);
+      Solver.markBlockExecutable(&SpecializedFunc->front());
 
       // Replace the function arguments for the specialized functions.
       for (Argument &Arg : SpecializedFunc->args())
@@ -138,12 +275,22 @@ public:
     const ValueLatticeElement &IV = Solver.getLatticeValueFor(V);
     if (isOverdefined(IV))
       return false;
-    auto *Const = IV.isConstant() ? Solver.getConstant(IV)
-                                  : UndefValue::get(V->getType());
+    auto *Const =
+        isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType());
     V->replaceAllUsesWith(Const);
 
-    // TODO: Update the solver here if we want to specialize specialized
-    // functions.
+    for (auto *U : Const->users())
+      if (auto *I = dyn_cast<Instruction>(U))
+        if (Solver.isBlockExecutable(I->getParent()))
+          Solver.visit(I);
+
+    // Remove the instruction from Block and Solver.
+    if (auto *I = dyn_cast<Instruction>(V)) {
+      if (I->isSafeToRemove()) {
+        I->eraseFromParent();
+        Solver.removeLatticeValueFor(I);
+      }
+    }
     return true;
   }
 
@@ -152,6 +299,15 @@ private:
   // also in the cost model.
   unsigned NbFunctionsSpecialized = 0;
 
+  /// Clone the function \p F and remove the ssa_copy intrinsics added by
+  /// the SCCPSolver in the cloned version.
+  Function *cloneCandidateFunction(Function *F) {
+    ValueToValueMapTy EmptyMap;
+    Function *Clone = CloneFunction(F, EmptyMap);
+    removeSSACopy(*Clone);
+    return Clone;
+  }
+
   /// This function decides whether to specialize function \p F based on the
   /// known constant values its arguments can take on. Specialization is
   /// performed on the first interesting argument. Specializations based on
@@ -214,8 +370,7 @@ private:
       for (auto *C : Constants) {
         // Clone the function. We leave the ValueToValueMap empty to allow
         // IPSCCP to propagate the constant arguments.
-        ValueToValueMapTy EmptyMap;
-        Function *Clone = CloneFunction(F, EmptyMap);
+        Function *Clone = cloneCandidateFunction(F);
         Argument *ClonedArg = Clone->arg_begin() + A.getArgNo();
 
         // Rewrite calls to the function so that they call the clone instead.
@@ -231,9 +386,10 @@ private:
         NbFunctionsSpecialized++;
       }
 
-      // TODO: if we want to support specialize specialized functions, and if
-      // the function has been completely specialized, the original function is
-      // no longer needed, so we would need to mark it unreachable here.
+      // If the function has been completely specialized, the original function
+      // is no longer needed. Mark it unreachable.
+      if (!IsPartial)
+        Solver.markFunctionUnreachable(F);
 
       // FIXME: Only one argument per function.
       return true;
@@ -528,24 +684,6 @@ private:
   }
 };
 
-/// Function to clean up the left over intrinsics from SCCP util.
-static void cleanup(Module &M) {
-  for (Function &F : M) {
-    for (BasicBlock &BB : F) {
-      for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) {
-        Instruction *Inst = &*BI++;
-        if (auto *II = dyn_cast<IntrinsicInst>(Inst)) {
-          if (II->getIntrinsicID() == Intrinsic::ssa_copy) {
-            Value *Op = II->getOperand(0);
-            Inst->replaceAllUsesWith(Op);
-            Inst->eraseFromParent();
-          }
-        }
-      }
-    }
-  }
-}
-
 bool llvm::runFunctionSpecialization(
     Module &M, const DataLayout &DL,
     std::function<TargetLibraryInfo &(Function &)> GetTLI,
@@ -637,14 +775,18 @@ bool llvm::runFunctionSpecialization(
   unsigned I = 0;
   while (FuncSpecializationMaxIters != I++ &&
          FS.specializeFunctions(FuncDecls, CurrentSpecializations)) {
-    // TODO: run the solver here for the specialized functions only if we want
-    // to specialize recursively.
+
+    // Run the solver for the specialized functions.
+    RunSCCPSolver(CurrentSpecializations);
+
+    // Replace some unresolved constant arguments
+    constantArgPropagation(FuncDecls, M, Solver);
 
     CurrentSpecializations.clear();
     Changed = true;
   }
 
   // Clean up the IR by removing ssa_copy intrinsics.
-  cleanup(M);
+  removeSSACopy(M);
   return Changed;
 }
index 521fec4..0ad3586 100644 (file)
@@ -1,26 +1,10 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -function-specialization -inline -instcombine -S < %s | FileCheck %s
-
-; TODO: this is a case that would be interesting to support, but we don't yet
-; at the moment.
+; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS2
+; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=3 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS3
+; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=4 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS4
 
 @Global = internal constant i32 1, align 4
 
 define internal void @recursiveFunc(i32* nocapture readonly %arg) {
-; CHECK-LABEL: @recursiveFunc(
-; CHECK-NEXT:    [[TEMP:%.*]] = alloca i32, align 4
-; CHECK-NEXT:    [[ARG_LOAD:%.*]] = load i32, i32* [[ARG:%.*]], align 4
-; CHECK-NEXT:    [[ARG_CMP:%.*]] = icmp slt i32 [[ARG_LOAD]], 4
-; CHECK-NEXT:    br i1 [[ARG_CMP]], label [[BLOCK6:%.*]], label [[RET_BLOCK:%.*]]
-; CHECK:       block6:
-; CHECK-NEXT:    call void @print_val(i32 [[ARG_LOAD]])
-; CHECK-NEXT:    [[ARG_ADD:%.*]] = add nsw i32 [[ARG_LOAD]], 1
-; CHECK-NEXT:    store i32 [[ARG_ADD]], i32* [[TEMP]], align 4
-; CHECK-NEXT:    call void @recursiveFunc(i32* nonnull [[TEMP]])
-; CHECK-NEXT:    br label [[RET_BLOCK]]
-; CHECK:       ret.block:
-; CHECK-NEXT:    ret void
-;
   %temp = alloca i32, align 4
   %arg.load = load i32, i32* %arg, align 4
   %arg.cmp = icmp slt i32 %arg.load, 4
@@ -37,10 +21,28 @@ ret.block:
   ret void
 }
 
+; ITERS2:  @funcspec.arg.3 = internal constant i32 3
+; ITERS3:  @funcspec.arg.5 = internal constant i32 4
+
 define i32 @main() {
-; CHECK-LABEL: @main(
-; CHECK-NEXT:    call void @recursiveFunc(i32* nonnull @Global)
-; CHECK-NEXT:    ret i32 0
+; ITERS2-LABEL: @main(
+; ITERS2-NEXT:    call void @print_val(i32 1)
+; ITERS2-NEXT:    call void @print_val(i32 2)
+; ITERS2-NEXT:    call void @recursiveFunc(i32* nonnull @funcspec.arg.3)
+; ITERS2-NEXT:    ret i32 0
+;
+; ITERS3-LABEL: @main(
+; ITERS3-NEXT:    call void @print_val(i32 1)
+; ITERS3-NEXT:    call void @print_val(i32 2)
+; ITERS3-NEXT:    call void @print_val(i32 3)
+; ITERS3-NEXT:    call void @recursiveFunc(i32* nonnull @funcspec.arg.5)
+; ITERS3-NEXT:    ret i32 0
+;
+; ITERS4-LABEL: @main(
+; ITERS4-NEXT:    call void @print_val(i32 1)
+; ITERS4-NEXT:    call void @print_val(i32 2)
+; ITERS4-NEXT:    call void @print_val(i32 3)
+; ITERS4-NEXT:    ret i32 0
 ;
   call void @recursiveFunc(i32* nonnull @Global)
   ret i32 0
diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive2.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive2.ll
new file mode 100644 (file)
index 0000000..8bac3cf
--- /dev/null
@@ -0,0 +1,32 @@
+; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -S < %s | FileCheck %s
+
+; Volatile store preventing recursive specialisation:
+;
+; CHECK:     @recursiveFunc.1
+; CHECK-NOT: @recursiveFunc.2
+
+@Global = internal constant i32 1, align 4
+
+define internal void @recursiveFunc(i32* nocapture readonly %arg) {
+  %temp = alloca i32, align 4
+  %arg.load = load i32, i32* %arg, align 4
+  %arg.cmp = icmp slt i32 %arg.load, 4
+  br i1 %arg.cmp, label %block6, label %ret.block
+
+block6:
+  call void @print_val(i32 %arg.load)
+  %arg.add = add nsw i32 %arg.load, 1
+  store volatile i32 %arg.add, i32* %temp, align 4
+  call void @recursiveFunc(i32* nonnull %temp)
+  br label %ret.block
+
+ret.block:
+  ret void
+}
+
+define i32 @main() {
+  call void @recursiveFunc(i32* nonnull @Global)
+  ret i32 0
+}
+
+declare dso_local void @print_val(i32)
diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive3.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive3.ll
new file mode 100644 (file)
index 0000000..46b7761
--- /dev/null
@@ -0,0 +1,34 @@
+; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -S < %s | FileCheck %s
+
+; Duplicate store preventing recursive specialisation:
+;
+; CHECK:     @recursiveFunc.1
+; CHECK-NOT: @recursiveFunc.2
+
+@Global = internal constant i32 1, align 4
+
+define internal void @recursiveFunc(i32* nocapture readonly %arg) {
+  %temp = alloca i32, align 4
+  %arg.load = load i32, i32* %arg, align 4
+  %arg.cmp = icmp slt i32 %arg.load, 4
+  br i1 %arg.cmp, label %block6, label %ret.block
+
+block6:
+  call void @print_val(i32 %arg.load)
+  %arg.add = add nsw i32 %arg.load, 1
+  store i32 %arg.add, i32* %temp, align 4
+  store i32 %arg.add, i32* %temp, align 4
+  call void @recursiveFunc(i32* nonnull %temp)
+  br label %ret.block
+
+ret.block:
+  ret void
+}
+
+
+define i32 @main() {
+  call void @recursiveFunc(i32* nonnull @Global)
+  ret i32 0
+}
+
+declare dso_local void @print_val(i32)
diff --git a/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive4.ll b/llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive4.ll
new file mode 100644 (file)
index 0000000..294c6e0
--- /dev/null
@@ -0,0 +1,32 @@
+; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -S < %s | FileCheck %s
+
+; Alloca is not an integer type:
+;
+; CHECK:     @recursiveFunc.1
+; CHECK-NOT: @recursiveFunc.2
+
+@Global = internal constant i32 1, align 4
+
+define internal void @recursiveFunc(i32* nocapture readonly %arg) {
+  %temp = alloca float, align 4
+  %arg.load = load i32, i32* %arg, align 4
+  %arg.cmp = icmp slt i32 %arg.load, 4
+  br i1 %arg.cmp, label %block6, label %ret.block
+
+block6:
+  call void @print_val(i32 %arg.load)
+  %arg.add = add nsw i32 %arg.load, 1
+  %bc = bitcast float* %temp to i32*
+  call void @recursiveFunc(i32* nonnull %bc)
+  br label %ret.block
+
+ret.block:
+  ret void
+}
+
+define i32 @main() {
+  call void @recursiveFunc(i32* nonnull @Global)
+  ret i32 0
+}
+
+declare dso_local void @print_val(i32)