[API] Add API to get model id of request and handle empty weight
authorDongju Chae <dongju.chae@samsung.com>
Wed, 23 Jun 2021 06:57:29 +0000 (15:57 +0900)
committer채동주/On-Device Lab(SR)/Staff Engineer/삼성전자 <dongju.chae@samsung.com>
Wed, 23 Jun 2021 09:36:08 +0000 (18:36 +0900)
This patch adds APi to get model id of request and handle the case
that model contains empty weight data.

Signed-off-by: Dongju Chae <dongju.chae@samsung.com>
include/host/libnpuhost.h
src/core/ne-handler.cc
src/core/ne-handler.h
src/core/ne-segment-table.cc
src/host/ne-host.cc

index 98acb0e..a807f46 100644 (file)
@@ -510,6 +510,15 @@ int createNPU_request (npudev_h dev, uint32_t model_id, int *req_id);
 int removeNPU_request (npudev_h dev, int req_id);
 
 /**
+ * @brief Get the request's model id
+ * @param[in] dev The NPU device handle
+ * @param[in] req_id The request's ID
+ * @param[out] model_id The request's model ID
+ * @return 0 if no error. Otherwise a negative errno
+ */
+int getNPU_requestModel (npudev_h dev, int req_id, uint32_t *model_id);
+
+/**
  * @brief Set request's input/output data
  * @param[in] dev The NPU device handle
  * @param[in] req_id The request ID
index 64c244a..602cdd2 100644 (file)
@@ -326,6 +326,17 @@ HostHandler::removeRequest (int req_id) {
 }
 
 /**
+ * @brief Get the request's model id
+ * @param[in] req_id The request's ID
+ * @param[out] model_id The request's model ID
+ * @return 0 if no error. Otherwise a negative errno
+ */
+int
+HostHandler::getRequestModel (int req_id, uint32_t *model_id) {
+  return device_->getRequestModel (req_id, model_id);
+}
+
+/**
  * @brief Set request's input/output data
  * @param[in] req_id The request ID
  * @param[in] input The input data buffers
@@ -1257,6 +1268,29 @@ TrinityVision2::removeRequest (int req_id) {
 }
 
 int
+TrinityVision2::getRequestModel (int req_id, uint32_t *model_id) {
+  if (model_id == nullptr) {
+    logerr (TAG, "Invalid argument detected\n");
+    return -EINVAL;
+  }
+
+  Request *req = scheduler_->findRequest (req_id);
+  if (req == nullptr) {
+    logerr (TAG, "Unable to find the request with ID (%d)\n", req_id);
+    return -ENOENT;
+  }
+
+  const Model *model = req->getModel ();
+  if (model == nullptr) {
+    logerr (TAG, "Unable to find the request's model\n");
+    return -ENOENT;
+  }
+
+  *model_id = model->getID ();
+  return 0;
+}
+
+int
 TrinityVision2::setRequestData (int req_id, input_buffers *input,
                                 tensors_data_info *in_info,
                                 output_buffers *output,
@@ -1273,14 +1307,17 @@ TrinityVision2::setRequestData (int req_id, input_buffers *input,
     return -EINVAL;
   }
 
-  if (input == nullptr || in_info == nullptr || output == nullptr ||
-      out_info == nullptr) {
+  if (input == nullptr || output == nullptr) {
     logerr (TAG, "Invalid arguments detected\n");
     return -EINVAL;
   }
 
-  /* FIXME: should be per request, not per model */
-  const_cast<Model *> (model)->setDataInfo (in_info, out_info);
+  if (in_info != nullptr && out_info != nullptr) {
+    /* FIXME: should be per request, not per model */
+    const_cast<Model *> (model)->setDataInfo (in_info, out_info);
+  }
+
+  const_cast<Model *> (model)->updateDataInfo ();
 
   /** this device uses segment table */
   SegmentTable *segt = prepareSegmentTable (model, input, output);
index 9a43912..6ce7a3f 100644 (file)
@@ -65,6 +65,7 @@ class HostHandler {
 
   int createRequest (uint32_t model_id, int *req_id);
   int removeRequest (int req_id);
+  int getRequestModel (int req_id, uint32_t *model_id);
 
   int setRequestData (int req_id, input_buffers *input,
                       tensors_data_info *in_info, output_buffers *output,
@@ -154,6 +155,7 @@ class Device {
 
   virtual int createRequest (const Model *model, int *req_id) = 0;
   virtual int removeRequest (int req_id) = 0;
+  virtual int getRequestModel (int req_id, uint32_t *model_id) = 0;
 
   virtual int setRequestData (int req_id, input_buffers *input,
                               tensors_data_info *in_info,
@@ -219,6 +221,8 @@ class TrinityVision2 : public Device {
 
   int createRequest (const Model *model, int *req_id);
   int removeRequest (int req_id);
+  int getRequestModel (int req_id, uint32_t *model_id);
+
   int setRequestData (int req_id, input_buffers *input,
                       tensors_data_info *in_info, output_buffers *output,
                       tensors_data_info *out_info);
index edf7c45..e5569bb 100644 (file)
@@ -288,18 +288,23 @@ SegmentTable::createSegments (const Model *model, const input_buffers *input,
 
   /** segment index validity is already checked in Metadata's checkSanity () */
   num_total_segments_ = meta->getSegmentsNum ();
-  num_weight_segments_ = 1;
+  num_weight_segments_ = meta->getWeightSize () != 0 ? 1 : 0;
   num_input_segments_ = meta->getInputNum ();
   num_output_segments_ = meta->getOutputNum ();
 
-  weight_seg_idx_ = new uint32_t[num_weight_segments_];
+  if (num_weight_segments_ > 0) {
+    weight_seg_idx_ = new uint32_t[num_weight_segments_];
+    weight_seg_idx_[0] = meta->getWeightSegmentIndex ();
+  } else {
+    weight_seg_idx_ = nullptr;
+  }
+
   input_seg_idx_ = new uint32_t[num_input_segments_];
   output_seg_idx_ = new uint32_t[num_output_segments_];
 
   input_seg_off_ = new uint32_t[num_input_segments_];
   output_seg_off_ = new uint32_t[num_output_segments_];
 
-  weight_seg_idx_[0] = meta->getWeightSegmentIndex ();
   for (uint32_t i = 0; i < num_input_segments_; i++) {
     input_seg_idx_[i] = meta->getInputSegmentIndex (i);
     input_seg_off_[i] = meta->getInputSegmentOffset (i);
@@ -325,11 +330,8 @@ SegmentTable::restoreSegments () {
   HWmem *hwmem;
 
   hwmem = getWeightSegment ();
-  if (hwmem == nullptr) {
-    logerr (TAG, "Unable to find weight segment\n");
-    return -EINVAL;
-  }
-  updateSegmentSlot (hwmem, weight_seg_idx_[0]);
+  if (hwmem != nullptr)
+    updateSegmentSlot (hwmem, weight_seg_idx_[0]);
 
   for (uint32_t i = 0; i < num_input_segments_; i++) {
     hwmem = getInputSegment (i);
@@ -359,10 +361,8 @@ SegmentTable::restoreSegments () {
  */
 HWmem *
 SegmentTable::getWeightSegment (uint32_t idx) {
-  if (weight_seg_idx_ == nullptr) {
-    logerr (TAG, "No valid segments in this table, maybe uninitialized?\n");
+  if (weight_seg_idx_ == nullptr)
     return nullptr;
-  }
 
   if (idx >= num_weight_segments_) {
     logerr (TAG, "Invalid weight segment index (%u). Should be less than %u\n",
index b87cb07..942d2f2 100644 (file)
@@ -755,6 +755,20 @@ removeNPU_request (npudev_h dev, int req_id) {
 }
 
 /**
+ * @brief Get the request's model id
+ * @param[in] dev The NPU device handle
+ * @param[in] req_id The request's ID
+ * @param[out] model_id The request's model ID
+ * @return 0 if no error. Otherwise a negative errno
+ */
+int
+getNPU_requestModel (npudev_h dev, int req_id, uint32_t *model_id) {
+  INIT_HOST_HANDLER (host_handler, dev);
+
+  return host_handler->getRequestModel (req_id, model_id);
+}
+
+/**
  * @brief Set request's input/output data
  * @param[in] dev The NPU device handle
  * @param[in] req_id The request ID