[PYTHON][FFI] Cythonize NDArray.copyto (#4549)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 20 Dec 2019 22:21:09 +0000 (14:21 -0800)
committerziheng <ziheng@apache.org>
Fri, 20 Dec 2019 22:21:09 +0000 (14:21 -0800)
* [PYTHON][FFI] Cythonize NDArray.copyto

* Cythonize the shape property

python/tvm/_ffi/_ctypes/ndarray.py
python/tvm/_ffi/_cython/ndarray.pxi
python/tvm/_ffi/ndarray.py

index 9367160..af59de6 100644 (file)
@@ -85,6 +85,16 @@ class NDArrayBase(object):
     def _tvm_handle(self):
         return ctypes.cast(self.handle, ctypes.c_void_p).value
 
+    def _copyto(self, target_nd):
+        """Internal function that implements copy to target ndarray."""
+        check_call(_LIB.TVMArrayCopyFromTo(self.handle, target_nd.handle, None))
+        return target_nd
+
+    @property
+    def shape(self):
+        """Shape of this array"""
+        return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
+
     def to_dlpack(self):
         """Produce an array from a DLPack Tensor without copying memory
 
index 402c9de..5682ae6 100644 (file)
@@ -68,6 +68,11 @@ cdef class NDArrayBase:
         def __set__(self, value):
             self._set_handle(value)
 
+    @property
+    def shape(self):
+        """Shape of this array"""
+        return tuple(self.chandle.shape[i] for i in range(self.chandle.ndim))
+
     def __init__(self, handle, is_view):
         self._set_handle(handle)
         self.c_is_view = is_view
@@ -76,6 +81,11 @@ cdef class NDArrayBase:
         if self.c_is_view == 0:
             CALL(TVMArrayFree(self.chandle))
 
+    def _copyto(self, target_nd):
+        """Internal function that implements copy to target ndarray."""
+        CALL(TVMArrayCopyFromTo(self.chandle, (<NDArrayBase>target_nd).chandle, NULL))
+        return target_nd
+
     def to_dlpack(self):
         """Produce an array from a DLPack Tensor without copying memory
 
index da0783e..56bf4a0 100644 (file)
@@ -157,10 +157,6 @@ def from_dlpack(dltensor):
 
 class NDArrayBase(_NDArrayBase):
     """A simple Device/CPU Array object in runtime."""
-    @property
-    def shape(self):
-        """Shape of this array"""
-        return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
 
     @property
     def dtype(self):
@@ -240,6 +236,7 @@ class NDArrayBase(_NDArrayBase):
             except:
                 raise TypeError('array must be an array_like data,' +
                                 'type %s is not supported' % str(type(source_array)))
+
         t = TVMType(self.dtype)
         shape, dtype = self.shape, self.dtype
         if t.lanes > 1:
@@ -294,14 +291,12 @@ class NDArrayBase(_NDArrayBase):
         target : NDArray
             The target array to be copied, must have same shape as this array.
         """
-        if isinstance(target, TVMContext):
-            target = empty(self.shape, self.dtype, target)
         if isinstance(target, NDArrayBase):
-            check_call(_LIB.TVMArrayCopyFromTo(
-                self.handle, target.handle, None))
-        else:
-            raise ValueError("Unsupported target type %s" % str(type(target)))
-        return target
+            return self._copyto(target)
+        elif isinstance(target, TVMContext):
+            res = empty(self.shape, self.dtype, target)
+            return self._copyto(res)
+        raise ValueError("Unsupported target type %s" % str(type(target)))
 
 
 def free_extension_handle(handle, type_code):