Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / frontend / base_loader / include / base_loader.h
index 878a594..a6b1fb4 100644 (file)
@@ -513,11 +513,12 @@ void BaseLoader<LoaderDomain>::loadSparsity(const Tensor *tensor, ir::TypeInfo &
       if (src_metadata->array_segments() == nullptr || src_metadata->array_indices() == nullptr)
         return false;
       bool status = true;
+      /* `onert` inernally uses uint16 type regardless of the value of
+         the array_segments_type and array_indices_type */
       switch (src_metadata->array_segments_type())
       {
         case SparseIndexVector::SparseIndexVector_Int32Vector:
-          status = Copy(src_metadata->array_segments_as_Int32Vector(), w1_segments);
-          break;
+          throw std::runtime_error("sparse tensor with int32 segment type is not supported");
         case SparseIndexVector::SparseIndexVector_Uint16Vector:
           status = Copy(src_metadata->array_segments_as_Uint16Vector(), w1_segments);
           break;
@@ -532,7 +533,7 @@ void BaseLoader<LoaderDomain>::loadSparsity(const Tensor *tensor, ir::TypeInfo &
       switch (src_metadata->array_indices_type())
       {
         case SparseIndexVector::SparseIndexVector_Int32Vector:
-          return Copy(src_metadata->array_indices_as_Int32Vector(), w1_indices);
+          throw std::runtime_error("sparse tensor with int32 indices type is not supported");
         case SparseIndexVector::SparseIndexVector_Uint16Vector:
           return Copy(src_metadata->array_indices_as_Uint16Vector(), w1_indices);
         case SparseIndexVector::SparseIndexVector_Uint8Vector:
@@ -650,7 +651,19 @@ void BaseLoader<LoaderDomain>::loadConv2D(const Operator *op, ir::Graph &subg)
   param.dilation.width_factor = options->dilation_w_factor();
   param.dilation.height_factor = options->dilation_h_factor();
 
-  loadOperationTo<ir::operation::Conv2D>(op, subg, param);
+  const auto conv = loadOperationTo<ir::operation::Conv2D>(op, subg, param);
+
+  // TFLite support old hybrid quantization (float input/output, uint8 kernel)
+  // but it interprets weight type as init8 internally
+  const auto &input_operand =
+    subg.operands().at(conv->getInputs().at(ir::operation::Conv2D::INPUT));
+  auto &weights_operand = subg.operands().at(conv->getInputs().at(ir::operation::Conv2D::KERNEL));
+  if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 &&
+      ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) ||
+       weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM))
+  {
+    weights_operand.type(ir::DataType::QUANT_INT8_SYMM);
+  }
 }
 
 template <typename LoaderDomain>
@@ -665,7 +678,21 @@ void BaseLoader<LoaderDomain>::loadDepthwiseConv2D(const Operator *op, ir::Graph
   param.dilation.width_factor = options->dilation_w_factor();
   param.dilation.height_factor = options->dilation_h_factor();
 
-  loadOperationTo<ir::operation::DepthwiseConv2D>(op, subg, param);
+  const auto dconv = loadOperationTo<ir::operation::DepthwiseConv2D>(op, subg, param);
+
+  // TFLite does not support old hybrid quantization (float input/output, uint8 kernel)
+  // for depthwise convolution.
+  // But for consistency with Conv2D and FC, we interpret weight type as init8 internally
+  const auto &input_operand =
+    subg.operands().at(dconv->getInputs().at(ir::operation::DepthwiseConv2D::INPUT));
+  auto &weights_operand =
+    subg.operands().at(dconv->getInputs().at(ir::operation::DepthwiseConv2D::KERNEL));
+  if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 &&
+      ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) ||
+       weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM))
+  {
+    weights_operand.type(ir::DataType::QUANT_INT8_SYMM);
+  }
 }
 
 template <typename LoaderDomain>
@@ -745,6 +772,8 @@ void BaseLoader<LoaderDomain>::loadFC(const Operator *op, ir::Graph &subg)
 
   const auto fc = loadOperationTo<ir::operation::FullyConnected>(op, subg, param);
 
+  // TFLite supports old hybrid quantization (float input/output, uint8 kernel)
+  // but it interprets weight type as init8 internally
   const auto &input_operand =
     subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::INPUT));
   auto &weights_operand =