From 530439075abddbcbd85e7461584b32586a9226f0 Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Wed, 8 May 2019 23:37:57 -0700 Subject: [PATCH] Add gpu.launch_func builder. -- PiperOrigin-RevId: 247364893 --- mlir/include/mlir/GPU/GPUDialect.h | 5 +++++ mlir/lib/GPU/IR/GPUDialect.cpp | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/mlir/include/mlir/GPU/GPUDialect.h b/mlir/include/mlir/GPU/GPUDialect.h index 28d86b5..50736ec 100644 --- a/mlir/include/mlir/GPU/GPUDialect.h +++ b/mlir/include/mlir/GPU/GPUDialect.h @@ -104,6 +104,11 @@ class LaunchFuncOp : public Op::Impl, public: using Op::Op; + static void build(Builder *builder, OperationState *result, + Function *kernelFunc, Value *gridSizeX, Value *gridSizeY, + Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, + Value *blockSizeZ, ArrayRef kernelOperands); + /// The kernel function specified by the operation's `kernel` attribute. Function *kernel(); /// The number of operands passed to the kernel function. diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index 9d8b748..87488de 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -268,6 +268,19 @@ ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) { //===----------------------------------------------------------------------===// // LaunchFuncOp //===----------------------------------------------------------------------===// + +void LaunchFuncOp::build(Builder *builder, OperationState *result, + Function *kernelFunc, Value *gridSizeX, + Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, + Value *blockSizeY, Value *blockSizeZ, + ArrayRef kernelOperands) { + // Add grid and block sizes as op operands, followed by the data operands. + result->addOperands( + {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); + result->addOperands(kernelOperands); + result->addAttribute("kernel", builder->getFunctionAttr(kernelFunc)); +} + Function *LaunchFuncOp::kernel() { return this->getAttr("kernel").dyn_cast().getValue(); } -- 2.7.4