[nnpkg_run] support qint8 in h5 load and save (#9373)
author이상규/On-Device Lab(SR)/Principal Engineer/삼성전자 <sg5.lee@samsung.com>
Wed, 4 Dec 2019 08:05:25 +0000 (17:05 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 4 Dec 2019 08:05:25 +0000 (17:05 +0900)
nnpackage_run can load and save qint8.

Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
tests/tools/nnpackage_run/src/h5formatter.cc

index f8869ca..b35f675 100644 (file)
@@ -43,24 +43,28 @@ void H5Formatter::loadInputs(const std::string &filename, std::vector<Allocation
     {
       nnfw_tensorinfo ti;
       NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
+      // allocate memory for data
+      auto bufsz = bufsize_for(&ti);
+      inputs[i].alloc(bufsz);
 
       H5::DataSet data_set = value_group.openDataSet(std::to_string(i));
-
-      // check type
       H5::DataType type = data_set.getDataType();
-      if (!(type == H5::PredType::IEEE_F32BE || type == H5::PredType::IEEE_F32LE))
+      if (type == H5::PredType::IEEE_F32BE || type == H5::PredType::IEEE_F32LE)
       {
-        throw std::runtime_error("h5 input has non-float32 type. nnpkg_run supports float32 only.");
+        data_set.read(inputs[i].data(), H5::PredType::NATIVE_FLOAT);
+        NNPR_ENSURE_STATUS(
+            nnfw_set_input(session, i, NNFW_TYPE_TENSOR_FLOAT32, inputs[i].data(), bufsz));
+      }
+      else if (type == H5::PredType::STD_U8BE || type == H5::PredType::STD_U8LE)
+      {
+        data_set.read(inputs[i].data(), H5::PredType::NATIVE_UINT8);
+        NNPR_ENSURE_STATUS(
+            nnfw_set_input(session, i, NNFW_TYPE_TENSOR_QUANT8_ASYMM, inputs[i].data(), bufsz));
+      }
+      else
+      {
+        throw std::runtime_error("nnpkg_run can load float32 and uint8 only.");
       }
-
-      // allocate memory for data
-      auto sz = num_elems(&ti);
-      inputs[i].alloc(sz * sizeof(float));
-      // read data
-      data_set.read(inputs[i].data(), H5::PredType::NATIVE_FLOAT);
-
-      NNPR_ENSURE_STATUS(nnfw_set_input(session, i, NNFW_TYPE_TENSOR_FLOAT32, inputs[i].data(),
-                                        sizeof(float) * num_elems(&ti)));
       NNPR_ENSURE_STATUS(nnfw_set_input_layout(session, i, NNFW_LAYOUT_CHANNELS_LAST));
     }
   }
@@ -103,9 +107,22 @@ void H5Formatter::dumpOutputs(const std::string &filename, std::vector<Allocatio
         }
       }
       H5::DataSpace data_space(ti.rank, dims.data());
-      H5::DataSet data_set =
-          value_group.createDataSet(std::to_string(i), H5::PredType::IEEE_F32BE, data_space);
-      data_set.write(outputs[i].data(), H5::PredType::NATIVE_FLOAT);
+      if (ti.dtype == NNFW_TYPE_TENSOR_FLOAT32)
+      {
+        H5::DataSet data_set =
+            value_group.createDataSet(std::to_string(i), H5::PredType::IEEE_F32BE, data_space);
+        data_set.write(outputs[i].data(), H5::PredType::NATIVE_FLOAT);
+      }
+      else if (ti.dtype == NNFW_TYPE_TENSOR_QUANT8_ASYMM)
+      {
+        H5::DataSet data_set =
+            value_group.createDataSet(std::to_string(i), H5::PredType::STD_U8BE, data_space);
+        data_set.write(outputs[i].data(), H5::PredType::NATIVE_UINT8);
+      }
+      else
+      {
+        throw std::runtime_error("nnpkg_run can dump float32 and uint8 only.");
+      }
     }
   }
   catch (const H5::Exception &e)