[base_loader] base_loader can load a quantization parameter (#9274)
author이상규/On-Device Lab(SR)/Principal Engineer/삼성전자 <sg5.lee@samsung.com>
Thu, 28 Nov 2019 09:45:24 +0000 (18:45 +0900)
committer이춘석/On-Device Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Thu, 28 Nov 2019 09:45:24 +0000 (18:45 +0900)
It enables to load quantization parameter where element size is 1.
That is, a pair of (scale, zero_point) is applied to a whole tensor values.

Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
runtime/neurun/frontend/base_loader/base_loader.h

index d4ccc2c..4384824 100644 (file)
@@ -23,6 +23,7 @@
 #include <map>
 #include <cpp14/memory.h>
 #include <fstream>
+#include <limits>
 
 namespace neurun
 {
@@ -184,7 +185,7 @@ BaseLoader<LoaderDomain, SpecificLoader>::BaseLoader::tensorTypeToDataType(const
       return model::DataType::INT32;
     case TensorType::TensorType_BOOL:
       return model::DataType::BOOL8;
-    case TensorType::TensorType_INT8:
+    case TensorType::TensorType_UINT8:
       return model::DataType::QUANT8_ASYMM;
     default:
       throw std::runtime_error(
@@ -204,8 +205,38 @@ model::OperandIndex BaseLoader<LoaderDomain, SpecificLoader>::loadOperand(const
   }
   // Type
   model::DataType data_type = tensorTypeToDataType(tensor->type());
+  // Quantization
+  auto q_params = tensor->quantization();
+  float scale = 0.0;
+  long zero_point = 0;
+  if (q_params != nullptr)
+  {
+    if (q_params->scale())
+    {
+      if (q_params->scale()->size() != 1)
+      {
+        throw std::runtime_error("Only 1 scale for a tensor is supported.");
+      }
+      scale = q_params->scale()->Get(0);
+    }
+
+    if (q_params->zero_point())
+    {
+      if (q_params->zero_point()->size() != 1)
+      {
+        throw std::runtime_error("Only 1 zero_point value for a tensor is supported.");
+      }
+      zero_point = q_params->zero_point()->Get(0);
+      // zero_point is long while TypeInfo.zero_point is defined as int32_t.
+      assert(zero_point >= std::numeric_limits<int32_t>::min());
+      assert(zero_point <= std::numeric_limits<int32_t>::max());
+    }
+    auto details = q_params->details_as_CustomQuantization();
+    if (details != nullptr)
+      throw std::runtime_error("Custom Quantization is not supported");
+  }
   // Create TypeInfo
-  model::TypeInfo type_info(data_type);
+  model::TypeInfo type_info(data_type, scale, zero_point);
   // Create operand
   const auto operand_index = _graph.addOperand(shape, type_info);
 
@@ -219,19 +250,6 @@ model::OperandIndex BaseLoader<LoaderDomain, SpecificLoader>::loadOperand(const
 
   // Name unused
   // auto name = tensor->name();
-  // Quantization
-  auto quantization = tensor->quantization();
-  if (quantization != nullptr)
-  {
-    auto scale = quantization->scale();
-    auto zero_point = quantization->zero_point();
-    if (scale != nullptr || zero_point != nullptr)
-      throw std::runtime_error("Quantization is not supported!");
-
-    auto details = quantization->details_as_CustomQuantization();
-    if (details != nullptr)
-      throw std::runtime_error("Custom Quantization is not supported");
-  }
   // Variablie
   if (tensor->is_variable())
     throw std::runtime_error("Variable tensor not supported!");