from tvm import relay
from .. import op as reg
+#################################################
+# Register the functions for different operators.
+#################################################
+
# Registering QNN Conv2D legalization function.
@reg.register_qnn_legalize("qnn.conv2d")
def legalize_qnn_conv2d(attrs, inputs, types):
- """Legalizes QNN conv2d op.
+ return qnn_conv2d_legalize(attrs, inputs, types)
+
+# Registering QNN dense legalization function.
+@reg.register_qnn_legalize("qnn.dense")
+def legalize_qnn_dense(attrs, inputs, types):
+ return qnn_dense_legalize(attrs, inputs, types)
+
+# Default to None. If overridden by target, this will not be run.
+# Generic QNN Conv2D legalization function.
+@tvm.target.generic_func
+def qnn_conv2d_legalize(attrs, inputs, types):
+ """Default legalization is None."""
+ return None
+
+# Generic QNN Conv2D legalization function.
+@tvm.target.generic_func
+def qnn_dense_legalize(attrs, inputs, types):
+ """Default legalization is None."""
+ return None
+
+###################
+# Helper functions.
+###################
+
+# Helper function for lowering in the abscence of fast Int8 arithmetic units.
+def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
+ """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
+ not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
+ much more efficiently if the convolution or dense operator input datatypes are int16 instead of
+ int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
Parameters
----------
result : tvm.relay.Expr
The legalized expr
"""
- return qnn_conv2d_legalize(attrs, inputs, types)
-# Generic QNN Conv2D legalization function.
-@tvm.target.generic_func
-def qnn_conv2d_legalize(attrs, inputs, types):
- """Default legalization is None."""
- return None
+ # Collect the input exprs.
+ data, kernel = inputs
-# Intel x86 QNN Conv2D legalization function.
-@qnn_conv2d_legalize.register('cpu')
-def _qnn_conv2d_legalize(attrs, inputs, types):
- """Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
- we dont transform. Else, we shift the tensor values and zero points to change the dtype.
+ input_zp = attrs['input_zero_point']
+ kernel_zp = attrs['kernel_zero_point']
+
+ shift_data = relay.subtract(relay.cast(data, dtype='int16'),
+ relay.const(input_zp, 'int16'))
+ shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'),
+ relay.const(kernel_zp, 'int16'))
+ new_attrs = {k : attrs[k] for k in attrs.keys()}
+ del new_attrs['kernel_zero_point']
+ del new_attrs['input_zero_point']
+ return relay_op(shift_data, shift_kernel, **new_attrs)
+
+# Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
+def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
+ """Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
+ are already good, we dont transform. Else, we shift the tensor values and zero points to change
+ the dtype.
Converting from int8 to uint8 can be done in following manner.
The legalized expr
"""
- def _shift(data, out_dtype):
+ def _shift(data, zero_point, out_dtype):
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
if out_dtype == 'uint8':
shift = 128
elif out_dtype == 'int8':
shift = -128
else:
- raise ValueError("Unsupport out dtype.")
+ raise ValueError("Unsupported out dtype.")
data_modified = relay.cast(data, 'int32')
data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
data_modified = relay.cast(data_modified, out_dtype)
- return data_modified
-
- def _is_int8_hw_support(target):
- """
- Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
- and above.
- """
- supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
- return supported_arches.intersection(set(target.options))
+ return (data_modified, zero_point + shift)
# Collect the dtypes.
data_dtype = types[0].dtype
# Collect the input exprs.
data, kernel = inputs
- # The VNNI transformations are applicable only Skylake and above.g
- target = tvm.target.current_target(allow_none=False)
- if not _is_int8_hw_support(target):
- return None
-
# VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
if data_dtype == 'uint8' and kernel_dtype == 'int8':
return None
input_zp = attrs['input_zero_point']
if data_dtype == 'int8':
# Compute (QA + 128) and (zp_a + 128)
- data = _shift(data, 'uint8')
- input_zp = input_zp + 128
+ data, input_zp = _shift(data, input_zp, 'uint8')
# Shift kernel if necessary.
kernel_zp = attrs['kernel_zero_point']
if kernel_dtype == 'uint8':
# Compute (QA - 128) and (zp_a - 128)
- kernel = _shift(kernel, 'int8')
- kernel_zp = kernel_zp - 128
+ kernel, kernel_zp = _shift(kernel, kernel_zp, 'int8')
# Call qnn.conv2d with modified inputs and zero points.
new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs['input_zero_point'] = input_zp
new_attrs['kernel_zero_point'] = kernel_zp
- return relay.qnn.op.conv2d(data, kernel, **new_attrs)
+ return relay_op(data, kernel, **new_attrs)
+
+# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
+def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
+ """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
+ many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
+ conv2d/dense such that both the dtypes are same.
+
+ Parameters
+ ----------
+ attrs : tvm.attrs.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+
+ def _shift(data, zero_point, out_dtype):
+ """Shifts (adds/subtracts) the qnn tensor by 128)"""
+ if out_dtype == 'uint8':
+ shift = 128
+ elif out_dtype == 'int8':
+ shift = -128
+ else:
+ raise ValueError("Unsupported out dtype.")
+ data_modified = relay.cast(data, 'int32')
+ data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
+ data_modified = relay.cast(data_modified, out_dtype)
+ return (data_modified, zero_point + shift)
+
+ # Collect the dtypes.
+ data_dtype = types[0].dtype
+ kernel_dtype = types[1].dtype
+
+ if data_dtype == kernel_dtype:
+ return None
+
+ # Collect the input exprs.
+ data, kernel = inputs
+
+ assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
+ "Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
+
+ # Shift input if necessary.
+ input_zp = attrs['input_zero_point']
+ data, input_zp = _shift(data, input_zp, kernel_dtype)
+
+ new_attrs = {k : attrs[k] for k in attrs.keys()}
+ new_attrs['input_zero_point'] = input_zp
+ return relay_op(data, kernel, **new_attrs)
+
+def is_fast_int8_on_intel():
+ """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
+ target = tvm.target.current_target(allow_none=False)
+ intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
+ return intel_supported_arches.intersection(set(target.options))
+
+def is_fast_int8_on_arm():
+ """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
+ target = tvm.target.current_target(allow_none=False)
+ return '+v8.2a,+dotprod' in ' '.join(target.options)
+
+########################
+# ARM CPU legalizations.
+########################
+
+@qnn_conv2d_legalize.register('arm_cpu')
+def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
+ # ARM prefers the dtypes to be same.
+ if is_fast_int8_on_arm():
+ return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
+ return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
+
+@qnn_dense_legalize.register('arm_cpu')
+def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
+ # ARM prefers the dtypes to be same.
+ if is_fast_int8_on_arm():
+ return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
+ return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
+
+##########################
+# Intel CPU legalizations.
+##########################
+
+@qnn_conv2d_legalize.register('cpu')
+def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
+ # The VNNI transformations prefer uint8 x int8 datatypes.
+ if is_fast_int8_on_intel():
+ return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d)
+ return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
+
+@qnn_dense_legalize.register('cpu')
+def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
+ # The VNNI transformations prefer uint8 x int8 datatypes.
+ if is_fast_int8_on_intel():
+ return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
+ return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
from tvm.relay.qnn.op import register_qnn_legalize
from tvm.relay import transform, analysis
+def alpha_equal(x, y):
+ """
+ Wrapper around alpha equality which ensures that
+ the hash function respects equality.
+ """
+ x = x['main']
+ y = y['main']
+ return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
b = run_opt_pass(expected(), transform.InferType())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
def test_qnn_legalize_qnn_conv2d():
- data_shape = (1, 64, 256, 256)
- kernel_shape = (128, 64, 3, 3)
- for dtype in ['uint8', 'int8']:
- data_dtype = kernel_dtype = dtype
+ def _get_mod(data_dtype, kernel_dtype):
+ data_shape = (1, 64, 256, 256)
+ kernel_shape = (128, 64, 3, 3)
data = relay.var("data", shape=data_shape,
dtype=data_dtype)
kernel = relay.var("kernel", shape=kernel_shape,
mod = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(mod)
+ return mod
+
+ # Check uint8 x uint8 and int8 x int8 transformation
+ for dtype in ('uint8', 'int8'):
+ mod = _get_mod(dtype, dtype)
+ #############################################################
+ # Check transformations for platforms with fast Int8 support.
+ #############################################################
+ # Check that Intel VNNI gets picked up.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
- mod = relay.qnn.transform.Legalize()(mod)
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
+
+ # Since same dtype, there should not be any transformation
+ with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert alpha_equal(mod, legalized_mod)
+
+ ################################################################
+ # Check transformations for platforms without fast Int8 support.
+ ################################################################
+ # Older Intel versions.
+ with tvm.target.create('llvm'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+ # Older ARM vesions.
+ with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+ # Check uint8 x int8 transformation
+ mod = _get_mod('uint8', 'int8')
+ #############################################################
+ # Check transformations for platforms with fast Int8 support.
+ #############################################################
+ # Check no transformation for Intel VNNI.
+ with tvm.target.create('llvm -mcpu=skylake-avx512'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert alpha_equal(mod, legalized_mod)
+
+ # ARM - so check that transformation has happened.
+ with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
+
+ ################################################################
+ # Check transformations for platforms without fast Int8 support.
+ ################################################################
+ # Older Intel versions.
+ with tvm.target.create('llvm'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+ # Older ARM vesions.
+ with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+
+def test_qnn_legalize_qnn_dense():
+ def _get_mod(data_dtype, kernel_dtype):
+ data_shape = (10, 3)
+ kernel_shape = (20, 3)
+ data = relay.var("data", shape=data_shape,
+ dtype=data_dtype)
+ kernel = relay.var("kernel", shape=kernel_shape,
+ dtype=kernel_dtype)
+ func = relay.qnn.op.dense(
+ data, kernel,
+ input_zero_point=1,
+ kernel_zero_point=1,
+ out_dtype='int32')
+
+ mod = relay.Function(relay.analysis.free_vars(func), func)
+ mod = relay.Module.from_expr(mod)
+ return mod
+
+ # Check uint8 x uint8 and int8 x int8 transformation
+ for dtype in ('uint8', 'int8'):
+ mod = _get_mod(dtype, dtype)
+
+ #############################################################
+ # Check transformations for platforms with fast Int8 support.
+ #############################################################
+ # Check that Intel VNNI gets picked up.
+ with tvm.target.create('llvm -mcpu=skylake-avx512'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
+
+ # Since same dtype, there should not be any transformation
+ with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert alpha_equal(mod, legalized_mod)
+
+ ################################################################
+ # Check transformations for platforms without fast Int8 support.
+ ################################################################
+ # Older Intel versions.
+ with tvm.target.create('llvm'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+ # Older ARM vesions.
+ with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+ # Check uint8 x int8 transformation
+ mod = _get_mod('uint8', 'int8')
+ #############################################################
+ # Check transformations for platforms with fast Int8 support.
+ #############################################################
+ # Check no transformation for Intel VNNI.
+ with tvm.target.create('llvm -mcpu=skylake-avx512'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert alpha_equal(mod, legalized_mod)
+
+ # ARM - so check that transformation has happened.
+ with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
+
+ ################################################################
+ # Check transformations for platforms without fast Int8 support.
+ ################################################################
+ # Older Intel versions.
+ with tvm.target.create('llvm'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+ # Older ARM vesions.
+ with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+ legalized_mod = relay.qnn.transform.Legalize()(mod)
+ assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
- assert 'cast' in mod.astext()
if __name__ == "__main__":
test_qnn_legalize()
test_qnn_legalize_qnn_conv2d()
+ test_qnn_legalize_qnn_dense()