Imported Upstream version 1.21.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / trix / DevContext.h
index 482932f..a7dbd7a 100644 (file)
@@ -32,28 +32,42 @@ public:
   DevContext()
   {
     auto device_count = getnumNPUdeviceByType(NPUCOND_TRIV2_CONN_SOCIP);
+    // TODO: x64 platform has 3 cores. We do not support more that 2 cores for now.
+    if (device_count > 2)
+    {
+      device_count = 2;
+    }
+
     if (device_count <= 0)
     {
-      throw std::runtime_error("Unable to find TRIV2 NPU device");
+      throw std::runtime_error("Unable to find TRIX NPU device");
     }
 
-    // Use NPU 0 device
-    if (getNPUdeviceByType(&_dev_handle, NPUCOND_TRIV2_CONN_SOCIP, 0) < 0)
+    for (int i = 0; i < device_count; i++)
     {
-      throw std::runtime_error("Failed to get TRIV2 NPU device handle");
+      npudev_h h;
+      if (getNPUdeviceByType(&h, NPUCOND_TRIV2_CONN_SOCIP, i) < 0)
+      {
+        throw std::runtime_error("Failed to get TRIX NPU device handle");
+      }
+      _dev_handles.push_back(h);
     }
   }
 
   ~DevContext()
   {
-    if (_dev_handle != nullptr)
+    for (auto h : _dev_handles)
     {
-      unregisterNPUmodel_all(_dev_handle);
-      putNPUdevice(_dev_handle);
+      if (h != nullptr)
+      {
+        unregisterNPUmodel_all(h);
+        putNPUdevice(h);
+      }
     }
   }
 
-  npudev_h getDev() { return _dev_handle; }
+  npudev_h getDev(int i) { return _dev_handles[i]; }
+  int getDevSize() { return _dev_handles.size(); }
 
   template <typename T> void setDataInfo(tensors_data_info *info, std::vector<T *> &tensors)
   {
@@ -66,14 +80,15 @@ public:
     }
   }
 
-  template <typename T> void setBuffer(generic_buffers *buf, std::vector<T *> &tensors)
+  template <typename T>
+  void setBuffer(generic_buffers *buf, std::vector<T *> &tensors, int batch_size, int batch_index)
   {
     buf->num_buffers = static_cast<uint32_t>(tensors.size());
 
     for (uint32_t idx = 0; idx < buf->num_buffers; ++idx)
     {
-      buf->bufs[idx].addr = tensors[idx]->buffer();
-      buf->bufs[idx].size = static_cast<uint64_t>(tensors[idx]->total_size());
+      buf->bufs[idx].size = static_cast<uint64_t>(tensors[idx]->total_size() / batch_size);
+      buf->bufs[idx].addr = tensors[idx]->buffer() + (batch_index * buf->bufs[idx].size);
       buf->bufs[idx].type = BUFFER_MAPPED;
     }
   }
@@ -106,9 +121,8 @@ private:
   }
 
 private:
-  // NPU device handle
-  // TODO Support multicore npu device
-  npudev_h _dev_handle;
+  // NPU device handles
+  std::vector<npudev_h> _dev_handles;
 };
 
 } // namespace trix