Move explicit calls to Module::getNamedFunction outside of the operations that...
authorRiver Riddle <riverriddle@google.com>
Fri, 24 May 2019 00:01:16 +0000 (17:01 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:59:44 +0000 (19:59 -0700)
--

PiperOrigin-RevId: 249743109

mlir/include/mlir/GPU/GPUDialect.h
mlir/include/mlir/StandardOps/Ops.td
mlir/lib/GPU/IR/GPUDialect.cpp
mlir/lib/StandardOps/Ops.cpp

index ea7762c..8ace6ff 100644 (file)
@@ -125,7 +125,7 @@ public:
                     KernelDim3 blockSize, ArrayRef<Value *> kernelOperands);
 
   /// The kernel function specified by the operation's `kernel` attribute.
-  Function *kernel();
+  StringRef kernel();
   /// The number of operands passed to the kernel function.
   unsigned getNumKernelOperands();
   /// The i-th operand passed to the kernel function.
index 47a3c60..058750f 100644 (file)
@@ -228,7 +228,8 @@ def CallOp : Std_Op<"call"> {
   }]>];
 
   let extraClassDeclaration = [{
-    Function *getCallee();
+    StringRef getCallee() { return callee(); }
+    FunctionType getCalleeType();
 
     /// Get the argument operands to the called function.
     operand_range getArgOperands() {
index ff4c493..755a2c2 100644 (file)
@@ -315,10 +315,8 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
         blockSize.x, blockSize.y, blockSize.z, kernelOperands);
 }
 
-Function *LaunchFuncOp::kernel() {
-  auto kernelAttr = getAttr(getKernelAttrName()).cast<FunctionAttr>();
-  return getOperation()->getFunction()->getModule()->getNamedFunction(
-      kernelAttr.getValue());
+StringRef LaunchFuncOp::kernel() {
+  return getAttrOfType<FunctionAttr>(getKernelAttrName()).getValue();
 }
 
 unsigned LaunchFuncOp::getNumKernelOperands() {
@@ -337,7 +335,8 @@ LogicalResult LaunchFuncOp::verify() {
     return emitOpError("attribute 'kernel' must be a function");
   }
 
-  Function *kernelFunc = this->kernel();
+  auto *module = getOperation()->getFunction()->getModule();
+  Function *kernelFunc = module->getNamedFunction(kernel());
   if (!kernelFunc)
     return emitError() << "kernel function '" << kernelAttr << "' is undefined";
 
index 3614e15..f05b0cf 100644 (file)
@@ -436,13 +436,12 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
 }
 
 static void print(OpAsmPrinter *p, CallOp op) {
-  *p << "call ";
-  p->printFunctionReference(op.getCallee());
-  *p << '(';
+  *p << "call " << op.getAttr("callee") << '(';
   p->printOperands(op.getOperands());
   *p << ')';
   p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
-  *p << " : " << op.getCallee()->getType();
+  *p << " : ";
+  p->printType(op.getCalleeType());
 }
 
 static LogicalResult verify(CallOp op) {
@@ -475,9 +474,13 @@ static LogicalResult verify(CallOp op) {
   return success();
 }
 
-Function *CallOp::getCallee() {
-  auto name = getAttrOfType<FunctionAttr>("callee").getValue();
-  return getOperation()->getFunction()->getModule()->getNamedFunction(name);
+FunctionType CallOp::getCalleeType() {
+  SmallVector<Type, 4> resultTypes(getOperation()->getResultTypes());
+  SmallVector<Type, 8> argTypes;
+  argTypes.reserve(getNumOperands());
+  for (auto *operand : getArgOperands())
+    argTypes.push_back(operand->getType());
+  return FunctionType::get(argTypes, resultTypes, getContext());
 }
 
 //===----------------------------------------------------------------------===//