[Converage] Increase code coverage for input services
[platform/adaptation/npu/trix-engine.git] / src / core / ne-host-input-service.cc
index ddc2178..d43a0b4 100644 (file)
 std::unique_ptr<HostInputService> HostInputService::instance_;
 std::once_flag HostInputService::once_flag_;
 
+/**
+ * @brief get singleton instance
+ */
 HostInputService &
-HostInputService::getInstance ()
-{
-  call_once (once_flag_, []() {
-    instance_.reset (new HostInputService);
-  });
+HostInputService::getInstance () {
+  call_once (once_flag_, []() { instance_.reset (new HostInputService); });
   return *(instance_.get ());
 }
 
+/**
+ * @brief submit the request to the thread pool
+ * @param[in] api device driver api
+ * @param[in] req request instance
+ * @param[in] callback output callback
+ * @return 0 if no error, otherwise a negative errno
+ */
 int
-HostInputService::submit (const DriverAPI *api, int id,
-    const Model *model, HWmem *data, outputCallback callback)
-{
-  if (api == nullptr)
+HostInputService::submit (const DriverAPI *api, const Request *req,
+                          outputCallback callback) {
+  if (api == nullptr || req == nullptr) {
+    logerr (TAG, "Invalid arguments\n");
     return -EINVAL;
+  }
 
-  if (dynamic_cast<Buffer *> (data)) {
-    /* empty model is possible */
-    return submit_buffer (api, id, model, dynamic_cast<Buffer *> (data), callback);
-  } else if (dynamic_cast<SegmentTable *> (data)) {
-    if (model == nullptr)
-      return -EINVAL;
-    return submit_segt (api, id, model, dynamic_cast<SegmentTable *> (data), callback);
-  } else {
+  if (req->getOpmode () != NPUINPUT_HOST) {
+    logerr (TAG, "Unmatched opmode\n");
     return -EINVAL;
   }
-}
 
-/**
- * @brief submit the request to the thread pool
- * @param[in] api the driver API
- * @param[in] id the request id
- * @param[in] model the target model
- * @param[in] buffer the target buffer
- * @param[in] callback output callback
- * @return task id if no error, otherwise a negative errno.
- */
-int
-HostInputService::submit_buffer (const DriverAPI *api, int id,
-    const Model *model, Buffer *buffer, outputCallback callback)
-{
-  taskFunc func = std::bind (&HostInputService::invoke_buffer, this,
-      api, model, buffer, callback, id);
-  ThreadTask *task = new ThreadTask (id, func);
-
-  return ThreadPool::getInstance().enqueueTask (task);
-}
+  if (req->getInferData () == nullptr) {
+    logerr (TAG, "inference data is not set\n");
+    return -EINVAL;
+  }
 
-/**
- * @brief submit the request to the thread pool
- * @param[in] api the driver API
- * @param[in] id the request id
- * @param[in] model the target model
- * @param[in] segt the target segment table
- * @param[in] callback output callback
- * @return task id if no error, otherwise a negative errno.
- */
-int
-HostInputService::submit_segt (const DriverAPI *api, int id,
-    const Model *model, SegmentTable *segt, outputCallback callback)
-{
-  taskFunc func = std::bind (&HostInputService::invoke_segt, this,
-      api, model, segt, callback, id);
-  ThreadTask *task = new ThreadTask (id, func);
-
-  return ThreadPool::getInstance().enqueueTask (task);
+  taskFunc func =
+      std::bind (&HostInputService::invoke, this, api, req, callback);
+  ThreadTask *task = new ThreadTask (req->getID (), func);
+  return ThreadPool::getInstance ().enqueueTask (task);
 }
 
 /**
  * @brief remove the submitted request (if possible)
+ * @param[in] api device driver api
  * @param[in] id the request id to be removed
  * @return 0 if no erorr. otherwise a negative errno
  */
 int
-HostInputService::remove (int id)
-{
-  return ThreadPool::getInstance().removeTask (id);
+HostInputService::remove (const DriverAPI *api, int id) {
+  return ThreadPool::getInstance ().removeTask (id);
 }
 
 /**
- * @brief invoke the request using APIs
- * @param[in] api the driver API
- * @param[in] model the target model
- * @param[in] buffer the target buffer
- * @param[in] callback output callback
- * @return 0 if no error, otherwise a negative errno
- * @note this function should be used with TRIV driver!
+ * @brief invoke inference with the segment table (TRIV-2)
+ * @param[in] api device driver api
+ * @param[in] req request instance
+ * @return 0 if no erorr. otherwise a negative errno
  */
-int
-HostInputService::invoke_buffer (const DriverAPI *api, const Model *model,
-    Buffer *buffer, outputCallback callback, int task_id)
-{
-  input_config_t input_config;
-  device_state_t state;
-  int ret = -EINVAL;
+static int
+invoke_segt (const DriverAPI *api, const Request *req) {
+  /** internal logic error */
+  assert (api != nullptr);
+  assert (req != nullptr);
 
-  state = api->isReady();
-  if (state != device_state_t::STATE_READY) {
-    logerr (TAG, "device is not available to run inference %d\n", state);
-    goto handle_callback;
-  }
+  SegmentTable *segt = dynamic_cast<SegmentTable *> (req->getInferData ());
+  assert (segt != nullptr);
 
-  /** internal logic error */
-  assert (buffer != nullptr);
+  const Model *model = req->getModel ();
+  assert (model != nullptr);
 
-  if (model != nullptr) {
-    /** consider NOP cases */
-    if (model->getProgramData() == nullptr) {
-      ret = 0;
-      goto handle_callback;
+  input_config_t input_config;
+  input_config.model_id = model->getInternalID ();
+  input_config.dbuf_fd = segt->getDmabuf ();
+  input_config.req_id = req->getID ();
+  input_config.num_segments = segt->getNumTotalSegments ();
+  input_config.task_handle = UINT32_MAX;
+  input_config.subtask_idx = UINT32_MAX;
+
+  /** FIXME: update input_config fields */
+  if (req->getScheduler () == NPU_SCHEDULER_VD) {
+    if (req->getSchedulerParam ()) {
+      memcpy (&input_config.task_handle, req->getSchedulerParam (),
+              sizeof (uint32_t) * 2);
+    } else {
+      input_config.task_handle = 0;
+      input_config.subtask_idx = 0;
     }
-
-    input_config.model_id = model->getInternalID();
-  } else {
-    input_config.model_id = 0;
   }
 
-  input_config.dbuf_fd = buffer->getDmabuf ();
-  input_config.activation_offset_addr0 = buffer->getOffset ();
-  input_config.activation_offset_addr1 = buffer->getOffset ();
-  input_config.task_id = task_id;
+  /** set constraints */
+  npu_constraint constraint = model->getConstraint ();
+  input_config.timeout_ms = constraint.timeout_ms;
+  input_config.priority = constraint.priority;
+  /** input handling by CPU. host inputservice only supports CPU mode */
+  input_config.input_mode = TRINITY_INPUT_CPU;
+  /** output handling by CPU, host inputservice only supports either interrupt or polling */
+  if (constraint.notimode == NPU_POLLING) {
+    input_config.output_mode = TRINITY_OUTPUT_CPU_POLL;
+  } else { /** default mode is interrupt */
+    input_config.output_mode = TRINITY_OUTPUT_CPU_INTR;
+  }
 
   /** run the inference with the input */
-  ret = api->runInput (&input_config);
+  int ret = api->runInput (&input_config);
   if (ret < 0 && ret != -ECANCELED)
     logerr (TAG, "Failed to run the NPU inference: %d\n", ret);
 
-handle_callback:
-  /** should call the callback regardless of failure, to avoid deadlock */
-  if (callback != nullptr)
-    callback ();
-
   return ret;
 }
 
 /**
- * @brief invoke the request using APIs
- * @param[in] api the driver API
- * @param[in] model the target model
- * @param[in] segt the target segment table
+ * @brief invoke the given request using the driver API
+ * @param[in] api device driver api
+ * @param[in] req request instance
  * @param[in] callback output callback
  * @return 0 if no error, otherwise a negative errno
- * @note this function should be used with TRIV2 driver!
  */
 int
-HostInputService::invoke_segt (const DriverAPI *api, const Model *model,
-    SegmentTable *segt, outputCallback callback, int task_id)
-{
-  input_config_t input_config;
+HostInputService::invoke (const DriverAPI *api, const Request *req,
+                          outputCallback callback) {
   device_state_t state;
-  npuConstraint constraint;
+  const Model *model;
+  HWmem *data;
   int ret = -EINVAL;
 
-  state = api->isReady();
-  if (state != device_state_t::STATE_READY) {
-    logerr (TAG, "device is not available to run inference %d\n", state);
-    goto handle_callback;
-  }
-
   /** internal logic error */
-  assert (model != nullptr);
-  assert (segt != nullptr);
+  assert (api != nullptr);
+  assert (req != nullptr);
 
-  /** consider NOP cases */
-  if (model->getProgramData() == nullptr) {
-    ret = 0;
+  model = req->getModel ();
+  if (model == nullptr) {
+    logerr (TAG, "unable to find the target model\n");
     goto handle_callback;
   }
 
-  input_config.model_id = model->getInternalID();
-  input_config.dbuf_fd = segt->getDmabuf ();
-  input_config.num_segments = segt->getNumTotalSegments ();
-
-  /** set constraints */
-  constraint = model->getConstraint ();
-  input_config.timeout_ms = constraint.timeout_ms;
-  input_config.priority = constraint.priority;
-
-  /** input handling by CPU. host inputservice only supports CPU mode */
-  input_config.input_mode = INPUT_CPU;
+  state = api->isReady ();
+  if (state != device_state_t::TRINITY_STATE_READY) {
+    logerr (TAG, "device is not available to run inference %d\n", state);
+    goto handle_callback;
+  }
 
-  /** output handling by CPU, host inputservice only supports either interrupt or polling */
-  if (constraint.notimode == NPU_POLLING) {
-    input_config.output_mode = OUTPUT_CPU_POLL;
-  } else { /** default mode is interrupt */
-    input_config.output_mode = OUTPUT_CPU_INTR;
+  if (model->getProgramData () == nullptr) {
+    ret = 0;
+    goto handle_callback;
   }
 
-  input_config.task_id = task_id;
-  /** run the inference with the input */
-  ret = api->runInput (&input_config);
-  if (ret < 0 && ret != -ECANCELED)
-    logerr (TAG, "Failed to run the NPU inference: %d\n", ret);
+  data = req->getInferData ();
+  if (dynamic_cast<SegmentTable *> (data))
+    ret = invoke_segt (api, req);
+  else /* no inference data; skip */
+    ret = 0;
 
 handle_callback:
-  /** should call the callback regardless of failure, to avoid deadlock */
   if (callback != nullptr)
     callback ();