[SE] Initial perftools::gputools::Platform initialization support
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Mar 2018 19:38:46 +0000 (11:38 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Mar 2018 19:43:00 +0000 (11:43 -0800)
Adds initialization methods to Platform.  Some platforms require initialization.
Those that do not have trivial implementations of these methods.

PiperOrigin-RevId: 188363315

tensorflow/stream_executor/multi_platform_manager.cc
tensorflow/stream_executor/multi_platform_manager.h
tensorflow/stream_executor/platform.cc
tensorflow/stream_executor/platform.h

index f23224a..f9f3737 100644 (file)
@@ -23,11 +23,37 @@ limitations under the License.
 namespace perftools {
 namespace gputools {
 
+/* static */ mutex MultiPlatformManager::platforms_mutex_{LINKER_INITIALIZED};
+
+/* static */ port::StatusOr<Platform*> MultiPlatformManager::LookupByNameLocked(
+    const string& target) {
+  PlatformMap* platform_map = GetPlatformMap();
+  auto it = platform_map->find(port::Lowercase(target));
+  if (it == platform_map->end()) {
+    return port::Status(
+        port::error::NOT_FOUND,
+        "could not find registered platform with name: \"" + target + "\"");
+  }
+  return it->second;
+}
+
+/* static */ port::StatusOr<Platform*> MultiPlatformManager::LookupByIdLocked(
+    const Platform::Id& id) {
+  PlatformIdMap* platform_map = GetPlatformByIdMap();
+  auto it = platform_map->find(id);
+  if (it == platform_map->end()) {
+    return port::Status(
+        port::error::NOT_FOUND,
+        port::Printf("could not find registered platform with id: 0x%p", id));
+  }
+  return it->second;
+}
+
 /* static */ port::Status MultiPlatformManager::RegisterPlatform(
     std::unique_ptr<Platform> platform) {
   CHECK(platform != nullptr);
   string key = port::Lowercase(platform->Name());
-  mutex_lock lock(GetPlatformsMutex());
+  mutex_lock lock(platforms_mutex_);
   if (GetPlatformMap()->find(key) != GetPlatformMap()->end()) {
     return port::Status(port::error::INTERNAL,
                         "platform is already registered with name: \"" +
@@ -45,33 +71,63 @@ namespace gputools {
 
 /* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
     const string& target) {
-  tf_shared_lock lock(GetPlatformsMutex());
-  auto it = GetPlatformMap()->find(port::Lowercase(target));
+  mutex_lock lock(platforms_mutex_);
 
-  if (it == GetPlatformMap()->end()) {
-    return port::Status(
-        port::error::NOT_FOUND,
-        "could not find registered platform with name: \"" + target + "\"");
+  SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
+  if (!platform->Initialized()) {
+    SE_RETURN_IF_ERROR(platform->Initialize({}));
   }
 
-  return it->second;
+  return platform;
 }
 
 /* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
     const Platform::Id& id) {
-  tf_shared_lock lock(GetPlatformsMutex());
-  auto it = GetPlatformByIdMap()->find(id);
-  if (it == GetPlatformByIdMap()->end()) {
+  mutex_lock lock(platforms_mutex_);
+
+  SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
+  if (!platform->Initialized()) {
+    SE_RETURN_IF_ERROR(platform->Initialize({}));
+  }
+
+  return platform;
+}
+
+/* static */ port::StatusOr<Platform*>
+MultiPlatformManager::InitializePlatformWithName(
+    const string& target, const std::map<string, string>& options) {
+  mutex_lock lock(platforms_mutex_);
+
+  SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
+  if (platform->Initialized()) {
+    return port::Status(port::error::FAILED_PRECONDITION,
+                        "platform \"" + target + "\" is already initialized");
+  }
+
+  SE_RETURN_IF_ERROR(platform->Initialize(options));
+
+  return platform;
+}
+
+/* static */ port::StatusOr<Platform*>
+MultiPlatformManager::InitializePlatformWithId(
+    const Platform::Id& id, const std::map<string, string>& options) {
+  mutex_lock lock(platforms_mutex_);
+
+  SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
+  if (platform->Initialized()) {
     return port::Status(
-        port::error::NOT_FOUND,
-        port::Printf("could not find registered platform with id: 0x%p", id));
+        port::error::FAILED_PRECONDITION,
+        port::Printf("platform with id 0x%p is already initialized", id));
   }
 
-  return it->second;
+  SE_RETURN_IF_ERROR(platform->Initialize(options));
+
+  return platform;
 }
 
 /* static */ void MultiPlatformManager::ClearPlatformRegistry() {
-  mutex_lock lock(GetPlatformsMutex());
+  mutex_lock lock(platforms_mutex_);
   GetPlatformMap()->clear();
   GetPlatformByIdMap()->clear();
 }
index ea6155b..438653e 100644 (file)
@@ -67,13 +67,13 @@ limitations under the License.
 #include <functional>
 #include <map>
 #include <memory>
-#include "tensorflow/stream_executor/platform/port.h"
 
 #include "tensorflow/stream_executor/lib/status.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 #include "tensorflow/stream_executor/platform.h"
 #include "tensorflow/stream_executor/platform/mutex.h"
 #include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
 
 namespace perftools {
 namespace gputools {
@@ -85,26 +85,43 @@ class MultiPlatformManager {
   // already registered. The associated listener, if not null, will be used to
   // trace events for ALL executors for that platform.
   // Takes ownership of listener.
-  static port::Status RegisterPlatform(std::unique_ptr<Platform> platform);
+  static port::Status RegisterPlatform(std::unique_ptr<Platform> platform)
+      LOCKS_EXCLUDED(platforms_mutex_);
 
-  // Retrieves the platform registered with the given platform name; e.g.
-  // "CUDA", "OpenCL", ...
+  // Retrieves the platform registered with the given platform name (e.g.
+  // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the
+  // Platform's Id() method).
+  //
+  // If the platform has not already been initialized, it will be initialized
+  // with a default set of parameters.
   //
   // If the requested platform is not registered, an error status is returned.
   // Ownership of the platform is NOT transferred to the caller --
   // the MultiPlatformManager owns the platforms in a singleton-like fashion.
-  static port::StatusOr<Platform*> PlatformWithName(const string& target);
-
-  // Retrieves the platform registered with the given platform ID, which
-  // is an opaque (but comparable) value.
+  static port::StatusOr<Platform*> PlatformWithName(const string& target)
+      LOCKS_EXCLUDED(platforms_mutex_);
+  static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id)
+      LOCKS_EXCLUDED(platforms_mutex_);
+
+  // Retrieves the platform registered with the given platform name (e.g.
+  // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the
+  // Platform's Id() method).
+  //
+  // The platform will be initialized with the given options. If the platform
+  // was already initialized, an error will be returned.
   //
   // If the requested platform is not registered, an error status is returned.
   // Ownership of the platform is NOT transferred to the caller --
   // the MultiPlatformManager owns the platforms in a singleton-like fashion.
-  static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id);
+  static port::StatusOr<Platform*> InitializePlatformWithName(
+      const string& target, const std::map<string, string>& options)
+      LOCKS_EXCLUDED(platforms_mutex_);
+  static port::StatusOr<Platform*> InitializePlatformWithId(
+      const Platform::Id& id, const std::map<string, string>& options)
+      LOCKS_EXCLUDED(platforms_mutex_);
 
   // Clears the set of registered platforms, primarily used for testing.
-  static void ClearPlatformRegistry();
+  static void ClearPlatformRegistry() LOCKS_EXCLUDED(platforms_mutex_);
 
   // Although the MultiPlatformManager "owns" its platforms, it holds them as
   // undecorated pointers to prevent races during program exit (between this
@@ -122,17 +139,16 @@ class MultiPlatformManager {
 
   // Provides access to the available set of platforms under a lock.
   static port::Status WithPlatforms(
-      std::function<port::Status(PlatformMap*)> callback) {
-    mutex_lock lock(GetPlatformsMutex());
+      std::function<port::Status(PlatformMap*)> callback)
+      LOCKS_EXCLUDED(platforms_mutex_) {
+    mutex_lock lock(platforms_mutex_);
     return callback(GetPlatformMap());
   }
 
  private:
-  // mutex that guards the platform map.
-  static mutex& GetPlatformsMutex() {
-    static mutex* platforms_mutex = new mutex;
-    return *platforms_mutex;
-  }
+  using PlatformIdMap = std::map<Platform::Id, Platform*>;
+
+  static mutex platforms_mutex_;
 
   // TODO(b/22689637): Clean up these two maps; make sure they coexist nicely.
   // TODO(b/22689637): Move this (whatever the final/"official" map is) to
@@ -147,12 +163,21 @@ class MultiPlatformManager {
 
   // Holds a Platform::Id-to-object mapping.
   // Unlike platforms_ above, this map does not own its contents.
-  static std::map<Platform::Id, Platform*>* GetPlatformByIdMap() {
-    using PlatformIdMap = std::map<Platform::Id, Platform*>;
+  static PlatformIdMap* GetPlatformByIdMap() {
     static PlatformIdMap* instance = new PlatformIdMap;
     return instance;
   }
 
+  // Looks up the platform object with the given name.  Assumes the Platforms
+  // mutex is held.
+  static port::StatusOr<Platform*> LookupByNameLocked(const string& target)
+      EXCLUSIVE_LOCKS_REQUIRED(platforms_mutex_);
+
+  // Looks up the platform object with the given id.  Assumes the Platforms
+  // mutex is held.
+  static port::StatusOr<Platform*> LookupByIdLocked(const Platform::Id& id)
+      EXCLUSIVE_LOCKS_REQUIRED(platforms_mutex_);
+
   SE_DISALLOW_COPY_AND_ASSIGN(MultiPlatformManager);
 };
 
index 93f08d0..4cdc22b 100644 (file)
@@ -85,6 +85,17 @@ StreamExecutorConfig::StreamExecutorConfig(int ordinal_in)
 
 Platform::~Platform() {}
 
+bool Platform::Initialized() const { return true; }
+
+port::Status Platform::Initialize(
+    const std::map<string, string> &platform_options) {
+  if (!platform_options.empty()) {
+    return port::Status(port::error::UNIMPLEMENTED,
+                        "this platform does not support custom initialization");
+  }
+  return port::Status::OK();
+}
+
 port::Status Platform::ForceExecutorShutdown() {
   return port::Status(port::error::UNIMPLEMENTED,
                       "executor shutdown is not supported on this platform");
index f0a0e60..54f8aa8 100644 (file)
@@ -111,6 +111,9 @@ class Platform {
   // Returns a key uniquely identifying this platform.
   virtual Id id() const = 0;
 
+  // Name of this platform.
+  virtual const string& Name() const = 0;
+
   // Returns the number of devices accessible on this platform.
   //
   // Note that, though these devices are visible, if there is only one userspace
@@ -118,8 +121,17 @@ class Platform {
   // device, a call to ExecutorForDevice may return an error status.
   virtual int VisibleDeviceCount() const = 0;
 
-  // Name of this platform.
-  virtual const string& Name() const = 0;
+  // Returns true iff the platform has been initialized.
+  virtual bool Initialized() const;
+
+  // Initializes the platform with a custom set of options. The platform must be
+  // initialized before obtaining StreamExecutor objects.  The interpretation of
+  // the platform_options argument is implementation specific.  This method may
+  // return an error if unrecognized options are provided.  If using
+  // MultiPlatformManager, this method will be called automatically by
+  // InitializePlatformWithId/InitializePlatformWithName.
+  virtual port::Status Initialize(
+      const std::map<string, string>& platform_options);
 
   // Returns a device with the given ordinal on this platform with a default
   // plugin configuration or, if none can be found with the given ordinal or
@@ -156,6 +168,8 @@ class Platform {
   // This is only useful on platforms which bind a device to a single process
   // that has obtained the device context. May return UNIMPLEMENTED on platforms
   // that have no reason to destroy device contexts.
+  //
+  // The platform must be reinitialized after this is called.
   virtual port::Status ForceExecutorShutdown();
 
   // Registers a TraceListener to listen to all StreamExecutors for this