Summary: Add the Stream class and a few of the operations it supports.
Reviewers: jlebar, tra
Subscribers: jprice, parallel_libs-commits
Differential Revision: https://reviews.llvm.org/D23333
llvm-svn: 278829
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
+ # Get the LLVM cxxflags by using llvm-config.
+ #
+ # This is necessary to get -fno-rtti if LLVM is compiled that way.
+ execute_process(
+ COMMAND
+ "${LLVM_BINARY_DIR}/bin/llvm-config"
+ --cxxflags
+ OUTPUT_VARIABLE
+ LLVM_CXXFLAGS
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${LLVM_CXXFLAGS}")
+
# Find the libraries that correspond to the LLVM components
# that we wish to use
llvm_map_components_to_libnames(llvm_libs support symbolize)
+++ /dev/null
-//===-- Interfaces.h - Interfaces to platform-specific impls ----*- C++ -*-===//
-//
-// The LLVM Compiler Infrastructure
-//
-// This file is distributed under the University of Illinois Open Source
-// License. See LICENSE.TXT for details.
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file
-/// Interfaces to platform-specific StreamExecutor type implementations.
-///
-//===----------------------------------------------------------------------===//
-
-#ifndef STREAMEXECUTOR_INTERFACES_H
-#define STREAMEXECUTOR_INTERFACES_H
-
-namespace streamexecutor {
-
-/// Methods supported by device kernel function objects on all platforms.
-class KernelInterface {
- // TODO(jhen): Add methods.
-};
-
-// TODO(jhen): Add other interfaces such as Stream.
-
-} // namespace streamexecutor
-
-#endif // STREAMEXECUTOR_INTERFACES_H
--- /dev/null
+//===-- LaunchDimensions.h - Kernel block and grid sizes --------*- C++ -*-===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Structures to hold sizes for blocks and grids which are used as parameters
+/// for kernel launches.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef STREAMEXECUTOR_LAUNCHDIMENSIONS_H
+#define STREAMEXECUTOR_LAUNCHDIMENSIONS_H
+
+namespace streamexecutor {
+
+/// The dimensions of a device block of execution.
+///
+/// A block is made up of an array of X by Y by Z threads.
+struct BlockDimensions {
+ BlockDimensions(unsigned X = 1, unsigned Y = 1, unsigned Z = 1)
+ : X(X), Y(Y), Z(Z) {}
+
+ unsigned X;
+ unsigned Y;
+ unsigned Z;
+};
+
+/// The dimensions of a device grid of execution.
+///
+/// A grid is made up of an array of X by Y by Z blocks.
+struct GridDimensions {
+ GridDimensions(unsigned X = 1, unsigned Y = 1, unsigned Z = 1)
+ : X(X), Y(Y), Z(Z) {}
+
+ unsigned X;
+ unsigned Y;
+ unsigned Z;
+};
+
+} // namespace streamexecutor
+
+#endif // STREAMEXECUTOR_LAUNCHDIMENSIONS_H
/// efficiently, although it is probably more information than is needed for any
/// specific platform.
///
+/// The PackedKernelArgumentArrayBase class has no template parameters, so it
+/// does not benefit from compile-time type checking. However, since it has no
+/// template parameters, it can be passed as an argument to virtual functions,
+/// and this allows it to be passed to functions that use virtual function
+/// overloading to handle platform-specific kernel launching.
+///
//===----------------------------------------------------------------------===//
#ifndef STREAMEXECUTOR_PACKEDKERNELARGUMENTARRAY_H
SHARED_DEVICE_MEMORY /// Shared device memory argument.
};
-/// An array of packed kernel arguments.
-template <typename... ParameterTs> class PackedKernelArgumentArray {
+/// An array of packed kernel arguments without compile-time type information.
+///
+/// This un-templated base class is useful because packed kernel arguments must
+/// at some point be passed to a virtual function that performs
+/// platform-specific kernel launches. Such a virtual function cannot be
+/// templated to handle all specializations of the
+/// PackedKernelArgumentArray<...> class template, so, instead, references to
+/// PackedKernelArgumentArray<...> are passed as references to this base class.
+class PackedKernelArgumentArrayBase {
public:
- /// Constructs an instance by packing the specified arguments.
- PackedKernelArgumentArray(const ParameterTs &... Arguments)
- : SharedCount(0u) {
- PackArguments(0, Arguments...);
- }
+ virtual ~PackedKernelArgumentArrayBase();
/// Gets the number of packed arguments.
- size_t getArgumentCount() const { return sizeof...(ParameterTs); }
+ size_t getArgumentCount() const { return ArgumentCount; }
/// Gets the address of the argument at the given index.
- const void *getAddress(size_t Index) const { return Addresses[Index]; }
+ const void *getAddress(size_t Index) const { return AddressesData[Index]; }
/// Gets the size of the argument at the given index.
- size_t getSize(size_t Index) const { return Sizes[Index]; }
+ size_t getSize(size_t Index) const { return SizesData[Index]; }
/// Gets the type of the argument at the given index.
- KernelArgumentType getType(size_t Index) const { return Types[Index]; }
+ KernelArgumentType getType(size_t Index) const { return TypesData[Index]; }
/// Gets a pointer to the address array.
- const void *const *getAddresses() const { return Addresses.data(); }
+ const void *const *getAddresses() const { return AddressesData; }
/// Gets a pointer to the sizes array.
- const size_t *getSizes() const { return Sizes.data(); }
+ const size_t *getSizes() const { return SizesData; }
/// Gets a pointer to the types array.
- const KernelArgumentType *getTypes() const { return Types.data(); }
+ const KernelArgumentType *getTypes() const { return TypesData; }
/// Gets the number of shared device memory arguments.
size_t getSharedCount() const { return SharedCount; }
+protected:
+ PackedKernelArgumentArrayBase(size_t ArgumentCount)
+ : ArgumentCount(ArgumentCount), SharedCount(0u) {}
+
+ size_t ArgumentCount;
+ size_t SharedCount;
+ const void *const *AddressesData;
+ size_t *SizesData;
+ KernelArgumentType *TypesData;
+};
+
+/// An array of packed kernel arguments with compile-time type information.
+///
+/// This is used by the platform-independent StreamExecutor code to pack
+/// arguments in a compile-time type-safe way. In order to actually launch a
+/// kernel on a specific platform, however, a reference to this class will have
+/// to be passed to a virtual, platform-specific kernel launch function. Such a
+/// reference will be passed as a reference to the base class rather than a
+/// reference to this subclass itself because a virtual function cannot be
+/// templated in such a way to maintain the template parameter types of the
+/// subclass.
+template <typename... ParameterTs>
+class PackedKernelArgumentArray : public PackedKernelArgumentArrayBase {
+public:
+ /// Constructs an instance by packing the specified arguments.
+ ///
+ /// Rather than using this constructor directly, consider using the
+ /// make_kernel_argument_pack function instead, to get the compiler to infer
+ /// the parameter types for you.
+ PackedKernelArgumentArray(const ParameterTs &... Arguments)
+ : PackedKernelArgumentArrayBase(sizeof...(ParameterTs)) {
+ AddressesData = Addresses.data();
+ SizesData = Sizes.data();
+ TypesData = Types.data();
+ PackArguments(0, Arguments...);
+ }
+
+ ~PackedKernelArgumentArray() override = default;
+
private:
// Base case for PackArguments when there are no arguments to pack.
void PackArguments(size_t) {}
std::array<const void *, sizeof...(ParameterTs)> Addresses;
std::array<size_t, sizeof...(ParameterTs)> Sizes;
std::array<KernelArgumentType, sizeof...(ParameterTs)> Types;
- size_t SharedCount;
};
// Utility template function to call the PackedKernelArgumentArray constructor
--- /dev/null
+//===-- PlatformInterfaces.h - Interfaces to platform impls -----*- C++ -*-===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Interfaces to platform-specific implementations.
+///
+/// The general pattern is that the functions in these interfaces take raw
+/// handle types as parameters. This means that these types and functions are
+/// not intended for public use. Instead, corresponding methods in public types
+/// like Stream, StreamExecutor, and Kernel use C++ templates to create
+/// type-safe public interfaces. Those public functions do the type-unsafe work
+/// of extracting raw handles from their arguments and forwarding those handles
+/// to the methods defined in this file in the proper format.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef STREAMEXECUTOR_PLATFORMINTERFACES_H
+#define STREAMEXECUTOR_PLATFORMINTERFACES_H
+
+#include "streamexecutor/DeviceMemory.h"
+#include "streamexecutor/Kernel.h"
+#include "streamexecutor/LaunchDimensions.h"
+#include "streamexecutor/PackedKernelArgumentArray.h"
+#include "streamexecutor/Utils/Error.h"
+
+namespace streamexecutor {
+
+class PlatformStreamExecutor;
+
+/// Methods supported by device kernel function objects on all platforms.
+class KernelInterface {
+ // TODO(jhen): Add methods.
+};
+
+/// Platform-specific stream handle.
+class PlatformStreamHandle {
+public:
+ explicit PlatformStreamHandle(PlatformStreamExecutor *Executor)
+ : Executor(Executor) {}
+
+ virtual ~PlatformStreamHandle();
+
+ PlatformStreamExecutor *getExecutor() { return Executor; }
+
+private:
+ PlatformStreamExecutor *Executor;
+};
+
+/// Raw executor methods that must be implemented by each platform.
+///
+/// This class defines the platform interface that supports executing work on a
+/// device.
+///
+/// The public StreamExecutor and Stream classes have the type-safe versions of
+/// the functions in this interface.
+class PlatformStreamExecutor {
+public:
+ virtual ~PlatformStreamExecutor();
+
+ virtual std::string getName() const = 0;
+
+ /// Creates a platform-specific stream.
+ virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() = 0;
+
+ /// Launches a kernel on the given stream.
+ virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
+ GridDimensions GridSize, const KernelBase &Kernel,
+ const PackedKernelArgumentArrayBase &ArgumentArray) {
+ return make_error("launch not implemented for platform " + getName());
+ }
+
+ /// Copies data from the device to the host.
+ virtual Error memcpyD2H(PlatformStreamHandle *S,
+ const GlobalDeviceMemoryBase &DeviceSrc,
+ void *HostDst, size_t ByteCount) {
+ return make_error("memcpyD2H not implemented for platform " + getName());
+ }
+
+ /// Copies data from the host to the device.
+ virtual Error memcpyH2D(PlatformStreamHandle *S, const void *HostSrc,
+ GlobalDeviceMemoryBase *DeviceDst, size_t ByteCount) {
+ return make_error("memcpyH2D not implemented for platform " + getName());
+ }
+
+ /// Copies data from one device location to another.
+ virtual Error memcpyD2D(PlatformStreamHandle *S,
+ const GlobalDeviceMemoryBase &DeviceSrc,
+ GlobalDeviceMemoryBase *DeviceDst, size_t ByteCount) {
+ return make_error("memcpyD2D not implemented for platform " + getName());
+ }
+
+ /// Blocks the host until the given stream completes all the work enqueued up
+ /// to the point this function is called.
+ virtual Error blockHostUntilDone(PlatformStreamHandle *S) {
+ return make_error("blockHostUntilDone not implemented for platform " +
+ getName());
+ }
+};
+
+} // namespace streamexecutor
+
+#endif // STREAMEXECUTOR_PLATFORMINTERFACES_H
--- /dev/null
+//===-- Stream.h - A stream of execution ------------------------*- C++ -*-===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+///
+/// A Stream instance represents a queue of sequential, host-asynchronous work
+/// to be performed on a device.
+///
+/// To enqueue work on a device, first create a StreamExecutor instance for a
+/// given device and then use that StreamExecutor to create a Stream instance.
+/// The Stream instance will perform its work on the device managed by the
+/// StreamExecutor that created it.
+///
+/// The various "then" methods of the Stream object, such as thenMemcpyH2D and
+/// thenLaunch, may be used to enqueue work on the Stream, and the
+/// blockHostUntilDone() method may be used to block the host code until the
+/// Stream has completed all its work.
+///
+/// Multiple Stream instances can be created for the same StreamExecutor. This
+/// allows several independent streams of computation to be performed
+/// simultaneously on a single device.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef STREAMEXECUTOR_STREAM_H
+#define STREAMEXECUTOR_STREAM_H
+
+#include <cassert>
+#include <memory>
+#include <string>
+
+#include "streamexecutor/DeviceMemory.h"
+#include "streamexecutor/Kernel.h"
+#include "streamexecutor/LaunchDimensions.h"
+#include "streamexecutor/PackedKernelArgumentArray.h"
+#include "streamexecutor/PlatformInterfaces.h"
+#include "streamexecutor/Utils/Error.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/RWMutex.h"
+
+namespace streamexecutor {
+
+/// Represents a stream of dependent computations on a device.
+///
+/// The operations within a stream execute sequentially and asynchronously until
+/// blockHostUntilDone() is invoked, which synchronously joins host code with
+/// the execution of the stream.
+///
+/// If any given operation fails when entraining work for the stream, isOK()
+/// will indicate that an error has occurred and getStatus() will get the first
+/// error that occurred on the stream. There is no way to clear the error state
+/// of a stream once it is in an error state.
+class Stream {
+public:
+ explicit Stream(std::unique_ptr<PlatformStreamHandle> PStream);
+
+ ~Stream();
+
+ /// Returns whether any error has occurred while entraining work on this
+ /// stream.
+ bool isOK() const {
+ llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex);
+ return !ErrorMessage;
+ }
+
+ /// Returns the status created by the first error that occurred while
+ /// entraining work on this stream.
+ Error getStatus() const {
+ llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex);
+ if (ErrorMessage)
+ return make_error(*ErrorMessage);
+ else
+ return Error::success();
+ };
+
+ /// Entrains onto the stream of operations a kernel launch with the given
+ /// arguments.
+ ///
+ /// These arguments can be device memory types like GlobalDeviceMemory<T> and
+ /// SharedDeviceMemory<T>, or they can be primitive types such as int. The
+ /// allowable argument types are determined by the template parameters to the
+ /// TypedKernel argument.
+ template <typename... ParameterTs>
+ Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize,
+ const TypedKernel<ParameterTs...> &Kernel,
+ const ParameterTs &... Arguments) {
+ auto ArgumentArray =
+ make_kernel_argument_pack<ParameterTs...>(Arguments...);
+ setError(PlatformExecutor->launch(ThePlatformStream.get(), BlockSize,
+ GridSize, Kernel, ArgumentArray));
+ return *this;
+ }
+
+ /// Entrain onto the stream a memcpy of a given number of elements from a
+ /// device source to a host destination.
+ ///
+ /// HostDst must be a pointer to host memory allocated by
+ /// StreamExecutor::allocateHostMemory or otherwise allocated and then
+ /// registered with StreamExecutor::registerHostMemory.
+ template <typename T>
+ Stream &thenMemcpyD2H(const GlobalDeviceMemory<T> &DeviceSrc,
+ llvm::MutableArrayRef<T> HostDst, size_t ElementCount) {
+ if (ElementCount > DeviceSrc.getElementCount())
+ setError("copying too many elements, " + llvm::Twine(ElementCount) +
+ ", from device memory array of size " +
+ llvm::Twine(DeviceSrc.getElementCount()));
+ else if (ElementCount > HostDst.size())
+ setError("copying too many elements, " + llvm::Twine(ElementCount) +
+ ", to host array of size " + llvm::Twine(HostDst.size()));
+ else
+ setError(PlatformExecutor->memcpyD2H(ThePlatformStream.get(), DeviceSrc,
+ HostDst.data(),
+ ElementCount * sizeof(T)));
+ return *this;
+ }
+
+ /// Same as thenMemcpyD2H above, but copies the entire source to the
+ /// destination.
+ template <typename T>
+ Stream &thenMemcpyD2H(const GlobalDeviceMemory<T> &DeviceSrc,
+ llvm::MutableArrayRef<T> HostDst) {
+ return thenMemcpyD2H(DeviceSrc, HostDst, DeviceSrc.getElementCount());
+ }
+
+ /// Entrain onto the stream a memcpy of a given number of elements from a host
+ /// source to a device destination.
+ ///
+ /// HostSrc must be a pointer to host memory allocated by
+ /// StreamExecutor::allocateHostMemory or otherwise allocated and then
+ /// registered with StreamExecutor::registerHostMemory.
+ template <typename T>
+ Stream &thenMemcpyH2D(llvm::ArrayRef<T> HostSrc,
+ GlobalDeviceMemory<T> *DeviceDst, size_t ElementCount) {
+ if (ElementCount > HostSrc.size())
+ setError("copying too many elements, " + llvm::Twine(ElementCount) +
+ ", from host array of size " + llvm::Twine(HostSrc.size()));
+ else if (ElementCount > DeviceDst->getElementCount())
+ setError("copying too many elements, " + llvm::Twine(ElementCount) +
+ ", to device memory array of size " +
+ llvm::Twine(DeviceDst->getElementCount()));
+ else
+ setError(PlatformExecutor->memcpyH2D(ThePlatformStream.get(),
+ HostSrc.data(), DeviceDst,
+ ElementCount * sizeof(T)));
+ return *this;
+ }
+
+ /// Same as thenMemcpyH2D above, but copies the entire source to the
+ /// destination.
+ template <typename T>
+ Stream &thenMemcpyH2D(llvm::ArrayRef<T> HostSrc,
+ GlobalDeviceMemory<T> *DeviceDst) {
+ return thenMemcpyH2D(HostSrc, DeviceDst, HostSrc.size());
+ }
+
+ /// Entrain onto the stream a memcpy of a given number of elements from a
+ /// device source to a device destination.
+ template <typename T>
+ Stream &thenMemcpyD2D(const GlobalDeviceMemory<T> &DeviceSrc,
+ GlobalDeviceMemory<T> *DeviceDst, size_t ElementCount) {
+ if (ElementCount > DeviceSrc.getElementCount())
+ setError("copying too many elements, " + llvm::Twine(ElementCount) +
+ ", from device memory array of size " +
+ llvm::Twine(DeviceSrc.getElementCount()));
+ else if (ElementCount > DeviceDst->getElementCount())
+ setError("copying too many elements, " + llvm::Twine(ElementCount) +
+ ", to device memory array of size " +
+ llvm::Twine(DeviceDst->getElementCount()));
+ else
+ setError(PlatformExecutor->memcpyD2D(ThePlatformStream.get(), DeviceSrc,
+ DeviceDst,
+ ElementCount * sizeof(T)));
+ return *this;
+ }
+
+ /// Same as thenMemcpyD2D above, but copies the entire source to the
+ /// destination.
+ template <typename T>
+ Stream &thenMemcpyD2D(const GlobalDeviceMemory<T> &DeviceSrc,
+ GlobalDeviceMemory<T> *DeviceDst) {
+ return thenMemcpyD2D(DeviceSrc, DeviceDst, DeviceSrc.getElementCount());
+ }
+
+ /// Blocks the host code, waiting for the operations entrained on the stream
+ /// (enqueued up to this point in program execution) to complete.
+ ///
+ /// Returns true if there are no errors on the stream.
+ bool blockHostUntilDone() {
+ Error E = PlatformExecutor->blockHostUntilDone(ThePlatformStream.get());
+ bool returnValue = static_cast<bool>(E);
+ setError(std::move(E));
+ return returnValue;
+ }
+
+private:
+ /// Sets the error state from an Error object.
+ ///
+ /// Does not overwrite the error if it is already set.
+ void setError(Error &&E) {
+ if (E) {
+ llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex);
+ if (!ErrorMessage)
+ ErrorMessage = consumeAndGetMessage(std::move(E));
+ }
+ }
+
+ /// Sets the error state from an error message.
+ ///
+ /// Does not overwrite the error if it is already set.
+ void setError(llvm::Twine Message) {
+ llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex);
+ if (!ErrorMessage)
+ ErrorMessage = Message.str();
+ }
+
+ /// The PlatformStreamExecutor that supports the operations of this stream.
+ PlatformStreamExecutor *PlatformExecutor;
+
+ /// The platform-specific stream handle for this instance.
+ std::unique_ptr<PlatformStreamHandle> ThePlatformStream;
+
+ /// Mutex that guards the error state flags.
+ ///
+ /// Mutable so that it can be obtained via const reader lock.
+ mutable llvm::sys::RWMutex ErrorMessageMutex;
+
+ /// First error message for an operation in this stream or empty if there have
+ /// been no errors.
+ llvm::Optional<std::string> ErrorMessage;
+
+ Stream(const Stream &) = delete;
+ void operator=(const Stream &) = delete;
+};
+
+} // namespace streamexecutor
+
+#endif // STREAMEXECUTOR_STREAM_H
#ifndef STREAMEXECUTOR_STREAMEXECUTOR_H
#define STREAMEXECUTOR_STREAMEXECUTOR_H
+#include "streamexecutor/KernelSpec.h"
#include "streamexecutor/Utils/Error.h"
namespace streamexecutor {
class KernelInterface;
+class PlatformStreamExecutor;
+class Stream;
class StreamExecutor {
public:
+ explicit StreamExecutor(PlatformStreamExecutor *PlatformExecutor);
+ virtual ~StreamExecutor();
+
/// Gets the kernel implementation for the underlying platform.
virtual Expected<std::unique_ptr<KernelInterface>>
getKernelImplementation(const MultiKernelLoaderSpec &Spec) {
return nullptr;
}
- // TODO(jhen): Add other methods.
+ Expected<std::unique_ptr<Stream>> createStream();
+
+private:
+ PlatformStreamExecutor *PlatformExecutor;
};
} // namespace streamexecutor
streamexecutor
$<TARGET_OBJECTS:utils>
Kernel.cpp
- KernelSpec.cpp)
+ KernelSpec.cpp
+ PackedKernelArgumentArray.cpp
+ PlatformInterfaces.cpp
+ Stream.cpp
+ StreamExecutor.cpp)
target_link_libraries(streamexecutor ${llvm_libs})
if(STREAM_EXECUTOR_UNIT_TESTS)
//===----------------------------------------------------------------------===//
#include "streamexecutor/Kernel.h"
-#include "streamexecutor/Interfaces.h"
+#include "streamexecutor/PlatformInterfaces.h"
#include "streamexecutor/StreamExecutor.h"
#include "llvm/DebugInfo/Symbolize/Symbolize.h"
--- /dev/null
+//===-- PackedKernelArgumentArray.cpp - Packed argument array impl --------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Implementation details for classes from PackedKernelArgumentArray.h.
+///
+//===----------------------------------------------------------------------===//
+
+#include "streamexecutor/PackedKernelArgumentArray.h"
+
+namespace streamexecutor {
+
+PackedKernelArgumentArrayBase::~PackedKernelArgumentArrayBase() = default;
+
+} // namespace streamexecutor
--- /dev/null
+//===-- PlatformInterfaces.cpp - Platform interface implementations -------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Implementation file for PlatformInterfaces.h.
+///
+//===----------------------------------------------------------------------===//
+
+#include "streamexecutor/PlatformInterfaces.h"
+
+namespace streamexecutor {
+
+PlatformStreamHandle::~PlatformStreamHandle() = default;
+
+PlatformStreamExecutor::~PlatformStreamExecutor() = default;
+
+} // namespace streamexecutor
--- /dev/null
+//===-- Stream.cpp - General stream implementation ------------------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the implementation details for a general stream object.
+///
+//===----------------------------------------------------------------------===//
+
+#include "streamexecutor/Stream.h"
+
+namespace streamexecutor {
+
+Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream)
+ : PlatformExecutor(PStream->getExecutor()),
+ ThePlatformStream(std::move(PStream)) {}
+
+Stream::~Stream() = default;
+
+} // namespace streamexecutor
--- /dev/null
+//===-- StreamExecutor.cpp - StreamExecutor implementation ----------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Implementation of StreamExecutor class internals.
+///
+//===----------------------------------------------------------------------===//
+
+#include "streamexecutor/StreamExecutor.h"
+
+#include <cassert>
+
+#include "streamexecutor/PlatformInterfaces.h"
+#include "streamexecutor/Stream.h"
+
+#include "llvm/ADT/STLExtras.h"
+
+namespace streamexecutor {
+
+StreamExecutor::StreamExecutor(PlatformStreamExecutor *PlatformExecutor)
+ : PlatformExecutor(PlatformExecutor) {}
+
+StreamExecutor::~StreamExecutor() = default;
+
+Expected<std::unique_ptr<Stream>> StreamExecutor::createStream() {
+ Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream =
+ PlatformExecutor->createStream();
+ if (!MaybePlatformStream) {
+ return MaybePlatformStream.takeError();
+ }
+ assert((*MaybePlatformStream)->getExecutor() == PlatformExecutor &&
+ "an executor created a stream with a different stored executor");
+ return llvm::make_unique<Stream>(std::move(*MaybePlatformStream));
+}
+
+} // namespace streamexecutor
PackedKernelArgumentArrayTest.cpp)
target_link_libraries(
packed_kernel_argument_array_test
+ streamexecutor
${llvm_libs}
${GTEST_BOTH_LIBRARIES}
${CMAKE_THREAD_LIBS_INIT})
add_test(PackedKernelArgumentArrayTest packed_kernel_argument_array_test)
+
+add_executable(
+ stream_test
+ StreamTest.cpp)
+target_link_libraries(
+ stream_test
+ streamexecutor
+ ${llvm_libs}
+ ${GTEST_BOTH_LIBRARIES}
+ ${CMAKE_THREAD_LIBS_INIT})
+add_test(StreamTest stream_test)
#include <cassert>
-#include "streamexecutor/Interfaces.h"
#include "streamexecutor/Kernel.h"
#include "streamexecutor/KernelSpec.h"
+#include "streamexecutor/PlatformInterfaces.h"
#include "streamexecutor/StreamExecutor.h"
#include "llvm/ADT/STLExtras.h"
class MockStreamExecutor : public se::StreamExecutor {
public:
MockStreamExecutor()
- : Unique(llvm::make_unique<se::KernelInterface>()), Raw(Unique.get()) {}
+ : se::StreamExecutor(nullptr),
+ Unique(llvm::make_unique<se::KernelInterface>()), Raw(Unique.get()) {}
// Moves the unique pointer into the returned se::Expected instance.
//
--- /dev/null
+//===-- StreamTest.cpp - Tests for Stream ---------------------------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the unit tests for Stream code.
+///
+//===----------------------------------------------------------------------===//
+
+#include <cstring>
+
+#include "streamexecutor/Kernel.h"
+#include "streamexecutor/KernelSpec.h"
+#include "streamexecutor/PlatformInterfaces.h"
+#include "streamexecutor/Stream.h"
+#include "streamexecutor/StreamExecutor.h"
+
+#include "gtest/gtest.h"
+
+namespace {
+
+namespace se = ::streamexecutor;
+
+/// Mock PlatformStreamExecutor that performs asynchronous memcpy operations by
+/// ignoring the stream argument and calling std::memcpy on device memory
+/// handles.
+class MockPlatformStreamExecutor : public se::PlatformStreamExecutor {
+public:
+ ~MockPlatformStreamExecutor() override {}
+
+ std::string getName() const override { return "MockPlatformStreamExecutor"; }
+
+ se::Expected<std::unique_ptr<se::PlatformStreamHandle>>
+ createStream() override {
+ return nullptr;
+ }
+
+ se::Error memcpyD2H(se::PlatformStreamHandle *,
+ const se::GlobalDeviceMemoryBase &DeviceSrc,
+ void *HostDst, size_t ByteCount) override {
+ std::memcpy(HostDst, DeviceSrc.getHandle(), ByteCount);
+ return se::Error::success();
+ }
+
+ se::Error memcpyH2D(se::PlatformStreamHandle *, const void *HostSrc,
+ se::GlobalDeviceMemoryBase *DeviceDst,
+ size_t ByteCount) override {
+ std::memcpy(const_cast<void *>(DeviceDst->getHandle()), HostSrc, ByteCount);
+ return se::Error::success();
+ }
+
+ se::Error memcpyD2D(se::PlatformStreamHandle *,
+ const se::GlobalDeviceMemoryBase &DeviceSrc,
+ se::GlobalDeviceMemoryBase *DeviceDst,
+ size_t ByteCount) override {
+ std::memcpy(const_cast<void *>(DeviceDst->getHandle()),
+ DeviceSrc.getHandle(), ByteCount);
+ return se::Error::success();
+ }
+};
+
+/// Test fixture to hold objects used by tests.
+class StreamTest : public ::testing::Test {
+public:
+ StreamTest()
+ : DeviceA(se::GlobalDeviceMemory<int>::makeFromElementCount(HostA, 10)),
+ DeviceB(se::GlobalDeviceMemory<int>::makeFromElementCount(HostB, 10)),
+ Stream(llvm::make_unique<se::PlatformStreamHandle>(&PlatformExecutor)) {
+ }
+
+protected:
+ // Device memory is backed by host arrays.
+ int HostA[10];
+ se::GlobalDeviceMemory<int> DeviceA;
+ int HostB[10];
+ se::GlobalDeviceMemory<int> DeviceB;
+
+ // Host memory to be used as actual host memory.
+ int Host[10];
+
+ MockPlatformStreamExecutor PlatformExecutor;
+ se::Stream Stream;
+};
+
+TEST_F(StreamTest, MemcpyCorrectSize) {
+ Stream.thenMemcpyH2D(llvm::ArrayRef<int>(Host), &DeviceA);
+ EXPECT_TRUE(Stream.isOK());
+
+ Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef<int>(Host));
+ EXPECT_TRUE(Stream.isOK());
+
+ Stream.thenMemcpyD2D(DeviceA, &DeviceB);
+ EXPECT_TRUE(Stream.isOK());
+}
+
+TEST_F(StreamTest, MemcpyH2DTooManyElements) {
+ Stream.thenMemcpyH2D(llvm::ArrayRef<int>(Host), &DeviceA, 20);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, MemcpyD2HTooManyElements) {
+ Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef<int>(Host), 20);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+TEST_F(StreamTest, MemcpyD2DTooManyElements) {
+ Stream.thenMemcpyD2D(DeviceA, &DeviceB, 20);
+ EXPECT_FALSE(Stream.isOK());
+}
+
+} // namespace