}
}
-// Move all constant arguments of the given kernel function into the function,
-// thereby reducing the number of kernel arguments.
-static gpu::LaunchFuncOp inlineConstants(FuncOp kernelFunc,
- gpu::LaunchFuncOp launch) {
+static bool isInliningBeneficiary(Operation *op) {
+ return isa<ConstantOp>(op) || isa<DimOp>(op);
+}
+
+// Move arguments of the given kernel function into the function if this reduces
+// the number of kernel arguments.
+static gpu::LaunchFuncOp inlineBeneficiaryOps(FuncOp kernelFunc,
+ gpu::LaunchFuncOp launch) {
OpBuilder kernelBuilder(kernelFunc.getBody());
auto &firstBlock = kernelFunc.getBody().front();
llvm::SmallVector<Value *, 8> newLaunchArgs;
+ BlockAndValueMapping map;
+ for (int i = 0, e = launch.getNumKernelOperands(); i < e; ++i) {
+ map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i));
+ }
for (int i = launch.getNumKernelOperands() - 1; i >= 0; --i) {
auto operandOp = launch.getKernelOperand(i)->getDefiningOp();
- auto constant = dyn_cast_or_null<ConstantOp>(operandOp);
- if (!constant) {
+ if (!operandOp || !isInliningBeneficiary(operandOp)) {
newLaunchArgs.push_back(launch.getKernelOperand(i));
continue;
}
- auto newConstant = kernelBuilder.clone(*operandOp);
- firstBlock.getArgument(i)->replaceAllUsesWith(newConstant->getResult(0));
+ // Only inline operations that do not create new arguments.
+ if (!llvm::all_of(operandOp->getOperands(),
+ [map](Value *value) { return map.contains(value); })) {
+ continue;
+ }
+ auto clone = kernelBuilder.clone(*operandOp, map);
+ firstBlock.getArgument(i)->replaceAllUsesWith(clone->getResult(0));
firstBlock.eraseArgument(i);
}
if (newLaunchArgs.size() == launch.getNumKernelOperands())
auto launchFuncOp = builder.create<gpu::LaunchFuncOp>(
launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
launchOp.getBlockSizeOperandValues(), kernelOperandValues);
- inlineConstants(kernelFunc, launchFuncOp);
+ inlineBeneficiaryOps(kernelFunc, launchFuncOp);
launchOp.erase();
}
// CHECK: %[[CST:.*]] = constant 8 : index
%cst = constant 8 : index
%cst2 = constant 2 : index
- %cst3 = constant 3 : index
+ %cst3 = dim %arg0, 0 : memref<?xf32>
// CHECK: "gpu.launch_func"(%[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]], %{{.*}}) {kernel = "extra_constants_kernel", kernel_module = @extra_constants_kernel} : (index, index, index, index, index, index, memref<?xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst,
%grid_z = %cst)