[QNN][Legalize] Specialize for Platforms without any fast Int8 arithmetic units....
authorAnimesh Jain <anijain@umich.edu>
Wed, 13 Nov 2019 19:18:49 +0000 (11:18 -0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Wed, 13 Nov 2019 19:18:49 +0000 (11:18 -0800)
python/tvm/relay/qnn/op/legalizations.py
tests/python/relay/test_pass_qnn_legalize.py

index 6b2e073..3f94f98 100644 (file)
@@ -22,10 +22,43 @@ import tvm
 from tvm import relay
 from .. import op as reg
 
+#################################################
+# Register the functions for different operators.
+#################################################
+
 # Registering QNN Conv2D legalization function.
 @reg.register_qnn_legalize("qnn.conv2d")
 def legalize_qnn_conv2d(attrs, inputs, types):
-    """Legalizes QNN conv2d op.
+    return qnn_conv2d_legalize(attrs, inputs, types)
+
+# Registering QNN dense legalization function.
+@reg.register_qnn_legalize("qnn.dense")
+def legalize_qnn_dense(attrs, inputs, types):
+    return qnn_dense_legalize(attrs, inputs, types)
+
+# Default to None. If overridden by target, this will not be run.
+# Generic QNN Conv2D legalization function.
+@tvm.target.generic_func
+def qnn_conv2d_legalize(attrs, inputs, types):
+    """Default legalization is None."""
+    return None
+
+# Generic QNN Conv2D legalization function.
+@tvm.target.generic_func
+def qnn_dense_legalize(attrs, inputs, types):
+    """Default legalization is None."""
+    return None
+
+###################
+# Helper functions.
+###################
+
+# Helper function for lowering in the abscence of fast Int8 arithmetic units.
+def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
+    """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
+    not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
+    much more efficiently if the convolution or dense operator input datatypes are int16 instead of
+    int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
 
     Parameters
     ----------
@@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types):
     result : tvm.relay.Expr
         The legalized expr
     """
-    return qnn_conv2d_legalize(attrs, inputs, types)
 
-# Generic QNN Conv2D legalization function.
-@tvm.target.generic_func
-def qnn_conv2d_legalize(attrs, inputs, types):
-    """Default legalization is None."""
-    return None
+    # Collect the input exprs.
+    data, kernel = inputs
 
-# Intel x86 QNN Conv2D legalization function.
-@qnn_conv2d_legalize.register('cpu')
-def _qnn_conv2d_legalize(attrs, inputs, types):
-    """Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
-    we dont transform. Else, we shift the tensor values and zero points to change the dtype.
+    input_zp = attrs['input_zero_point']
+    kernel_zp = attrs['kernel_zero_point']
+
+    shift_data = relay.subtract(relay.cast(data, dtype='int16'),
+                                relay.const(input_zp, 'int16'))
+    shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'),
+                                  relay.const(kernel_zp, 'int16'))
+    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    del new_attrs['kernel_zero_point']
+    del new_attrs['input_zero_point']
+    return relay_op(shift_data, shift_kernel, **new_attrs)
+
+# Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
+def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
+    """Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
+    are already good, we dont transform. Else, we shift the tensor values and zero points to change
+    the dtype.
 
     Converting from int8 to uint8 can be done in following manner.
 
@@ -82,26 +123,18 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
         The legalized expr
     """
 
-    def _shift(data, out_dtype):
+    def _shift(data, zero_point, out_dtype):
         """Shifts (add/subtracts) the qnn tensor with +/-128)"""
         if out_dtype == 'uint8':
             shift = 128
         elif out_dtype == 'int8':
             shift = -128
         else:
-            raise ValueError("Unsupport out dtype.")
+            raise ValueError("Unsupported out dtype.")
         data_modified = relay.cast(data, 'int32')
         data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
         data_modified = relay.cast(data_modified, out_dtype)
-        return data_modified
-
-    def _is_int8_hw_support(target):
-        """
-        Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
-        and above.
-        """
-        supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
-        return supported_arches.intersection(set(target.options))
+        return (data_modified, zero_point + shift)
 
     # Collect the dtypes.
     data_dtype = types[0].dtype
@@ -110,11 +143,6 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
     # Collect the input exprs.
     data, kernel = inputs
 
-    # The VNNI transformations are applicable only Skylake and above.g
-    target = tvm.target.current_target(allow_none=False)
-    if not _is_int8_hw_support(target):
-        return None
-
     # VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
     if data_dtype == 'uint8' and kernel_dtype == 'int8':
         return None
@@ -123,18 +151,118 @@ def _qnn_conv2d_legalize(attrs, inputs, types):
     input_zp = attrs['input_zero_point']
     if data_dtype == 'int8':
         # Compute (QA + 128) and (zp_a + 128)
-        data = _shift(data, 'uint8')
-        input_zp = input_zp + 128
+        data, input_zp = _shift(data, input_zp, 'uint8')
 
     # Shift kernel if necessary.
     kernel_zp = attrs['kernel_zero_point']
     if kernel_dtype == 'uint8':
         # Compute (QA - 128) and (zp_a - 128)
-        kernel = _shift(kernel, 'int8')
-        kernel_zp = kernel_zp - 128
+        kernel, kernel_zp = _shift(kernel, kernel_zp, 'int8')
 
     # Call qnn.conv2d with modified inputs and zero points.
     new_attrs = {k : attrs[k] for k in attrs.keys()}
     new_attrs['input_zero_point'] = input_zp
     new_attrs['kernel_zero_point'] = kernel_zp
-    return relay.qnn.op.conv2d(data, kernel, **new_attrs)
+    return relay_op(data, kernel, **new_attrs)
+
+# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
+def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
+    """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
+    many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
+    conv2d/dense such that both the dtypes are same.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+
+    def _shift(data, zero_point, out_dtype):
+        """Shifts (adds/subtracts) the qnn tensor by 128)"""
+        if out_dtype == 'uint8':
+            shift = 128
+        elif out_dtype == 'int8':
+            shift = -128
+        else:
+            raise ValueError("Unsupported out dtype.")
+        data_modified = relay.cast(data, 'int32')
+        data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
+        data_modified = relay.cast(data_modified, out_dtype)
+        return (data_modified, zero_point + shift)
+
+    # Collect the dtypes.
+    data_dtype = types[0].dtype
+    kernel_dtype = types[1].dtype
+
+    if data_dtype == kernel_dtype:
+        return None
+
+    # Collect the input exprs.
+    data, kernel = inputs
+
+    assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
+            "Qnn Conv2D/Dense only accepts uint8 or int8 inputs"
+
+    # Shift input if necessary.
+    input_zp = attrs['input_zero_point']
+    data, input_zp = _shift(data, input_zp, kernel_dtype)
+
+    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    new_attrs['input_zero_point'] = input_zp
+    return relay_op(data, kernel, **new_attrs)
+
+def is_fast_int8_on_intel():
+    """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
+    target = tvm.target.current_target(allow_none=False)
+    intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
+    return intel_supported_arches.intersection(set(target.options))
+
+def is_fast_int8_on_arm():
+    """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
+    target = tvm.target.current_target(allow_none=False)
+    return '+v8.2a,+dotprod' in ' '.join(target.options)
+
+########################
+# ARM CPU legalizations.
+########################
+
+@qnn_conv2d_legalize.register('arm_cpu')
+def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
+    # ARM prefers the dtypes to be same.
+    if is_fast_int8_on_arm():
+        return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
+
+@qnn_dense_legalize.register('arm_cpu')
+def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
+    # ARM prefers the dtypes to be same.
+    if is_fast_int8_on_arm():
+        return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
+
+##########################
+# Intel CPU legalizations.
+##########################
+
+@qnn_conv2d_legalize.register('cpu')
+def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
+    # The VNNI transformations prefer uint8 x int8 datatypes.
+    if is_fast_int8_on_intel():
+        return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
+
+@qnn_dense_legalize.register('cpu')
+def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
+    # The VNNI transformations prefer uint8 x int8 datatypes.
+    if is_fast_int8_on_intel():
+        return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
index 55c1fa6..8ace7bc 100644 (file)
@@ -23,6 +23,14 @@ from tvm.contrib import graph_runtime
 from tvm.relay.qnn.op import register_qnn_legalize
 from tvm.relay import transform, analysis
 
+def alpha_equal(x, y):
+    """
+    Wrapper around alpha equality which ensures that
+    the hash function respects equality.
+    """
+    x = x['main']
+    y = y['main']
+    return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
 
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
@@ -82,11 +90,11 @@ def test_qnn_legalize():
     b = run_opt_pass(expected(), transform.InferType())
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
+
 def test_qnn_legalize_qnn_conv2d():
-    data_shape = (1, 64, 256, 256)
-    kernel_shape = (128, 64, 3, 3)
-    for dtype in ['uint8', 'int8']:
-        data_dtype =  kernel_dtype = dtype
+    def _get_mod(data_dtype, kernel_dtype):
+        data_shape = (1, 64, 256, 256)
+        kernel_shape = (128, 64, 3, 3)
         data = relay.var("data", shape=data_shape,
                 dtype=data_dtype)
         kernel = relay.var("kernel", shape=kernel_shape,
@@ -104,12 +112,145 @@ def test_qnn_legalize_qnn_conv2d():
 
         mod = relay.Function(relay.analysis.free_vars(func), func)
         mod = relay.Module.from_expr(mod)
+        return mod
+
+    # Check uint8 x uint8 and int8 x int8 transformation
+    for dtype in ('uint8', 'int8'):
+        mod = _get_mod(dtype, dtype)
 
+        #############################################################
+        # Check transformations for platforms with fast Int8 support.
+        #############################################################
+        # Check that Intel VNNI gets picked up.
         with tvm.target.create('llvm -mcpu=skylake-avx512'):
-            mod = relay.qnn.transform.Legalize()(mod)
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
+
+        # Since same dtype, there should not be any transformation
+        with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert alpha_equal(mod, legalized_mod)
+
+        ################################################################
+        # Check transformations for platforms without fast Int8 support.
+        ################################################################
+        # Older Intel versions.
+        with tvm.target.create('llvm'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+        # Older ARM vesions.
+        with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+    # Check uint8 x int8 transformation
+    mod = _get_mod('uint8', 'int8')
+    #############################################################
+    # Check transformations for platforms with fast Int8 support.
+    #############################################################
+    # Check no transformation for Intel VNNI.
+    with tvm.target.create('llvm -mcpu=skylake-avx512'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert alpha_equal(mod, legalized_mod)
+
+    # ARM - so check that transformation has happened.
+    with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn.conv2d" in legalized_mod.astext()
+
+    ################################################################
+    # Check transformations for platforms without fast Int8 support.
+    ################################################################
+    # Older Intel versions.
+    with tvm.target.create('llvm'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+    # Older ARM vesions.
+    with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+
+def test_qnn_legalize_qnn_dense():
+    def _get_mod(data_dtype, kernel_dtype):
+        data_shape = (10, 3)
+        kernel_shape = (20, 3)
+        data = relay.var("data", shape=data_shape,
+                dtype=data_dtype)
+        kernel = relay.var("kernel", shape=kernel_shape,
+                dtype=kernel_dtype)
+        func = relay.qnn.op.dense(
+                data, kernel,
+                input_zero_point=1,
+                kernel_zero_point=1,
+                out_dtype='int32')
+
+        mod = relay.Function(relay.analysis.free_vars(func), func)
+        mod = relay.Module.from_expr(mod)
+        return mod
+
+    # Check uint8 x uint8 and int8 x int8 transformation
+    for dtype in ('uint8', 'int8'):
+        mod = _get_mod(dtype, dtype)
+
+        #############################################################
+        # Check transformations for platforms with fast Int8 support.
+        #############################################################
+        # Check that Intel VNNI gets picked up.
+        with tvm.target.create('llvm -mcpu=skylake-avx512'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
+
+        # Since same dtype, there should not be any transformation
+        with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert alpha_equal(mod, legalized_mod)
+
+        ################################################################
+        # Check transformations for platforms without fast Int8 support.
+        ################################################################
+        # Older Intel versions.
+        with tvm.target.create('llvm'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+        # Older ARM vesions.
+        with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+            legalized_mod = relay.qnn.transform.Legalize()(mod)
+            assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+    # Check uint8 x int8 transformation
+    mod = _get_mod('uint8', 'int8')
+    #############################################################
+    # Check transformations for platforms with fast Int8 support.
+    #############################################################
+    # Check no transformation for Intel VNNI.
+    with tvm.target.create('llvm -mcpu=skylake-avx512'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert alpha_equal(mod, legalized_mod)
+
+    # ARM - so check that transformation has happened.
+    with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn.dense" in legalized_mod.astext()
+
+    ################################################################
+    # Check transformations for platforms without fast Int8 support.
+    ################################################################
+    # Older Intel versions.
+    with tvm.target.create('llvm'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
+
+    # Older ARM vesions.
+    with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu'):
+        legalized_mod = relay.qnn.transform.Legalize()(mod)
+        assert 'cast' in legalized_mod.astext() and "qnn" not in legalized_mod.astext()
 
-        assert 'cast' in mod.astext()
 
 if __name__ == "__main__":
     test_qnn_legalize()
     test_qnn_legalize_qnn_conv2d()
+    test_qnn_legalize_qnn_dense()