[Scheduler] Implement the request stop feature
authorDongju Chae <dongju.chae@samsung.com>
Wed, 17 Jun 2020 06:39:49 +0000 (15:39 +0900)
committer송욱/On-Device Lab(SR)/Staff Engineer/삼성전자 <wook16.song@samsung.com>
Thu, 18 Jun 2020 01:55:15 +0000 (10:55 +0900)
This patch implements the request stop feature to terminate
all submitted requests. If a request is already running, it's
expected for runInput() to return -ECANCELED within a short time.

Signed-off-by: Dongju Chae <dongju.chae@samsung.com>
src/core/ne-host-input-service.cc
src/core/ne-scheduler.cc
src/core/ne-scheduler.h
src/core/npu/NPUdrvAPI.h
src/core/npu/NPUdrvAPI_emul.cc
src/core/utils/ne-utils.h

index bc0a4e3..4e24b1b 100644 (file)
@@ -132,7 +132,7 @@ HostInputService::invoke_buffer (const DriverAPI *api, const Model *model,
 
   /** run the inference with the input */
   ret = api->runInput (&input_config);
-  if (ret != 0)
+  if (ret != 0 && ret != -ECANCELED)
     logerr (TAG, "Failed to run the NPU inference: %d\n", ret);
 
 handle_callback:
@@ -200,7 +200,7 @@ HostInputService::invoke_segt (const DriverAPI *api, const Model *model,
 
   /** run the inference with the input */
   ret = api->runInput (&input_config);
-  if (ret != 0)
+  if (ret != 0 && ret != -ECANCELED)
     logerr (TAG, "Failed to run the NPU inference: %d\n", ret);
 
 handle_callback:
index fb7454c..648cac7 100644 (file)
@@ -21,8 +21,9 @@ std::atomic<uint32_t> Request::global_request_id_ (1);
 
 /** @brief constructor of request class */
 Request::Request (npu_input_opmode opmode)
-  : opmode_ (opmode), force_stop_ (false), model_ (nullptr),
-    buffer_ (nullptr), segt_ (nullptr), cb_ (nullptr)
+  : opmode_ (opmode), force_stop_ (false), stopped_ (false),
+    model_ (nullptr), buffer_ (nullptr), segt_ (nullptr),
+    cb_ (nullptr)
 {
   request_id_ = Request::global_request_id_.fetch_add(1);
 }
@@ -56,7 +57,32 @@ Scheduler::submitRequest (Request *req)
   int status = 0;
 
   if (opmode == NPUINPUT_STOP) {
-    /** TODO: stop all requests from this instance */
+    if (req->getForceStop ()) {
+      std::function <bool (Request *req)> functor =
+        [] (Request *req) -> bool {
+          bool can_remove = true;
+
+          /* remove a request if it's not scheduled */
+          if (InferenceEngine::stopRequest (req->getOpmode (), req->getID ()) != 0) {
+            /* In case of already-served requests, let's mark it as stopped */
+            req->setStopped ();
+            can_remove = false;
+          }
+
+          return can_remove;
+        };
+      request_map_.for_each (functor);
+
+      /* send the stop signal to the device driver */
+      status = api_->stop ();
+      if (status != 0)
+        return status;
+    }
+
+    /* wait until all requests are handled */
+    request_map_.wait_empty ();
+
+    delete req;
   } else {
     status = request_map_.insert (req->getID(), req);
     assert (status == 0); /** request ID is atomic value. So, should be successful */
@@ -90,7 +116,7 @@ Scheduler::handleCallback (Request *req)
 {
   outputCallback callback = req->getCallback();
 
-  if (callback != nullptr)
+  if (!req->isStopped () && callback != nullptr)
     callback ();
 
   /** the request instance is also deleted here */
index 8b2e872..ae970ea 100644 (file)
@@ -50,12 +50,16 @@ public:
   npu_input_opmode getOpmode () { return opmode_; }
   uint32_t getID () { return request_id_; }
 
+  void setStopped () { stopped_ = true; }
+  bool isStopped () { return stopped_; }
+
 private:
   static std::atomic<uint32_t> global_request_id_;
   uint32_t request_id_;     /**< request id */
 
   npu_input_opmode opmode_; /**< opmode of the request */
   bool force_stop_;         /**< indicates force stop */
+  bool stopped_;            /**< stopped request */
 
   const Model *model_;      /**< model of the request */
   Buffer *buffer_;          /**< buffer of the request */
index 0ab60b4..cfe4948 100644 (file)
@@ -89,6 +89,9 @@ class DriverAPI {
 
     /** @brief run inference with the input config */
     virtual int runInput (input_config_t *input) const { return -EPERM; }
+    /** @brief stop all requests. The stopped requests should be notified */
+    virtual int stop () const { return 0; }
+
     /** @brief register model config to the driver */
     virtual int registerModel (model_config_t *model) const { return -EPERM; }
     virtual int deregisterModel (unsigned long long id) const { return -EPERM; }
@@ -199,6 +202,8 @@ class TrinityEmulAPI : public DriverAPI {
     int munmap (void *addr, size_t size) const;
 
     int runInput (input_config_t *input) const;
+    int stop () const;
+
     int registerModel (model_config_t *model) const;
     int deregisterModel (unsigned long long id) const;
 
index d37904b..e922546 100644 (file)
@@ -352,6 +352,7 @@ int
 TrinityEmulAPI::runInput (input_config_t *input_config) const
 {
   int dbuf_fd;
+  int status = -EPERM;
 
   if (!initialized())
     return -EPERM;
@@ -380,7 +381,7 @@ TrinityEmulAPI::runInput (input_config_t *input_config) const
    * call NPU C-emulation codes (AIP/NPU_SystemService_Emulator)
    */
   if ((dev_type_ & DEVICETYPE_MASK) == DEVICETYPE_TRIV) {
-    run_triv_emul (addr_model + model->program_offset_addr, model->program_size,
+    status = run_triv_emul (addr_model + model->program_offset_addr, model->program_size,
         addr_input);
   } else if ((dev_type_ & DEVICETYPE_MASK) == DEVICETYPE_TRIV2) {
     if (input_config->num_segments <= 0)
@@ -403,10 +404,26 @@ TrinityEmulAPI::runInput (input_config_t *input_config) const
       segment_table[i] = static_cast<char *>(elem->getAddr ()) + offset;
     }
 
-    run_triv2_emul (addr_model + model->program_offset_addr, model->program_size,
+    status = run_triv2_emul (addr_model + model->program_offset_addr, model->program_size,
         segment_table, num_segs);
     delete [] segment_table;
   }
 
-  return 0;
+  return status;
+}
+
+/**
+ * @brief stop all inferences in this device
+ * @return 0 if no error. otherwise a negative errno
+ */
+int
+TrinityEmulAPI::stop () const
+{
+  if ((dev_type_ & DEVICETYPE_MASK) == DEVICETYPE_TRIV) {
+    return stop_triv_emul ();
+  } else if ((dev_type_ & DEVICETYPE_MASK) == DEVICETYPE_TRIV2) {
+    return stop_triv2_emul ();
+  }
+
+  return -EPERM;
 }
index c3036ed..7db934b 100644 (file)
@@ -19,6 +19,7 @@
 #include <mutex>
 #include <memory>
 #include <condition_variable>
+#include <functional>
 
 /********************************************************************
  * Logging utilities                                                *
@@ -99,7 +100,7 @@ class ThreadSafeMap
 {
   public:
     ThreadSafeMap () : num_entries_ (0) {}
-    ~ThreadSafeMap () {}
+    ~ThreadSafeMap () { clear(); }
 
     /** @brief find the target element */
     V * find (K key) {
@@ -151,6 +152,8 @@ class ThreadSafeMap
       it = map_.begin ();
       while (it != map_.end())
         it = map_.erase (it);
+
+      num_entries_ = 0;
     }
 
     /** @brief wait until all elements are removed */
@@ -159,6 +162,28 @@ class ThreadSafeMap
       cv_.wait (lock, [this]() { return num_entries_ == 0; });
     }
 
+    /**
+     * @brief apply the passed function for each entry.
+     * @note if the func returns true, it removes the entry
+     */
+    void for_each (std::function<bool (V *)> & func) {
+      typename std::map<K, std::unique_ptr<V>>::iterator it;
+      std::unique_lock<std::mutex> lock(m_);
+
+      it = map_.begin ();
+      while (it != map_.end()) {
+        if (func (it->second.get ())) {
+          it = map_.erase (it);
+          num_entries_--;
+        } else {
+          ++it;
+        }
+      }
+
+      if (num_entries_ == 0)
+        cv_.notify_all ();
+    }
+
   private:
     uint32_t num_entries_;                  /**< number of entries */
     std::map<K, std::unique_ptr<V>> map_;   /**< map internal instance */