Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / RequantizePass.cpp
index a565362..77c5532 100644 (file)
@@ -32,37 +32,9 @@ namespace luci
 namespace
 {
 
-// Check if the node is the bias of Conv2D, DepthwiseConv2D, or FullyConnected layer
-bool is_bias(CircleConst *node)
-{
-  if (node == nullptr)
-    return false;
-
-  auto succs = loco::succs(node);
-  if (succs.size() != 1) // assume bias is used by only one node
-    return false;
-
-  for (auto out : succs)
-  {
-    auto conv = dynamic_cast<CircleConv2D *>(out);
-    if (conv != nullptr && conv->bias() == node)
-      return true;
-
-    auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
-    if (dw_conv != nullptr && dw_conv->bias() == node)
-      return true;
-
-    auto fc = dynamic_cast<CircleFullyConnected *>(out);
-    if (fc != nullptr && fc->bias() == node)
-      return true;
-
-    auto tconv = dynamic_cast<CircleTransposeConv *>(out);
-    if (tconv != nullptr && tconv->bias() == node)
-      return true;
-  }
-  return false;
-}
-
+// Requantize Non-const node from int8 to uint8
+// Original values: -128 ~ 127
+// After requantization: 0 ~ 255
 void requant_nonconst_int8_to_uint8(CircleNode *circle_node)
 {
   assert(circle_node->dtype() == loco::DataType::S8);
@@ -107,99 +79,48 @@ void requant_const_int8_to_uint8(CircleConst *node)
   }
 }
 
+#define RETURN_UNLESS(cond) \
+  if (not(cond))            \
+    return;
+
 /**
- * @brief RequantizeNonConst requantizes tensors for activations
+ * @brief Requantize int8 quantized tensors to uint8 tensors
  */
-struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool>
+struct RequantizeS8ToU8 final : public luci::CircleNodeMutableVisitor<void>
 {
-  RequantizeNonConst(loco::DataType input, loco::DataType output)
-    : _input_type(input), _output_type(output)
-  {
-  }
-
-  loco::DataType _input_type;
-  loco::DataType _output_type;
-
-  // Requantize input tensors of each node
-  bool visit(luci::CircleNode *node)
+  // Requantize non-const tensors
+  void visit(luci::CircleNode *node)
   {
     LOGGER(l);
-    INFO(l) << "RequantizeNonConst visit node: " << node->name() << std::endl;
-    auto arity = node->arity();
-    for (uint32_t i = 0; i < arity; i++)
-    {
-      auto input_node = node->arg(i);
-      auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+    INFO(l) << "RequantizeS8ToU8 visit non-const node: " << node->name() << std::endl;
 
-      // Check if this was quantized (only quantized tensors are requantized)
-      if (circle_node->quantparam() == nullptr)
-        continue;
+    // Ignore non-quantized tensors
+    RETURN_UNLESS(node->quantparam() != nullptr);
 
-      // Check if this is already requantized
-      if (circle_node->dtype() == _output_type)
-        continue;
+    // Check dtype is int8
+    RETURN_UNLESS(node->dtype() == loco::DataType::S8);
 
-      // Check if this is not const (only non-const is requantized in this function)
-      auto circle_const = dynamic_cast<CircleConst *>(circle_node);
-      if (circle_const != nullptr)
-        continue;
-
-      if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
-        requant_nonconst_int8_to_uint8(circle_node);
-    }
-    return false;
-  }
-};
-
-/**
- * @brief RequantizeConst requantizes tensors for weights
- */
-struct RequantizeConst final : public luci::CircleNodeMutableVisitor<bool>
-{
-  RequantizeConst(loco::DataType input, loco::DataType output)
-    : _input_type(input), _output_type(output)
-  {
+    requant_nonconst_int8_to_uint8(node);
   }
 
-  loco::DataType _input_type;
-  loco::DataType _output_type;
-
-  // Requantize input tensors of each node
-  bool visit(luci::CircleNode *node)
+  // Requantize const tensors
+  void visit(luci::CircleConst *node)
   {
     LOGGER(l);
-    INFO(l) << "RequantizeConst visit node: " << node->name() << std::endl;
-    auto arity = node->arity();
-    for (uint32_t i = 0; i < arity; i++)
-    {
-      auto input_node = node->arg(i);
-      auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+    INFO(l) << "RequantizeS8ToU8 visit const node: " << node->name() << std::endl;
 
-      // Check if this was quantized (only quantized tensors are requantized)
-      if (circle_node->quantparam() == nullptr)
-        continue;
+    // Ignore non-quantized tensors
+    RETURN_UNLESS(node->quantparam() != nullptr);
 
-      // Check if this is already requantized
-      if (circle_node->dtype() == _output_type)
-        continue;
+    // Check dtype is int8
+    RETURN_UNLESS(node->dtype() == loco::DataType::S8);
 
-      // Check if this is const (only const is requantized in this function)
-      auto circle_const = dynamic_cast<CircleConst *>(circle_node);
-      if (circle_const == nullptr)
-        continue;
-
-      // Check if this is not bias
-      // bias is not requantized when int8 -> uint8
-      if (is_bias(circle_const))
-        continue;
-
-      if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
-        requant_const_int8_to_uint8(circle_const);
-    }
-    return false;
+    requant_const_int8_to_uint8(node);
   }
 };
 
+#undef RETURN_UNLESS
+
 } // namespace
 
 bool RequantizePass::run(loco::Graph *g)
@@ -207,20 +128,21 @@ bool RequantizePass::run(loco::Graph *g)
   LOGGER(l);
   INFO(l) << "RequantizePass Start" << std::endl;
 
-  // Requantize non-const (activations)
-  for (auto node : loco::active_nodes(loco::output_nodes(g)))
+  // Input: int8 model
+  // Output: uint8 model
+  if (_input_dtype == loco::DataType::S8 and _output_dtype == loco::DataType::U8)
   {
-    RequantizeNonConst rqnc(_input_dtype, _output_dtype);
-    auto circle_node = loco::must_cast<luci::CircleNode *>(node);
-    circle_node->accept(&rqnc);
+    for (auto node : loco::active_nodes(loco::output_nodes(g)))
+    {
+      RequantizeS8ToU8 rq;
+      auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+      circle_node->accept(&rq);
+    }
   }
-
-  // Requantize const (including weights, constants)
-  for (auto node : loco::active_nodes(loco::output_nodes(g)))
+  else
   {
-    RequantizeConst rqc(_input_dtype, _output_dtype);
-    auto circle_node = loco::must_cast<luci::CircleNode *>(node);
-    circle_node->accept(&rqc);
+    // Ignore other cases
+    return false;
   }
 
   // Update output dtype
@@ -228,7 +150,8 @@ bool RequantizePass::run(loco::Graph *g)
   for (auto node : loco::output_nodes(g))
   {
     auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
-    if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_dtype)
+    auto from_node = loco::must_cast<luci::CircleNode *>(circle_node->from());
+    if (from_node->dtype() == _output_dtype)
     {
       circle_node->dtype(_output_dtype);
       auto graph_output = graph_outputs->at(circle_node->index());