From 474c70d7cb0ed4455aad9c04c2e6dd02f635c1c5 Mon Sep 17 00:00:00 2001 From: jmorrill Date: Sat, 29 Feb 2020 13:52:22 -0800 Subject: [PATCH] Added CopyFromBytes and CopyToBytes convenience methods to NDArray. Fixed typos. (#4970) * Added CopyFromBytes and CopyToBytes convenience methods. Fixed typos. * Removed unneed argument check * Use TVMArrayCopyFrom/ToBytes methods * Moved CopyFrom/ToBytes to ndarray.cc * CopyToBytes impl was using CopyFromBytes. Fixed * changed inline to TVM_DLL * Used impl from TVMArrayCopyTo/FromBytes into NDArray CopyTo/FromBytes * Move implementation of all CopyFrom/ToBytes into a common impls * make arg const * simplify method impl --- include/tvm/runtime/ndarray.h | 26 ++++++++++++++++--- src/runtime/ndarray.cc | 60 ++++++++++++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 24 deletions(-) diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 090cacf..2441ab6 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -68,20 +68,38 @@ class NDArray : public ObjectRef { /*! * \brief Copy data content from another array. * \param other The source array to be copied from. - * \note The copy may happen asynchrously if it involves a GPU context. + * \note The copy may happen asynchronously if it involves a GPU context. * TVMSynchronize is necessary. */ inline void CopyFrom(const DLTensor* other); inline void CopyFrom(const NDArray& other); /*! + * \brief Copy data content from a byte buffer. + * \param data The source bytes to be copied from. + * \param nbytes The size of the buffer in bytes + * Must be equal to the size of the NDArray. + * \note The copy may happen asynchronously if it involves a GPU context. + * TVMSynchronize is necessary. + */ + TVM_DLL void CopyFromBytes(const void* data, size_t nbytes); + /*! * \brief Copy data content into another array. * \param other The source array to be copied from. - * \note The copy may happen asynchrously if it involves a GPU context. + * \note The copy may happen asynchronously if it involves a GPU context. * TVMSynchronize is necessary. */ inline void CopyTo(DLTensor* other) const; inline void CopyTo(const NDArray& other) const; /*! + * \brief Copy data content into another array. + * \param data The source bytes to be copied from. + * \param nbytes The size of the data buffer. + * Must be equal to the size of the NDArray. + * \note The copy may happen asynchronously if it involves a GPU context. + * TVMSynchronize is necessary. + */ + TVM_DLL void CopyToBytes(void* data, size_t nbytes) const; + /*! * \brief Copy the data to another context. * \param ctx The target context. * \return The array under another context. @@ -182,7 +200,7 @@ class NDArray : public ObjectRef { /*! * \brief Save a DLTensor to stream - * \param strm The outpu stream + * \param strm The output stream * \param tensor The tensor to be saved. */ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); @@ -205,7 +223,7 @@ class NDArray::ContainerBase { DLTensor dl_tensor; /*! - * \brief addtional context, reserved for recycling + * \brief additional context, reserved for recycling * \note We can attach additional content here * which the current container depend on * (e.g. reference to original memory when creating views). diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 91002c9..ff2f34e 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -60,6 +60,32 @@ inline size_t GetDataAlignment(const DLTensor& arr) { return align; } +void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + size_t arr_size = GetDataSize(*handle); + CHECK_EQ(arr_size, nbytes) + << "ArrayCopyFromBytes: size mismatch"; + DeviceAPI::Get(handle->ctx)->CopyDataFromTo( + data, 0, + handle->data, static_cast(handle->byte_offset), + nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr); +} + +void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + size_t arr_size = GetDataSize(*handle); + CHECK_EQ(arr_size, nbytes) + << "ArrayCopyToBytes: size mismatch"; + DeviceAPI::Get(handle->ctx)->CopyDataFromTo( + handle->data, static_cast(handle->byte_offset), + data, 0, + nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr); +} + struct NDArray::Internal { // Default deleter for the container static void DefaultDeleter(Object* ptr_obj) { @@ -185,6 +211,18 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { return NDArray(GetObjectPtr(data)); } +void NDArray::CopyToBytes(void* data, size_t nbytes) const { + CHECK(data != nullptr); + CHECK(data_ != nullptr); + ArrayCopyToBytes(&get_mutable()->dl_tensor, data, nbytes); +} + +void NDArray::CopyFromBytes(const void* data, size_t nbytes) { + CHECK(data != nullptr); + CHECK(data_ != nullptr); + ArrayCopyFromBytes(&get_mutable()->dl_tensor, data, nbytes); +} + void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) { @@ -286,16 +324,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) { API_BEGIN(); - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - size_t arr_size = GetDataSize(*handle); - CHECK_EQ(arr_size, nbytes) - << "TVMArrayCopyFromBytes: size mismatch"; - DeviceAPI::Get(handle->ctx)->CopyDataFromTo( - data, 0, - handle->data, static_cast(handle->byte_offset), - nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr); + ArrayCopyFromBytes(handle, data, nbytes); API_END(); } @@ -303,15 +332,6 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes) { API_BEGIN(); - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - size_t arr_size = GetDataSize(*handle); - CHECK_EQ(arr_size, nbytes) - << "TVMArrayCopyToBytes: size mismatch"; - DeviceAPI::Get(handle->ctx)->CopyDataFromTo( - handle->data, static_cast(handle->byte_offset), - data, 0, - nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr); + ArrayCopyToBytes(handle, data, nbytes); API_END(); } -- 2.7.4