namespace
{
-// Check if the node is the bias of Conv2D, DepthwiseConv2D, or FullyConnected layer
-bool is_bias(CircleConst *node)
-{
- if (node == nullptr)
- return false;
-
- auto succs = loco::succs(node);
- if (succs.size() != 1) // assume bias is used by only one node
- return false;
-
- for (auto out : succs)
- {
- auto conv = dynamic_cast<CircleConv2D *>(out);
- if (conv != nullptr && conv->bias() == node)
- return true;
-
- auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
- if (dw_conv != nullptr && dw_conv->bias() == node)
- return true;
-
- auto fc = dynamic_cast<CircleFullyConnected *>(out);
- if (fc != nullptr && fc->bias() == node)
- return true;
-
- auto tconv = dynamic_cast<CircleTransposeConv *>(out);
- if (tconv != nullptr && tconv->bias() == node)
- return true;
- }
- return false;
-}
-
+// Requantize Non-const node from int8 to uint8
+// Original values: -128 ~ 127
+// After requantization: 0 ~ 255
void requant_nonconst_int8_to_uint8(CircleNode *circle_node)
{
assert(circle_node->dtype() == loco::DataType::S8);
}
}
+#define RETURN_UNLESS(cond) \
+ if (not(cond)) \
+ return;
+
/**
- * @brief RequantizeNonConst requantizes tensors for activations
+ * @brief Requantize int8 quantized tensors to uint8 tensors
*/
-struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool>
+struct RequantizeS8ToU8 final : public luci::CircleNodeMutableVisitor<void>
{
- RequantizeNonConst(loco::DataType input, loco::DataType output)
- : _input_type(input), _output_type(output)
- {
- }
-
- loco::DataType _input_type;
- loco::DataType _output_type;
-
- // Requantize input tensors of each node
- bool visit(luci::CircleNode *node)
+ // Requantize non-const tensors
+ void visit(luci::CircleNode *node)
{
LOGGER(l);
- INFO(l) << "RequantizeNonConst visit node: " << node->name() << std::endl;
- auto arity = node->arity();
- for (uint32_t i = 0; i < arity; i++)
- {
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+ INFO(l) << "RequantizeS8ToU8 visit non-const node: " << node->name() << std::endl;
- // Check if this was quantized (only quantized tensors are requantized)
- if (circle_node->quantparam() == nullptr)
- continue;
+ // Ignore non-quantized tensors
+ RETURN_UNLESS(node->quantparam() != nullptr);
- // Check if this is already requantized
- if (circle_node->dtype() == _output_type)
- continue;
+ // Check dtype is int8
+ RETURN_UNLESS(node->dtype() == loco::DataType::S8);
- // Check if this is not const (only non-const is requantized in this function)
- auto circle_const = dynamic_cast<CircleConst *>(circle_node);
- if (circle_const != nullptr)
- continue;
-
- if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
- requant_nonconst_int8_to_uint8(circle_node);
- }
- return false;
- }
-};
-
-/**
- * @brief RequantizeConst requantizes tensors for weights
- */
-struct RequantizeConst final : public luci::CircleNodeMutableVisitor<bool>
-{
- RequantizeConst(loco::DataType input, loco::DataType output)
- : _input_type(input), _output_type(output)
- {
+ requant_nonconst_int8_to_uint8(node);
}
- loco::DataType _input_type;
- loco::DataType _output_type;
-
- // Requantize input tensors of each node
- bool visit(luci::CircleNode *node)
+ // Requantize const tensors
+ void visit(luci::CircleConst *node)
{
LOGGER(l);
- INFO(l) << "RequantizeConst visit node: " << node->name() << std::endl;
- auto arity = node->arity();
- for (uint32_t i = 0; i < arity; i++)
- {
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+ INFO(l) << "RequantizeS8ToU8 visit const node: " << node->name() << std::endl;
- // Check if this was quantized (only quantized tensors are requantized)
- if (circle_node->quantparam() == nullptr)
- continue;
+ // Ignore non-quantized tensors
+ RETURN_UNLESS(node->quantparam() != nullptr);
- // Check if this is already requantized
- if (circle_node->dtype() == _output_type)
- continue;
+ // Check dtype is int8
+ RETURN_UNLESS(node->dtype() == loco::DataType::S8);
- // Check if this is const (only const is requantized in this function)
- auto circle_const = dynamic_cast<CircleConst *>(circle_node);
- if (circle_const == nullptr)
- continue;
-
- // Check if this is not bias
- // bias is not requantized when int8 -> uint8
- if (is_bias(circle_const))
- continue;
-
- if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
- requant_const_int8_to_uint8(circle_const);
- }
- return false;
+ requant_const_int8_to_uint8(node);
}
};
+#undef RETURN_UNLESS
+
} // namespace
bool RequantizePass::run(loco::Graph *g)
LOGGER(l);
INFO(l) << "RequantizePass Start" << std::endl;
- // Requantize non-const (activations)
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ // Input: int8 model
+ // Output: uint8 model
+ if (_input_dtype == loco::DataType::S8 and _output_dtype == loco::DataType::U8)
{
- RequantizeNonConst rqnc(_input_dtype, _output_dtype);
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&rqnc);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ RequantizeS8ToU8 rq;
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ circle_node->accept(&rq);
+ }
}
-
- // Requantize const (including weights, constants)
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ else
{
- RequantizeConst rqc(_input_dtype, _output_dtype);
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- circle_node->accept(&rqc);
+ // Ignore other cases
+ return false;
}
// Update output dtype
for (auto node : loco::output_nodes(g))
{
auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
- if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_dtype)
+ auto from_node = loco::must_cast<luci::CircleNode *>(circle_node->from());
+ if (from_node->dtype() == _output_dtype)
{
circle_node->dtype(_output_dtype);
auto graph_output = graph_outputs->at(circle_node->index());