From 22f1af4400dacc03d1af21e52289078e270e16e0 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 23 May 2019 17:01:16 -0700 Subject: [PATCH] Move explicit calls to Module::getNamedFunction outside of the operations that contain FunctionAttr. It is up to the users of operations to query the module for a specific function referenced by a FunctionAttr. -- PiperOrigin-RevId: 249743109 --- mlir/include/mlir/GPU/GPUDialect.h | 2 +- mlir/include/mlir/StandardOps/Ops.td | 3 ++- mlir/lib/GPU/IR/GPUDialect.cpp | 9 ++++----- mlir/lib/StandardOps/Ops.cpp | 17 ++++++++++------- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/GPU/GPUDialect.h b/mlir/include/mlir/GPU/GPUDialect.h index ea7762c..8ace6ff 100644 --- a/mlir/include/mlir/GPU/GPUDialect.h +++ b/mlir/include/mlir/GPU/GPUDialect.h @@ -125,7 +125,7 @@ public: KernelDim3 blockSize, ArrayRef kernelOperands); /// The kernel function specified by the operation's `kernel` attribute. - Function *kernel(); + StringRef kernel(); /// The number of operands passed to the kernel function. unsigned getNumKernelOperands(); /// The i-th operand passed to the kernel function. diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 47a3c60..058750f 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -228,7 +228,8 @@ def CallOp : Std_Op<"call"> { }]>]; let extraClassDeclaration = [{ - Function *getCallee(); + StringRef getCallee() { return callee(); } + FunctionType getCalleeType(); /// Get the argument operands to the called function. operand_range getArgOperands() { diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index ff4c493..755a2c2 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -315,10 +315,8 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result, blockSize.x, blockSize.y, blockSize.z, kernelOperands); } -Function *LaunchFuncOp::kernel() { - auto kernelAttr = getAttr(getKernelAttrName()).cast(); - return getOperation()->getFunction()->getModule()->getNamedFunction( - kernelAttr.getValue()); +StringRef LaunchFuncOp::kernel() { + return getAttrOfType(getKernelAttrName()).getValue(); } unsigned LaunchFuncOp::getNumKernelOperands() { @@ -337,7 +335,8 @@ LogicalResult LaunchFuncOp::verify() { return emitOpError("attribute 'kernel' must be a function"); } - Function *kernelFunc = this->kernel(); + auto *module = getOperation()->getFunction()->getModule(); + Function *kernelFunc = module->getNamedFunction(kernel()); if (!kernelFunc) return emitError() << "kernel function '" << kernelAttr << "' is undefined"; diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 3614e15..f05b0cf 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -436,13 +436,12 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { } static void print(OpAsmPrinter *p, CallOp op) { - *p << "call "; - p->printFunctionReference(op.getCallee()); - *p << '('; + *p << "call " << op.getAttr("callee") << '('; p->printOperands(op.getOperands()); *p << ')'; p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - *p << " : " << op.getCallee()->getType(); + *p << " : "; + p->printType(op.getCalleeType()); } static LogicalResult verify(CallOp op) { @@ -475,9 +474,13 @@ static LogicalResult verify(CallOp op) { return success(); } -Function *CallOp::getCallee() { - auto name = getAttrOfType("callee").getValue(); - return getOperation()->getFunction()->getModule()->getNamedFunction(name); +FunctionType CallOp::getCalleeType() { + SmallVector resultTypes(getOperation()->getResultTypes()); + SmallVector argTypes; + argTypes.reserve(getNumOperands()); + for (auto *operand : getArgOperands()) + argTypes.push_back(operand->getType()); + return FunctionType::get(argTypes, resultTypes, getContext()); } //===----------------------------------------------------------------------===// -- 2.7.4