Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / PropagateQParamBackwardPass.cpp
index e8fa2a4..18617e3 100644 (file)
 namespace
 {
 
+// Return true if node is a virtual node
+bool virtual_op(const luci::CircleOpcode opcode)
+{
+  switch (opcode)
+  {
+#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \
+  case luci::CircleOpcode::OPCODE:        \
+    return false;
+#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \
+  case luci::CircleOpcode::OPCODE:         \
+    return true;
+#include <luci/IR/CircleNodes.lst>
+#undef CIRCLE_NODE
+#undef CIRCLE_VNODE
+    default:
+      throw std::runtime_error("Unknown opcode detected");
+  }
+}
+
 void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
                         loco::DataType quant_type)
 {
@@ -448,6 +467,50 @@ struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<voi
   void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); }
 
   void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); }
+
+  // Propagate qparam for non-value changing Ops
+  // (ex: Reshape, Transpose, etc.)
+  // TODO Add more Ops
+
+  void visit(luci::CircleReshape *node)
+  {
+    auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor());
+
+    // Do not propagate qparam if input node has multiple users
+    if (loco::succs(input_node).size() > 1)
+      return;
+
+    const auto input_opcode = input_node->opcode();
+
+    // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
+    // Why? It is not safe to propagate qparam to some virtual nodes. For example,
+    // const node, multi-out nodes. Let's block them for now.
+    // TODO Revisit this condition
+    if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
+      return;
+
+    overwrite_quantparam(node, input_node);
+  }
+
+  void visit(luci::CircleTranspose *node)
+  {
+    auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
+
+    // Do not propagate qparam if input node has multiple users
+    if (loco::succs(input_node).size() > 1)
+      return;
+
+    const auto input_opcode = input_node->opcode();
+
+    // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
+    // Why? It is not safe to propagate qparam to some virtual nodes. For example,
+    // const node, multi-out nodes. Let's block them for now.
+    // TODO Revisit this condition
+    if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
+      return;
+
+    overwrite_quantparam(node, input_node);
+  }
 };
 
 } // namespace