DevContext()
{
auto device_count = getnumNPUdeviceByType(NPUCOND_TRIV2_CONN_SOCIP);
+ // TODO: x64 platform has 3 cores. We do not support more that 2 cores for now.
+ if (device_count > 2)
+ {
+ device_count = 2;
+ }
+
if (device_count <= 0)
{
- throw std::runtime_error("Unable to find TRIV2 NPU device");
+ throw std::runtime_error("Unable to find TRIX NPU device");
}
- // Use NPU 0 device
- if (getNPUdeviceByType(&_dev_handle, NPUCOND_TRIV2_CONN_SOCIP, 0) < 0)
+ for (int i = 0; i < device_count; i++)
{
- throw std::runtime_error("Failed to get TRIV2 NPU device handle");
+ npudev_h h;
+ if (getNPUdeviceByType(&h, NPUCOND_TRIV2_CONN_SOCIP, i) < 0)
+ {
+ throw std::runtime_error("Failed to get TRIX NPU device handle");
+ }
+ _dev_handles.push_back(h);
}
}
~DevContext()
{
- if (_dev_handle != nullptr)
+ for (auto h : _dev_handles)
{
- unregisterNPUmodel_all(_dev_handle);
- putNPUdevice(_dev_handle);
+ if (h != nullptr)
+ {
+ unregisterNPUmodel_all(h);
+ putNPUdevice(h);
+ }
}
}
- npudev_h getDev() { return _dev_handle; }
+ npudev_h getDev(int i) { return _dev_handles[i]; }
+ int getDevSize() { return _dev_handles.size(); }
template <typename T> void setDataInfo(tensors_data_info *info, std::vector<T *> &tensors)
{
}
}
- template <typename T> void setBuffer(generic_buffers *buf, std::vector<T *> &tensors)
+ template <typename T>
+ void setBuffer(generic_buffers *buf, std::vector<T *> &tensors, int batch_size, int batch_index)
{
buf->num_buffers = static_cast<uint32_t>(tensors.size());
for (uint32_t idx = 0; idx < buf->num_buffers; ++idx)
{
- buf->bufs[idx].addr = tensors[idx]->buffer();
- buf->bufs[idx].size = static_cast<uint64_t>(tensors[idx]->total_size());
+ buf->bufs[idx].size = static_cast<uint64_t>(tensors[idx]->total_size() / batch_size);
+ buf->bufs[idx].addr = tensors[idx]->buffer() + (batch_index * buf->bufs[idx].size);
buf->bufs[idx].type = BUFFER_MAPPED;
}
}
}
private:
- // NPU device handle
- // TODO Support multicore npu device
- npudev_h _dev_handle;
+ // NPU device handles
+ std::vector<npudev_h> _dev_handles;
};
} // namespace trix