[TOPI][x86] Legalize - Support int8xint8 convolution to use VNNI instructions. (...
authorAnimesh Jain <anijain@umich.edu>
Fri, 25 Oct 2019 06:24:25 +0000 (23:24 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Fri, 25 Oct 2019 06:24:25 +0000 (23:24 -0700)
tests/python/relay/test_op_level2.py
topi/python/topi/x86/conv2d_alter_op.py

index e097980..9236d6e 100644 (file)
@@ -546,9 +546,11 @@ def test_conv2d_int8_intrinsics():
 
         n, h, w, ch, cw = 1, 64, 64, 3, 3
         if data_layout == 'NCHW':
-            x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
+            data_shape = (n, ic, h, w)
+            x = relay.var("x", relay.TensorType(data_shape, input_dtype))
         elif data_layout == 'NHWC':
-            x = relay.var("x", relay.TensorType((n, h, w, ic), input_dtype))
+            data_shape = (n, h, w, ic)
+            x = relay.var("x", relay.TensorType(data_shape, input_dtype))
         else:
             raise ValueError('Not supported')
 
@@ -559,8 +561,8 @@ def test_conv2d_int8_intrinsics():
         else:
             raise ValueError('Not supported')
 
-        w = relay.var("w", relay.TensorType(kernel_shape, weight_dtype))
-        y = relay.nn.conv2d(x, w,
+        weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
+        y = relay.nn.conv2d(x, weight,
                             kernel_size=(ch, cw),
                             channels=oc,
                             padding=(1, 1),
@@ -568,11 +570,13 @@ def test_conv2d_int8_intrinsics():
                             data_layout=data_layout,
                             kernel_layout=kernel_layout,
                             out_dtype=output_dtype)
-        func = relay.Function([x, w], y)
+        func = relay.Function([x, weight], y)
         wdata = np.random.rand(*kernel_shape) * 10
-        parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))}
+        parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
+
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build(func, target, params=parameters)
+
         assembly = lib.get_source("asm")
         return assembly
 
@@ -589,58 +593,63 @@ def test_conv2d_int8_intrinsics():
     llvm_version = tvm.codegen.llvm_version_major()
     for target in targets:
         if llvm_version >= 8:
-            fast_int8_dtypes = ('uint8', 'int8', 'int32')
+            dtypes = ('uint8', 'int8', 'int32')
             # Sweep the input channels to check int8 robustness
             # Input channels should be a multiple of 4 internally.
             for ic in [1, 4, 6]:
-                asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW",
+                asm = _compile(ic=ic, oc=16, target=target, data_layout="NCHW",
                                kernel_layout='OIHW',
-                               dtypes=fast_int8_dtypes)
+                               dtypes=dtypes)
                 assert _has_fast_int8_instructions(asm, target)
 
             for ic in [1, 4, 6]:
-                asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC",
+                asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC",
                                kernel_layout='HWIO',
-                               dtypes=fast_int8_dtypes)
+                               dtypes=dtypes)
                 assert _has_fast_int8_instructions(asm, target)
 
-
             # Sweep the output channels to check int8 robustness
             # Output channels should be a multiple of 16 internally.
             for oc in [4, 16, 20]:
-                asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW",
+                asm = _compile(ic=8, oc=oc, target=target, data_layout="NCHW",
                                kernel_layout='OIHW',
-                               dtypes=fast_int8_dtypes)
+                               dtypes=dtypes)
                 assert _has_fast_int8_instructions(asm, target)
 
             for oc in [4, 16, 20]:
-                asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC",
+                asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC",
                                kernel_layout='HWIO',
-                               dtypes=fast_int8_dtypes)
+                               dtypes=dtypes)
                 assert _has_fast_int8_instructions(asm, target)
 
             # Check that both non-divisible oc and ic work
             asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
-                           dtypes=fast_int8_dtypes)
+                           dtypes=dtypes)
             assert _has_fast_int8_instructions(asm, target)
 
             asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
-                           dtypes=fast_int8_dtypes)
+                           dtypes=dtypes)
             assert _has_fast_int8_instructions(asm, target)
 
-            # Ensure that code is generated when datatypes are not HW supported.
-            dtypes = ('int8', 'int8', 'int32')
-            asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
+    # Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
+    for target in targets:
+        if llvm_version >= 8:
+            dtypes = (('int8', 'int8', 'int32'))
+            # Check that both non-divisible oc and ic work
+            asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
                            dtypes=dtypes)
-            # Check that intrinisic is not present in the assembly.
-            assert not _has_fast_int8_instructions(asm, target)
+            assert _has_fast_int8_instructions(asm, target)
 
-            # Ensure that code is generated when datatypes are not HW supported.
-            dtypes = ('uint8', 'uint8', 'int32')
-            asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
+            asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
                            dtypes=dtypes)
-            # Check that intrinisic is not present in the assembly.
-            assert not _has_fast_int8_instructions(asm, target)
+            assert _has_fast_int8_instructions(asm, target)
+
+    # Ensure that code is generated when datatypes are not HW supported.
+    dtypes = ('uint8', 'uint8', 'int32')
+    asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
+                   dtypes=dtypes)
+    # Check that intrinisic is not present in the assembly.
+    assert not _has_fast_int8_instructions(asm, target)
 
     # Check that a vectorized instruction is generated for older Intel
     # generations, because we default to NCHWc layout.
index aec3efc..f596bc0 100644 (file)
@@ -192,24 +192,72 @@ def _conv2d_legalize(attrs, inputs, arg_types):
         The legalized expr
     """
 
+    # Dilation not supported yet. Return None if dilation is not (1, 1)
+    dilation = attrs.get_int_tuple("dilation")
+    if not (dilation[0] == 1 and dilation[1] == 1):
+        return None
+
     # Collect the input tensors.
     data_tensor, kernel_tensor = arg_types[0], arg_types[1]
+    data_dtype = data_tensor.dtype
+    kernel_dtype = kernel_tensor.dtype
 
     # Collect the output tensor.
     output_tensor = arg_types[2]
 
+    # Collect the input exprs.
+    data, kernel = inputs
+
+    # Get the conv attrs
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+    is_int8_inputs = False
+    # If both the inputs are int8, we can add 128 to make the input dtype uint8, and then adjust the
+    # output. This will help picking up Intel VNNI instructions.
+    # Original --> C = A (conv) B
+    # A and B are int8
+    #   C = (A + 128 - 128) (conv) B
+    #   C = (A' conv B) - 128 (conv) B
+    # where A' = A + 128
+    # and 128 (conv) B is basically a reduce on CRS axis for weights.
+    if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8':
+        is_int8_inputs = True
+        padding = attrs.get_int_tuple("padding")
+
+        if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
+            adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2))
+            pad_width = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0))
+        elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
+            pad_width = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1]))
+            adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3))
+            adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2)
+        else:
+            return None
+
+        data = relay.cast(data, 'int32')
+        data = relay.add(data, relay.const(128, 'int32'))
+        data = relay.cast(data, 'uint8')
+
+        # Do external padding as pad value has to be 128.
+        if not (padding[0] == 0 and padding[1] == 0):
+            data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
+        new_attrs['padding'] = (0, 0)
+
+        # The data type is now shifted to uint8
+        data_dtype = 'uint8'
+
+        # Multiply 128 to adjust shift.
+        adjust_shift = relay.multiply(adjust_shift, relay.const(128, 'int32'))
+
     # Legalize if the datatypes are suitable for fast Int8 instructions.  Int8 instructions require
     # input channel to be a multiple of 4 and output channels to be a multiple of 16. For input
     # channels, we pad both the inputs and weights input channels. For output channels, we pad the
     # weight and stride_slice the output.
-    if _is_int8_hw_support(data_tensor.dtype, kernel_tensor.dtype):
+    if _is_int8_hw_support(data_dtype, kernel_dtype):
         # Flags to remember if the expr is modified
         ic_modified = False
         oc_modified = False
 
-        # Collect the input exprs.
-        data, kernel = inputs
-
         # Find the value of input and output channel.
         in_channel = -1
         out_channel = -1
@@ -250,16 +298,16 @@ def _conv2d_legalize(attrs, inputs, arg_types):
             else:
                 return None
 
-        if not (ic_modified or oc_modified):
-            return None
-
-        if ic_modified and not oc_modified:
-            return relay.nn.conv2d(data, kernel, **attrs)
-
         if oc_modified:
-            new_attrs = {k: attrs[k] for k in attrs.keys()}
             new_attrs['channels'] = new_out_channel
             out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
             original_out_shape = [x.value for x in output_tensor.shape]
-            return relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
+            out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
+        else:
+            out = relay.nn.conv2d(data, kernel, **new_attrs)
+
+        if is_int8_inputs:
+            out = relay.subtract(out, adjust_shift)
+
+        return out
     return None