[Refactoring] return 'task_id' when calling runNPU_sync/async
[platform/adaptation/npu/trix-engine.git] / src / core / ne-handler.cc
index 9e545b3..1550b79 100644 (file)
@@ -210,6 +210,21 @@ HostHandler::getTops (uint32_t *tops)
 }
 
 /**
+ * @brief Get the DSP DSPM size of the opened NPU device
+ * @param[in] dev the NPU device handle
+ * @param[out] dspm dspm size
+ * @return 0 if no error, otherwise a negative errno
+ * @note this does not support for emulated devices
+ */
+int
+HostHandler::getDspmSize (uint32_t *dspm)
+{
+  const DriverAPI * api = device_->getDriverAPI ();
+  assert (api != nullptr);
+
+  return api->getDspmSize (dspm);
+}
+/**
  * @brief Set the data layout for input/output tensors
  * @param[in] modelid The ID of model whose layouts are set
  * @param[in] in the layout/type info for input tensors
@@ -269,7 +284,7 @@ class callbackSync {
     }
 
     void callback (output_buffers *output, uint64_t sequence) {
-      if (output_ != nullptr) {
+      if (output_ != nullptr && output != nullptr) {
         /** just copy internal variables of output buffers */
         memcpy (output_, output, sizeof (output_buffers));
       }
@@ -294,7 +309,7 @@ class callbackSync {
  * @param[in] modelid The model to be inferred.
  * @param[in] input The input data to be inferred.
  * @param[out] output The output result.
- * @return @c 0 if no error. otherwise a negative error value
+ * @return @c positive id if no error. otherwise a negative error value
  */
 int
 HostHandler::runSync (uint32_t modelid, const input_buffers *input,
@@ -303,7 +318,7 @@ HostHandler::runSync (uint32_t modelid, const input_buffers *input,
   callbackSync sync (output);
   int status = runAsync (modelid, input, callbackSync::callback,
       static_cast <void*> (&sync), NPUASYNC_DROP_OLD, nullptr);
-  if (status == 0) {
+  if (status > 0) {
     /** sync needs to wait callback */
     sync.wait ();
   }
@@ -318,7 +333,7 @@ HostHandler::runSync (uint32_t modelid, const input_buffers *input,
  * @param[in] cb_data The data given as a parameter to the runNPU_async call.
  * @param[in] mode Configures how this operation works.
  * @param[out] sequence The sequence number returned with runNPU_async.
- * @return @c 0 if no error. otherwise a negative error value
+ * @return @c positive id if no error. otherwise a negative error value
  */
 int
 HostHandler::runAsync (uint32_t modelid, const input_buffers *input,
@@ -816,6 +831,7 @@ TrinityVision2::setModel (const generic_buffer *model_buf, Model ** model_ptr)
   if (model->getMetadata()->getProgramSize() > 0) {
     HWmem * hwmem_prog = new HWmem (new HWmemDevice);
     hwmem_prog->setDriverAPI (api_.get());
+    hwmem_prog->setContiguous (true);
 
     model->setProgramData (hwmem_prog);
 
@@ -864,6 +880,9 @@ TrinityVision2::setModel (const generic_buffer *model_buf, Model ** model_ptr)
         logerr (TAG, "Failed to extract generic buffer: %d\n", status);
         goto delete_exit;
       }
+    } else {
+      config.metadata_ext_dbuf_fd = -1;
+      config.metadata_ext_size = 0;
     }
 
     status = api_->registerModel (&config, model->getMetadata()->getNPUVersion());
@@ -972,11 +991,12 @@ TrinityVision2::run (npu_input_opmode opmode, const Model *model,
 
   Request *req = new Request (opmode);
   req->setModel (model);
-  req->setSegmentTable (segt);
+  req->setInferData (segt);
   req->setCallback (std::bind (&TrinityVision2::callback, this, req, cb, cb_data));
 
-  if (sequence)
-    *sequence = req->getID();
+  if (sequence && req->getID () > 0) {
+    *sequence = (uint32_t) req->getID ();
+  }
 
   return scheduler_->submitRequest (req);
 }
@@ -1003,7 +1023,7 @@ TrinityVision2::runInternal (npu_input_opmode opmode, const Model *model,
 
   Request *req = new Request (opmode);
   req->setModel (model);
-  req->setSegmentTable (segt);
+  req->setInferData (segt);
   req->setHwDevice (hw_dev);
 
   return scheduler_->submitRequest (req);
@@ -1017,7 +1037,10 @@ TrinityVision2::callback (Request *req, npuOutputNotify cb, void *cb_data)
     return;
 
   const Model *model = req->getModel ();
-  SegmentTable *segt = req->getSegmentTable ();
+  SegmentTable *segt = dynamic_cast<SegmentTable *> (req->getInferData ());
+  /** internal logic error */
+  assert (segt != nullptr);
+
   output_buffers output = {
     .num_buffers = segt->getNumOutputSegments ()
   };