Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / QuantizeWithMinMaxPass.cpp
index f8abee7..b335a53 100644 (file)
@@ -138,7 +138,8 @@ bool is_quantized(const CircleNode *node)
          node->dtype() == loco::DataType::S32;  // bias
 }
 
-void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor)
+void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor,
+                            int32_t &channel_dim_index)
 {
   assert(node->dtype() == loco::DataType::FLOAT32);
 
@@ -153,7 +154,6 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_facto
   uint32_t indices[4] = {
       0,
   };
-  int channel_dim_index{0};
 
   if (!get_channel_dim_index(node, dimension, channel_dim_index))
   {
@@ -189,7 +189,7 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_facto
 }
 
 void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
-                             std::vector<float> &scaling_factor)
+                             std::vector<float> &scaling_factor, int32_t &channel_dim_index)
 {
   assert(node->dtype() == loco::DataType::FLOAT32);
 
@@ -204,7 +204,6 @@ void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
   uint32_t indices[4] = {
       0,
   };
-  int channel_dim_index{0};
 
   if (!get_channel_dim_index(node, dimension, channel_dim_index))
   {
@@ -282,6 +281,10 @@ bool is_weights(CircleNode *node)
     if (dw_conv != nullptr && dw_conv->filter() == circle_const)
       return true;
 
+    auto t_conv = dynamic_cast<CircleTransposeConv *>(out);
+    if (t_conv != nullptr && t_conv->filter() == circle_const && circle_const->rank() == 4)
+      return true;
+
     auto fc = dynamic_cast<CircleFullyConnected *>(out);
     if (fc != nullptr && fc->weights() == circle_const)
       return true;
@@ -350,8 +353,8 @@ struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool>
           circle_node->dtype(loco::DataType::S16);
         }
 
-        circle_node->quantparam()->max[0] = nudged_max;
-        circle_node->quantparam()->min[0] = nudged_min;
+        circle_node->quantparam()->min.clear();
+        circle_node->quantparam()->max.clear();
         circle_node->quantparam()->scale.push_back(scaling_factor);
         circle_node->quantparam()->zerop.push_back(zp);
       }
@@ -472,15 +475,19 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
           assert(quantparam != nullptr);
           auto min = quantparam->min;
           auto scaling_factor = quantparam->scale;
+          int32_t channel_dim_index = 0;
 
           if (output_type == loco::DataType::U8)
           {
-            asym_wquant_per_channel(circle_const, min, scaling_factor);
+            asym_wquant_per_channel(circle_const, min, scaling_factor, channel_dim_index);
           }
           else
           {
-            sym_wquant_per_channel(circle_const, scaling_factor);
+            sym_wquant_per_channel(circle_const, scaling_factor, channel_dim_index);
           }
+          quantparam->min.clear();
+          quantparam->max.clear();
+          quantparam->quantized_dimension = channel_dim_index;
         }
         // Find min/max per layer-wise
         else
@@ -493,6 +500,8 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
           auto min = quantparam->min[0];
           auto scaling_factor = quantparam->scale[0];
           asym_wquant_per_layer(circle_const, min, scaling_factor);
+          quantparam->min.clear();
+          quantparam->max.clear();
         }
       }
     }