[MachineLearning.Inference] Update SingleShot and its related classes (#1154)
[platform/core/csapi/tizenfx.git] / src / Tizen.MachineLearning.Inference / Tizen.MachineLearning.Inference / TensorsData.cs
index ddabbb9..7ab6542 100755 (executable)
@@ -15,7 +15,7 @@
 */
 
 using System;
-using System.IO;
+using System.Collections;
 
 namespace Tizen.MachineLearning.Inference
 {
@@ -27,16 +27,57 @@ namespace Tizen.MachineLearning.Inference
     {
         private IntPtr _handle = IntPtr.Zero;
         private bool _disposed = false;
-        private int _count = Tensor.InvalidCount;
+        private TensorsInfo _tensorsInfo = null;
+        private ArrayList _dataList = null;
 
         /// <summary>
-        /// Creates a TensorsInfo instance with handle which is given by TensorsInfo.
+        /// Creates a TensorsData instance with handle which is given by TensorsInfo.
         /// </summary>
         /// <param name="handle">The handle of tensors data.</param>
+        /// <param name="info">The handle of tensors info. (Default: null)</param>
+        /// <param name="isFetch">The boolean value for fetching the data (Default: false)</param>
         /// <since_tizen> 6 </since_tizen>
-        private TensorsData(IntPtr handle)
+        private TensorsData(IntPtr handle, TensorsInfo info, bool isFetch)
         {
+            NNStreamer.CheckNNStreamerSupport();
+            NNStreamerError ret = NNStreamerError.None;
+
+            /* Set internal object */
             _handle = handle;
+            _tensorsInfo = info;
+
+            /* Set count */
+            int count = 0;
+            ret = Interop.Util.GetTensorsCount(_handle, out count);
+            NNStreamer.CheckException(ret, "unable to get the count of TensorsData");
+
+            _dataList = new ArrayList(count);
+
+            if (isFetch)
+            {
+                for (int i = 0; i < count; ++i)
+                {
+                    IntPtr raw_data;
+                    byte[] bufData = null;
+                    int size;
+
+                    ret = Interop.Util.GetTensorData(_handle, i, out raw_data, out size);
+                    NNStreamer.CheckException(ret, "unable to get the buffer of TensorsData: " + i.ToString());
+
+                    bufData = Interop.Util.IntPtrToByteArray(raw_data, size);
+                    _dataList.Add(bufData);
+                }
+            }
+            else
+            {
+                for (int i = 0; i < count; ++i)
+                {
+                    int size = info.GetTensorSize(i);
+                    byte[] bufData = new byte[size];
+
+                    _dataList.Add(bufData);
+                }
+            }
         }
 
         /// <summary>
@@ -48,36 +89,54 @@ namespace Tizen.MachineLearning.Inference
             Dispose(false);
         }
 
-        internal static TensorsData CreateFromNativeHandle(IntPtr handle)
+        /// <summary>
+        /// Gets the number of Tensor in TensorsData class
+        /// </summary>
+        /// <feature>http://tizen.org/feature/machine_learning.inference</feature>
+        /// <exception cref="NotSupportedException">Thrown when the feature is not supported.</exception>
+        /// <since_tizen> 6 </since_tizen>
+        public int Count
         {
-            TensorsData retTensorsData = new TensorsData(handle);
+            get {
+                NNStreamer.CheckNNStreamerSupport();
 
-            return retTensorsData;
+                return _dataList.Count;
+            }
         }
 
         /// <summary>
-        /// Gets the number of Tensor in TensorsData class
+        /// Gets the tensors information.
         /// </summary>
+        /// <returns>The TensorsInfo instance</returns>
         /// <feature>http://tizen.org/feature/machine_learning.inference</feature>
         /// <exception cref="NotSupportedException">Thrown when the feature is not supported.</exception>
-        /// <since_tizen> 6 </since_tizen>
-        public int Count
+        /// <since_tizen> 8 </since_tizen>
+        public TensorsInfo TensorsInfo
         {
             get {
                 NNStreamer.CheckNNStreamerSupport();
 
-                if (_count != Tensor.InvalidCount)
-                    return _count;
+                return _tensorsInfo;
+            }
+        }
 
-                NNStreamerError ret = NNStreamerError.None;
-                int count = 0;
+        /// <summary>
+        /// Allocates a new TensorsData instance with the given tensors information.
+        /// </summary>
+        /// <param name="info">TensorsInfo object which has Tensor information</param>
+        /// <returns>The TensorsInfo instance</returns>
+        /// <exception cref="ArgumentException">Thrown when the method failed due to an invalid parameter.</exception>
+        /// <exception cref="NotSupportedException">Thrown when the feature is not supported.</exception>
+        /// <since_tizen> 8 </since_tizen>
+        public static TensorsData Allocate(TensorsInfo info)
+        {
+            NNStreamer.CheckNNStreamerSupport();
 
-                ret = Interop.Util.GetTensorsCount(_handle, out count);
-                NNStreamer.CheckException(ret, "unable to get the count of TensorsData");
+            if (info == null)
+                throw NNStreamerExceptionFactory.CreateException(NNStreamerError.InvalidParameter, "TensorsInfo is null");
 
-                _count = count;
-                return _count;
-            }
+            TensorsData retData = info.GetTensorsData();
+            return retData;
         }
 
         /// <summary>
@@ -86,23 +145,17 @@ namespace Tizen.MachineLearning.Inference
         /// <param name="index">The index of the tensor.</param>
         /// <param name="buffer">Raw tensor data to be set.</param>
         /// <feature>http://tizen.org/feature/machine_learning.inference</feature>
-        /// <exception cref="ArgumentException">Thrown when the method failed due to an invalid parameter.</exception>
         /// <exception cref="NotSupportedException">Thrown when the feature is not supported.</exception>
+        /// <exception cref="ArgumentException">Thrown when the data is not valid.</exception>
         /// <since_tizen> 6 </since_tizen>
         public void SetTensorData(int index, byte[] buffer)
         {
-            NNStreamerError ret = NNStreamerError.None;
-
             NNStreamer.CheckNNStreamerSupport();
 
-            if (buffer == null)
-            {
-                string msg = "buffer is null";
-                throw NNStreamerExceptionFactory.CreateException(NNStreamerError.InvalidParameter, msg);
-            }
+            CheckIndex(index);
+            CheckDataBuffer(index, buffer);
 
-            ret = Interop.Util.SetTensorData(_handle, index, buffer, buffer.Length);
-            NNStreamer.CheckException(ret, "unable to set the buffer of TensorsData: " + index.ToString());
+            _dataList[index] = buffer;
         }
 
         /// <summary>
@@ -116,19 +169,11 @@ namespace Tizen.MachineLearning.Inference
         /// <since_tizen> 6 </since_tizen>
         public byte[] GetTensorData(int index)
         {
-            byte[] retBuffer = null;
-            IntPtr raw_data;
-            int size;
-            NNStreamerError ret = NNStreamerError.None;
-
             NNStreamer.CheckNNStreamerSupport();
 
-            ret = Interop.Util.GetTensorData(_handle, index, out raw_data, out size);
-            NNStreamer.CheckException(ret, "unable to get the buffer of TensorsData: " + index.ToString());
+            CheckIndex(index);
 
-            retBuffer = Interop.Util.IntPtrToByteArray(raw_data, size);
-
-            return retBuffer;
+            return (byte[])_dataList[index];
         }
 
         /// <summary>
@@ -168,9 +213,79 @@ namespace Tizen.MachineLearning.Inference
             _disposed = true;
         }
 
-        internal IntPtr Handle
+        internal IntPtr GetHandle()
+        {
+            return _handle;
+        }
+
+        internal void PrepareInvoke()
+        {
+            NNStreamerError ret = NNStreamerError.None;
+            int count = _dataList.Count;
+
+            for (int i = 0; i < count; ++i)
+            {
+                byte[] data = (byte[])_dataList[i];
+                ret = Interop.Util.SetTensorData(_handle, i, data, data.Length);
+                NNStreamer.CheckException(ret, "unable to set the buffer of TensorsData: " + i.ToString());
+            }
+        }
+
+        internal static TensorsData CreateFromNativeHandle(IntPtr dataHandle, IntPtr infoHandle, bool isFetch)
+        {
+            TensorsData retTensorsData = null;
+
+            if (infoHandle == IntPtr.Zero)
+            {
+                retTensorsData = new TensorsData(dataHandle, null, isFetch);
+            }
+            else
+            {
+                TensorsInfo info = TensorsInfo.ConvertTensorsInfoFromHandle(infoHandle);
+                retTensorsData = new TensorsData(dataHandle, info, isFetch);
+            }
+
+            return retTensorsData;
+        }
+
+        private void CheckIndex(int index)
         {
-            get { return _handle; }
+            if (index < 0 || index >= _dataList.Count)
+            {
+                string msg = "Invalid index [" + index + "] of the tensors";
+                throw NNStreamerExceptionFactory.CreateException(NNStreamerError.InvalidParameter, msg);
+            }
+        }
+
+        private void CheckDataBuffer(int index, byte[] data)
+        {
+            if (data == null)
+            {
+                string msg = "data is not valid";
+                throw NNStreamerExceptionFactory.CreateException(NNStreamerError.InvalidParameter, msg);
+            }
+
+            if (index >= Tensor.SizeLimit)
+            {
+                string msg = "Max size of the tensors is " + Tensor.SizeLimit;
+                throw NNStreamerExceptionFactory.CreateException(NNStreamerError.QuotaExceeded, msg);
+            }
+
+            if (_tensorsInfo != null)
+            {
+                if (index >= _tensorsInfo.Count)
+                {
+                    string msg = "Current information has " + _tensorsInfo.Count + " tensors";
+                    throw NNStreamerExceptionFactory.CreateException(NNStreamerError.QuotaExceeded, msg);
+                }
+
+                int size = _tensorsInfo.GetTensorSize(index);
+                if (data.Length != size)
+                {
+                    string msg = "Invalid buffer size, required size is " + size.ToString();
+                    throw NNStreamerExceptionFactory.CreateException(NNStreamerError.InvalidParameter, msg);
+                }
+            }
         }
     }
 }