KernelDim3 blockSize, ArrayRef<Value *> kernelOperands);
/// The kernel function specified by the operation's `kernel` attribute.
- Function *kernel();
+ StringRef kernel();
/// The number of operands passed to the kernel function.
unsigned getNumKernelOperands();
/// The i-th operand passed to the kernel function.
}]>];
let extraClassDeclaration = [{
- Function *getCallee();
+ StringRef getCallee() { return callee(); }
+ FunctionType getCalleeType();
/// Get the argument operands to the called function.
operand_range getArgOperands() {
blockSize.x, blockSize.y, blockSize.z, kernelOperands);
}
-Function *LaunchFuncOp::kernel() {
- auto kernelAttr = getAttr(getKernelAttrName()).cast<FunctionAttr>();
- return getOperation()->getFunction()->getModule()->getNamedFunction(
- kernelAttr.getValue());
+StringRef LaunchFuncOp::kernel() {
+ return getAttrOfType<FunctionAttr>(getKernelAttrName()).getValue();
}
unsigned LaunchFuncOp::getNumKernelOperands() {
return emitOpError("attribute 'kernel' must be a function");
}
- Function *kernelFunc = this->kernel();
+ auto *module = getOperation()->getFunction()->getModule();
+ Function *kernelFunc = module->getNamedFunction(kernel());
if (!kernelFunc)
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
}
static void print(OpAsmPrinter *p, CallOp op) {
- *p << "call ";
- p->printFunctionReference(op.getCallee());
- *p << '(';
+ *p << "call " << op.getAttr("callee") << '(';
p->printOperands(op.getOperands());
*p << ')';
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
- *p << " : " << op.getCallee()->getType();
+ *p << " : ";
+ p->printType(op.getCalleeType());
}
static LogicalResult verify(CallOp op) {
return success();
}
-Function *CallOp::getCallee() {
- auto name = getAttrOfType<FunctionAttr>("callee").getValue();
- return getOperation()->getFunction()->getModule()->getNamedFunction(name);
+FunctionType CallOp::getCalleeType() {
+ SmallVector<Type, 4> resultTypes(getOperation()->getResultTypes());
+ SmallVector<Type, 8> argTypes;
+ argTypes.reserve(getNumOperands());
+ for (auto *operand : getArgOperands())
+ argTypes.push_back(operand->getType());
+ return FunctionType::get(argTypes, resultTypes, getContext());
}
//===----------------------------------------------------------------------===//