namespace perftools {
namespace gputools {
+/* static */ mutex MultiPlatformManager::platforms_mutex_{LINKER_INITIALIZED};
+
+/* static */ port::StatusOr<Platform*> MultiPlatformManager::LookupByNameLocked(
+ const string& target) {
+ PlatformMap* platform_map = GetPlatformMap();
+ auto it = platform_map->find(port::Lowercase(target));
+ if (it == platform_map->end()) {
+ return port::Status(
+ port::error::NOT_FOUND,
+ "could not find registered platform with name: \"" + target + "\"");
+ }
+ return it->second;
+}
+
+/* static */ port::StatusOr<Platform*> MultiPlatformManager::LookupByIdLocked(
+ const Platform::Id& id) {
+ PlatformIdMap* platform_map = GetPlatformByIdMap();
+ auto it = platform_map->find(id);
+ if (it == platform_map->end()) {
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::Printf("could not find registered platform with id: 0x%p", id));
+ }
+ return it->second;
+}
+
/* static */ port::Status MultiPlatformManager::RegisterPlatform(
std::unique_ptr<Platform> platform) {
CHECK(platform != nullptr);
string key = port::Lowercase(platform->Name());
- mutex_lock lock(GetPlatformsMutex());
+ mutex_lock lock(platforms_mutex_);
if (GetPlatformMap()->find(key) != GetPlatformMap()->end()) {
return port::Status(port::error::INTERNAL,
"platform is already registered with name: \"" +
/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
const string& target) {
- tf_shared_lock lock(GetPlatformsMutex());
- auto it = GetPlatformMap()->find(port::Lowercase(target));
+ mutex_lock lock(platforms_mutex_);
- if (it == GetPlatformMap()->end()) {
- return port::Status(
- port::error::NOT_FOUND,
- "could not find registered platform with name: \"" + target + "\"");
+ SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
+ if (!platform->Initialized()) {
+ SE_RETURN_IF_ERROR(platform->Initialize({}));
}
- return it->second;
+ return platform;
}
/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
const Platform::Id& id) {
- tf_shared_lock lock(GetPlatformsMutex());
- auto it = GetPlatformByIdMap()->find(id);
- if (it == GetPlatformByIdMap()->end()) {
+ mutex_lock lock(platforms_mutex_);
+
+ SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
+ if (!platform->Initialized()) {
+ SE_RETURN_IF_ERROR(platform->Initialize({}));
+ }
+
+ return platform;
+}
+
+/* static */ port::StatusOr<Platform*>
+MultiPlatformManager::InitializePlatformWithName(
+ const string& target, const std::map<string, string>& options) {
+ mutex_lock lock(platforms_mutex_);
+
+ SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
+ if (platform->Initialized()) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "platform \"" + target + "\" is already initialized");
+ }
+
+ SE_RETURN_IF_ERROR(platform->Initialize(options));
+
+ return platform;
+}
+
+/* static */ port::StatusOr<Platform*>
+MultiPlatformManager::InitializePlatformWithId(
+ const Platform::Id& id, const std::map<string, string>& options) {
+ mutex_lock lock(platforms_mutex_);
+
+ SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
+ if (platform->Initialized()) {
return port::Status(
- port::error::NOT_FOUND,
- port::Printf("could not find registered platform with id: 0x%p", id));
+ port::error::FAILED_PRECONDITION,
+ port::Printf("platform with id 0x%p is already initialized", id));
}
- return it->second;
+ SE_RETURN_IF_ERROR(platform->Initialize(options));
+
+ return platform;
}
/* static */ void MultiPlatformManager::ClearPlatformRegistry() {
- mutex_lock lock(GetPlatformsMutex());
+ mutex_lock lock(platforms_mutex_);
GetPlatformMap()->clear();
GetPlatformByIdMap()->clear();
}
#include <functional>
#include <map>
#include <memory>
-#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
namespace perftools {
namespace gputools {
// already registered. The associated listener, if not null, will be used to
// trace events for ALL executors for that platform.
// Takes ownership of listener.
- static port::Status RegisterPlatform(std::unique_ptr<Platform> platform);
+ static port::Status RegisterPlatform(std::unique_ptr<Platform> platform)
+ LOCKS_EXCLUDED(platforms_mutex_);
- // Retrieves the platform registered with the given platform name; e.g.
- // "CUDA", "OpenCL", ...
+ // Retrieves the platform registered with the given platform name (e.g.
+ // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the
+ // Platform's Id() method).
+ //
+ // If the platform has not already been initialized, it will be initialized
+ // with a default set of parameters.
//
// If the requested platform is not registered, an error status is returned.
// Ownership of the platform is NOT transferred to the caller --
// the MultiPlatformManager owns the platforms in a singleton-like fashion.
- static port::StatusOr<Platform*> PlatformWithName(const string& target);
-
- // Retrieves the platform registered with the given platform ID, which
- // is an opaque (but comparable) value.
+ static port::StatusOr<Platform*> PlatformWithName(const string& target)
+ LOCKS_EXCLUDED(platforms_mutex_);
+ static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
+ LOCKS_EXCLUDED(platforms_mutex_);
+
+ // Retrieves the platform registered with the given platform name (e.g.
+ // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the
+ // Platform's Id() method).
+ //
+ // The platform will be initialized with the given options. If the platform
+ // was already initialized, an error will be returned.
//
// If the requested platform is not registered, an error status is returned.
// Ownership of the platform is NOT transferred to the caller --
// the MultiPlatformManager owns the platforms in a singleton-like fashion.
- static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id);
+ static port::StatusOr<Platform*> InitializePlatformWithName(
+ const string& target, const std::map<string, string>& options)
+ LOCKS_EXCLUDED(platforms_mutex_);
+ static port::StatusOr<Platform*> InitializePlatformWithId(
+ const Platform::Id& id, const std::map<string, string>& options)
+ LOCKS_EXCLUDED(platforms_mutex_);
// Clears the set of registered platforms, primarily used for testing.
- static void ClearPlatformRegistry();
+ static void ClearPlatformRegistry() LOCKS_EXCLUDED(platforms_mutex_);
// Although the MultiPlatformManager "owns" its platforms, it holds them as
// undecorated pointers to prevent races during program exit (between this
// Provides access to the available set of platforms under a lock.
static port::Status WithPlatforms(
- std::function<port::Status(PlatformMap*)> callback) {
- mutex_lock lock(GetPlatformsMutex());
+ std::function<port::Status(PlatformMap*)> callback)
+ LOCKS_EXCLUDED(platforms_mutex_) {
+ mutex_lock lock(platforms_mutex_);
return callback(GetPlatformMap());
}
private:
- // mutex that guards the platform map.
- static mutex& GetPlatformsMutex() {
- static mutex* platforms_mutex = new mutex;
- return *platforms_mutex;
- }
+ using PlatformIdMap = std::map<Platform::Id, Platform*>;
+
+ static mutex platforms_mutex_;
// TODO(b/22689637): Clean up these two maps; make sure they coexist nicely.
// TODO(b/22689637): Move this (whatever the final/"official" map is) to
// Holds a Platform::Id-to-object mapping.
// Unlike platforms_ above, this map does not own its contents.
- static std::map<Platform::Id, Platform*>* GetPlatformByIdMap() {
- using PlatformIdMap = std::map<Platform::Id, Platform*>;
+ static PlatformIdMap* GetPlatformByIdMap() {
static PlatformIdMap* instance = new PlatformIdMap;
return instance;
}
+ // Looks up the platform object with the given name. Assumes the Platforms
+ // mutex is held.
+ static port::StatusOr<Platform*> LookupByNameLocked(const string& target)
+ EXCLUSIVE_LOCKS_REQUIRED(platforms_mutex_);
+
+ // Looks up the platform object with the given id. Assumes the Platforms
+ // mutex is held.
+ static port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
+ EXCLUSIVE_LOCKS_REQUIRED(platforms_mutex_);
+
SE_DISALLOW_COPY_AND_ASSIGN(MultiPlatformManager);
};
Platform::~Platform() {}
+bool Platform::Initialized() const { return true; }
+
+port::Status Platform::Initialize(
+ const std::map<string, string> &platform_options) {
+ if (!platform_options.empty()) {
+ return port::Status(port::error::UNIMPLEMENTED,
+ "this platform does not support custom initialization");
+ }
+ return port::Status::OK();
+}
+
port::Status Platform::ForceExecutorShutdown() {
return port::Status(port::error::UNIMPLEMENTED,
"executor shutdown is not supported on this platform");
// Returns a key uniquely identifying this platform.
virtual Id id() const = 0;
+ // Name of this platform.
+ virtual const string& Name() const = 0;
+
// Returns the number of devices accessible on this platform.
//
// Note that, though these devices are visible, if there is only one userspace
// device, a call to ExecutorForDevice may return an error status.
virtual int VisibleDeviceCount() const = 0;
- // Name of this platform.
- virtual const string& Name() const = 0;
+ // Returns true iff the platform has been initialized.
+ virtual bool Initialized() const;
+
+ // Initializes the platform with a custom set of options. The platform must be
+ // initialized before obtaining StreamExecutor objects. The interpretation of
+ // the platform_options argument is implementation specific. This method may
+ // return an error if unrecognized options are provided. If using
+ // MultiPlatformManager, this method will be called automatically by
+ // InitializePlatformWithId/InitializePlatformWithName.
+ virtual port::Status Initialize(
+ const std::map<string, string>& platform_options);
// Returns a device with the given ordinal on this platform with a default
// plugin configuration or, if none can be found with the given ordinal or
// This is only useful on platforms which bind a device to a single process
// that has obtained the device context. May return UNIMPLEMENTED on platforms
// that have no reason to destroy device contexts.
+ //
+ // The platform must be reinitialized after this is called.
virtual port::Status ForceExecutorShutdown();
// Registers a TraceListener to listen to all StreamExecutors for this