Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / common-artifacts / src / TestDataGenerator.cpp
index 739300d..7a07dd8 100644 (file)
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 
+#include <arser/arser.h>
 #include <foder/FileLoader.h>
 #include <luci/Importer.h>
 #include <luci_interpreter/Interpreter.h>
@@ -62,10 +63,9 @@ template <typename T> void geneate_random_data(std::mt19937 &gen, void *data, ui
   }
 }
 
-void fill_random_data(void *data, uint32_t size, loco::DataType dtype)
+void fill_random_data(void *data, uint32_t size, loco::DataType dtype, uint32_t seed)
 {
-  std::random_device rd;  // used to obtain a seed for the random number engine
-  std::mt19937 gen(rd()); // standard mersenne_twister_engine seeded with rd()
+  std::mt19937 gen(seed); // standard mersenne_twister_engine seeded with rd()
 
   switch (dtype)
   {
@@ -90,7 +90,25 @@ void fill_random_data(void *data, uint32_t size, loco::DataType dtype)
 
 int entry(int argc, char **argv)
 {
-  std::string circle_file{argv[1]};
+  arser::Arser arser;
+  arser.add_argument("circle").type(arser::DataType::STR).help("Circle file you want to test");
+  arser.add_argument("--fixed_seed")
+      .required(false)
+      .nargs(0)
+      .help("Put a fixed seed into the random number generator");
+
+  try
+  {
+    arser.parse(argc, argv);
+  }
+  catch (const std::runtime_error &err)
+  {
+    std::cout << err.what() << std::endl;
+    std::cout << arser;
+    return 255;
+  }
+
+  std::string circle_file = arser.get<std::string>("circle");
   size_t last_dot_index = circle_file.find_last_of(".");
   std::string prefix = circle_file.substr(0, last_dot_index);
 
@@ -136,6 +154,7 @@ int entry(int argc, char **argv)
   std::unique_ptr<H5::Group> output_value_group =
       std::make_unique<H5::Group>(output_file.createGroup("value"));
 
+  std::random_device rd; // used to obtain a seed for the random number engine
   uint32_t input_index = 0;
   for (uint32_t g = 0; g < circle_model->subgraphs()->size(); g++)
   {
@@ -174,7 +193,10 @@ int entry(int argc, char **argv)
       std::vector<int8_t> data(byte_size);
 
       // generate random data
-      fill_random_data(data.data(), data_size, input_node->dtype());
+      if (arser["--fixed_seed"])
+        fill_random_data(data.data(), data_size, input_node->dtype(), 0);
+      else
+        fill_random_data(data.data(), data_size, input_node->dtype(), rd());
 
       dataset->write(data.data(), dtype);