[MachineLearning.Train] Add SetDataset method to Model class
authorHyunil <hyunil46.park@samsung.com>
Thu, 30 Jun 2022 04:51:03 +0000 (13:51 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 23 Aug 2022 05:50:26 +0000 (14:50 +0900)
- Add SetDataset(Dataset dataset)
- Add ml_train_model_set_dataset() to interop
- Add NNTrainerDatasetMode to commons.cs

Signed-off-by: Hyunil <hyunil46.park@samsung.com>
src/Tizen.MachineLearning.Train/Interop/Interop.Model.cs
src/Tizen.MachineLearning.Train/Tizen.MachineLearning.Train/Commons.cs
src/Tizen.MachineLearning.Train/Tizen.MachineLearning.Train/Model.cs

index e383f91..0ef27a7 100644 (file)
@@ -66,6 +66,8 @@ internal static partial class Interop
         [DllImport(Libraries.Nntrainer, EntryPoint = "ml_train_model_set_optimizer")]
         public static extern NNTrainerError SetOptimizer(IntPtr modelHandle, IntPtr optimizerHandle);
 
-
+        /* int ml_train_model_set_dataset(ml_train_model_h model, ml_train_dataset_h dataset) */
+        [DllImport(Libraries.Nntrainer, EntryPoint = "ml_train_model_set_dataset")]
+        public static extern NNTrainerError SetDataset(IntPtr modelHandle, IntPtr datasetHandle);
     }
 }
index 966bafb..b5a7cf6 100644 (file)
@@ -227,6 +227,26 @@ namespace Tizen.MachineLearning.Train
         Unknown = 999
     }
 
+    /// <summary>
+    /// Enumeration for the dataset data type of NNTrainer.
+    /// </summary>
+    /// <since_tizen> 10 </since_tizen>
+    public enum NNTrainerDatasetMode
+    {
+        /// <summary>
+        /// The given data is for used when training
+        /// </summary>
+        Train = 0,
+        /// <summary>
+        /// The given data is for used when validating
+        /// </summary>
+        Valid = 1,
+        /// <summary>
+        /// The given data is for used when testing
+        /// </summary>
+        Test = 2,
+    }
+
     internal static class NNTrainer
     {
  
index fc66f3b..7e62522 100644 (file)
@@ -271,6 +271,7 @@ namespace Tizen.MachineLearning.Train
             NNTrainerError ret = Interop.Model.GetLayer(handle, layerName, out layerHandle);
             NNTrainer.CheckException(ret, "Failed to get layer");
         }
+
         /// <summary>
         /// Sets the optimizer for the neural network model.
         /// </summary>
@@ -288,5 +289,26 @@ namespace Tizen.MachineLearning.Train
             NNTrainerError ret = Interop.Model.SetOptimizer(handle, optimizer.GetHandle());
             NNTrainer.CheckException(ret, "Failed to set optimizer");
         }
+
+        /// <summary>
+        /// Sets the dataset (data provider) for the neural network model.
+        /// </summary>
+        /// <remarks>
+        /// Use this function to set dataset for running the model. The dataset
+        /// will provide training, validation and test data for the model. This transfers
+        /// the ownership of the dataset to the network. No need to destroy the dataset
+        /// once it is set to a model.
+        /// Unsets the previously set dataset, if any. The previously set
+        /// dataset must be freed using Dispose().
+        /// </remarks>
+        /// <param name="dataset"> The instance of Dataset class </param>
+        /// <since_tizen> 10 </since_tizen>
+        public void SetDataset(Dataset dataset)
+        {
+            if (dataset == null)
+                NNTrainer.CheckException(NNTrainerError.InvalidOperation, "dataset instance is null");
+            NNTrainerError ret = Interop.Model.SetDataset(handle, dataset.GetHandle());
+            NNTrainer.CheckException(ret, "Failed to set dataset");
+        }
     } 
 }