From dc2dff6c685de87abbca035370d691e0bc0da15d Mon Sep 17 00:00:00 2001 From: Jason Henline Date: Fri, 2 Sep 2016 00:22:05 +0000 Subject: [PATCH] [SE] Make Kernel movable Summary: Kernel is basically just a smart pointer to the underlying implementation, so making it movable prevents having to store a std::unique_ptr to it. Reviewers: jlebar Subscribers: jprice, parallel_libs-commits Differential Revision: https://reviews.llvm.org/D24150 llvm-svn: 280437 --- parallel-libs/streamexecutor/examples/Example.cpp | 5 +- .../streamexecutor/include/streamexecutor/Device.h | 10 +--- .../streamexecutor/include/streamexecutor/Kernel.h | 69 +++------------------- 3 files changed, 12 insertions(+), 72 deletions(-) diff --git a/parallel-libs/streamexecutor/examples/Example.cpp b/parallel-libs/streamexecutor/examples/Example.cpp index 8f42ffa..76027a8 100644 --- a/parallel-libs/streamexecutor/examples/Example.cpp +++ b/parallel-libs/streamexecutor/examples/Example.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include "streamexecutor/StreamExecutor.h" @@ -111,7 +110,7 @@ int main() { se::Device *Device = getOrDie(Platform->getDevice(0)); // Load the kernel onto the device. - std::unique_ptr Kernel = + cg::SaxpyKernel Kernel = getOrDie(Device->createKernel(cg::SaxpyLoaderSpec)); // Allocate memory on the device. @@ -124,7 +123,7 @@ int main() { se::Stream Stream = getOrDie(Device->createStream()); Stream.thenCopyH2D(HostX, X) .thenCopyH2D(HostY, Y) - .thenLaunch(ArraySize, 1, *Kernel, A, X, Y) + .thenLaunch(ArraySize, 1, Kernel, A, X, Y) .thenCopyD2H(X, HostX); // Wait for the stream to complete. se::dieIfError(Stream.blockHostUntilDone()); diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h index 2493781..3de9910 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h @@ -32,22 +32,18 @@ public: /// Creates a kernel object for this device. /// - /// If the return value is not an error, the returned pointer will never be - /// null. - /// /// See \ref CompilerGeneratedKernelExample "Kernel.h" for an example of how /// this method is used. template - Expected::value, KernelT>::type>> + Expected::value, + KernelT>::type> createKernel(const MultiKernelLoaderSpec &Spec) { Expected> MaybeKernelHandle = PDevice->createKernel(Spec); if (!MaybeKernelHandle) { return MaybeKernelHandle.takeError(); } - return llvm::make_unique(Spec.getKernelName(), - std::move(*MaybeKernelHandle)); + return KernelT(Spec.getKernelName(), std::move(*MaybeKernelHandle)); } /// Creates a stream object for this device. diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h index eaf3db3..c9b4180 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h @@ -11,68 +11,10 @@ /// Types to represent device kernels (code compiled to run on GPU or other /// accelerator). /// -/// With the kernel parameter types recorded in the Kernel template parameters, -/// type-safe kernel launch functions can be written with signatures like the -/// following: -/// \code -/// template -/// void Launch( -/// const Kernel &Kernel, ParamterTs... Arguments); -/// \endcode -/// and the compiler will check that the user passes in arguments with types -/// matching the corresponding kernel parameters. -/// -/// A problem is that a Kernel template specialization with the right parameter -/// types must be passed as the first argument to the Launch function, and it's -/// just as hard to get the types right in that template specialization as it is -/// to get them right for the kernel arguments. -/// -/// With this problem in mind, it is not recommended for users to specialize the -/// Kernel template class themselves, but instead to let the compiler do it for -/// them. When the compiler encounters a device kernel function, it can create a -/// Kernel template specialization in the host code that has the right parameter -/// types for that kernel and which has a type name based on the name of the -/// kernel function. -/// -/// \anchor CompilerGeneratedKernelExample -/// For example, if a CUDA device kernel function with the following signature -/// has been defined: -/// \code -/// void Saxpy(float A, float *X, float *Y); -/// \endcode -/// the compiler can insert the following declaration in the host code: -/// \code -/// namespace compiler_cuda_namespace { -/// namespace se = streamexecutor; -/// using SaxpyKernel = -/// se::Kernel< -/// float, -/// se::GlobalDeviceMemory, -/// se::GlobalDeviceMemory>; -/// } // namespace compiler_cuda_namespace -/// \endcode -/// and then the user can launch the kernel by calling the StreamExecutor launch -/// function as follows: -/// \code -/// namespace ccn = compiler_cuda_namespace; -/// using KernelPtr = std::unique_ptr; -/// // Assumes Device is a pointer to the Device on which to launch the -/// // kernel. -/// // -/// // See KernelSpec.h for details on how the compiler can create a -/// // MultiKernelLoaderSpec instance like SaxpyKernelLoaderSpec below. -/// Expected MaybeKernel = -/// Device->createKernel(ccn::SaxpyKernelLoaderSpec); -/// if (!MaybeKernel) { /* Handle error */ } -/// KernelPtr SaxpyKernel = std::move(*MaybeKernel); -/// Launch(*SaxpyKernel, A, X, Y); -/// \endcode -/// -/// With the compiler's help in specializing Kernel for each device kernel -/// function (and generating a MultiKernelLoaderSpec instance for each kernel), -/// the user can safely launch the device kernel from the host and get an error -/// message at compile time if the argument types don't match the kernel -/// parameter types. +/// See the \ref index "main page" for an example of how a compiler-generated +/// specialization of the Kernel class template can be used along with the +/// streamexecutor::Stream::thenLaunch method to create a typesafe interface for +/// kernel launches. /// //===----------------------------------------------------------------------===// @@ -112,6 +54,9 @@ public: Kernel(llvm::StringRef Name, std::unique_ptr PHandle) : KernelBase(Name), PHandle(std::move(PHandle)) {} + Kernel(Kernel &&Other) = default; + Kernel &operator=(Kernel &&Other) = default; + /// Gets the underlying platform-specific handle for this kernel. PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); } -- 2.7.4