}
/*
- * \brief Fallback to simpler lowering for dilation or grouped conv.
+ * \brief Fallback to simpler lowering for dilation (when non-zero kernel point) or grouped conv.
* \param data The input expr.
* \param weight The weight expr.
* \param input_zero_point The input zero point expr.
* \param kernel_zero_point The kernel zero point expr.
* \param param The qnn conv2d attributes.
* \return The fallback lowered sequence of Relay expr.
- * \note In case of dilation, normal lowering would require a dilated pool.
- * Since, we don't have dilated pool, we fallback to a simpler sequence of
- * Relay operations. This will potentially lead to performance degradation
- * as the convolution is called on int32 tensors instead of int8 tensors.
+ * \note In case of dilation with non-zero kernel zero point, normal lowering would require a
+ * dilated pool. Since, we don't have dilated pool, we fallback to a simpler sequence of Relay
+ * operations. This will potentially lead to performance degradation as the convolution is called on
+ * int32 tensors instead of int8 tensors.
*/
Expr Conv2DFallBack(const Expr& data, const Expr& weight, const Expr& input_zero_point,
const Expr& kernel_zero_point, const Conv2DAttrs* param) {
auto input_zero_point_int = GetScalarFromConstant<int>(input_zero_point);
auto kernel_zero_point_int = GetScalarFromConstant<int>(kernel_zero_point);
- // Fallback to int32 conv if there is dilation or grouped conv2d
+ // Fallback to int32 conv if there is dilation with non-zero kernel point or grouped conv2d
+ // For dilated conv, if the kernel zero point is non-zero, the pooling operator also has to
+ // traverse the elements in dilated manner. Currently, we do not have strided pool. So, in case of
+ // dilated conv with non-zero kernel point, we fall back to simpler but slow lowering.
CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation";
auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]);
- if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 && !is_depthwise(param))) {
+ if ((kernel_zero_point_int != 0 && (dilation_h != 1 || dilation_w != 1)) ||
+ (param->groups != 1 && !is_depthwise(param))) {
return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param);
} else if (is_depthwise(param)) {
CHECK_NE(channel_multiplier, -1);
def test_dilation():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):
- # uint8 input
+ # Non-zero kernel point - fall back to simpler lowering.
data_shape = (2, 4, 4, 4)
data_dtype = 'uint8'
kernel_shape = (3, 4, 2, 2)
verify(ref_func, qnn_func, data_shape, data_dtype,
kernel_shape, kernel_dtype)
+ # Zero kernel point
+ data_shape = (2, 4, 4, 4)
+ data_dtype = 'uint8'
+ kernel_shape = (3, 4, 2, 2)
+ kernel_dtype = 'uint8'
+ ref_func, qnn_func = get_funcs(data_shape=data_shape,
+ data_dtype=data_dtype,
+ kernel_shape=kernel_shape,
+ kernel_dtype=kernel_dtype,
+ input_zero_point=0,
+ kernel_zero_point=0,
+ input_scale=1.0,
+ kernel_scale=1.0,
+ kernel_size=(2, 2),
+ padding=(0, 0),
+ strides=(1, 1),
+ dilation=(2, 2),
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_dtype="int32")
+ verify(ref_func, qnn_func, data_shape, data_dtype,
+ kernel_shape, kernel_dtype)
+
def test_const_folding():
with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d):