if cnt < current_qconfig().skip_k_conv:
_set_conv_counter(cnt + 1)
return None
+
+ if current_qconfig().skip_conv_layers is not None:
+ leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+ if cnt in leave_alone_indices:
+ _set_conv_counter(cnt + 1)
+ return None
+
_set_conv_counter(cnt + 1)
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
+
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv:
return None
+ if current_qconfig().skip_conv_layers is not None:
+ leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+ if cnt - 1 in leave_alone_indices:
+ return None
+
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
- if _conv_counter() <= current_qconfig().skip_k_conv:
+ cnt = _conv_counter()
+ if cnt <= current_qconfig().skip_k_conv:
return None
+ if current_qconfig().skip_conv_layers is not None:
+ leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+ if cnt - 1 in leave_alone_indices:
+ return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
- if _conv_counter() <= current_qconfig().skip_k_conv:
+ cnt = _conv_counter()
+ if cnt <= current_qconfig().skip_k_conv:
return None
+ if current_qconfig().skip_conv_layers is not None:
+ leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+ if cnt - 1 in leave_alone_indices:
+ return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation"""
- if _conv_counter() <= current_qconfig().skip_k_conv:
+ cnt = _conv_counter()
+ if cnt <= current_qconfig().skip_k_conv:
return None
+ if current_qconfig().skip_conv_layers is not None:
+ leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+ if cnt - 1 in leave_alone_indices:
+ return None
x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
- if _conv_counter() <= current_qconfig().skip_k_conv:
+ cnt = _conv_counter()
+ if cnt <= current_qconfig().skip_k_conv:
return None
+ if current_qconfig().skip_conv_layers is not None:
+ leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+ if cnt - 1 in leave_alone_indices:
+ return None
+
expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
- if _conv_counter() <= current_qconfig().skip_k_conv:
+ cnt = _conv_counter()
+ if cnt <= current_qconfig().skip_k_conv:
return None
+ if current_qconfig().skip_conv_layers is not None:
+ leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
+ if cnt - 1 in leave_alone_indices:
+ return None
input_tuple = new_args[0]
expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
DataType dtype_activation = Int(32);
double global_scale = 8.0;
int skip_k_conv = 1;
+ Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
bool round_for_shift = true;
bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale);
v->Visit("skip_k_conv", &skip_k_conv);
+ v->Visit("skip_conv_layers", &skip_conv_layers);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
pre_computed=False)
if cfg.template_key == 'int8':
- return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
+ if (data.dtype == 'int8' or data.dtype == 'uint8'):
+ return conv2d_NCHWc_int8(
+ cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
if layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)