[VTA][Relay] Extending Vision model coverage compilation for VTA (#3740)
authorThierry Moreau <moreau@uw.edu>
Thu, 5 Sep 2019 18:17:09 +0000 (11:17 -0700)
committerJared Roesch <roeschinc@gmail.com>
Thu, 5 Sep 2019 18:17:09 +0000 (11:17 -0700)
* adding support for graphpack over multiply op

* increasing resnet model coverage

* fix indentation

* lint

* moving recursion limit fix into graphpack pass

* moving recursionlimit to relay init

* pooling on NCHWnc format

* adding more models

* deploy_resnet_on_vta.py

* trailing line

* generalizing to vision models

* merge conflicts

* fix, apply quantization to VTA only

* improving comments

* trimming models that have runtime issues for the moment

* lint

* lint

* lint

python/tvm/relay/__init__.py
src/relay/op/nn/pooling.cc
vta/python/vta/top/graphpack.py
vta/tutorials/frontend/deploy_vision_on_vta.py [moved from vta/tutorials/frontend/deploy_resnet_on_vta.py with 87% similarity]

index 8271244..b56ef65 100644 (file)
@@ -17,6 +17,7 @@
 # pylint: disable=wildcard-import, redefined-builtin, invalid-name
 """The Relay IR namespace containing the IR definition and compiler."""
 from __future__ import absolute_import
+from sys import setrecursionlimit
 from ..api import register_func
 from . import base
 from . import ty
@@ -59,6 +60,9 @@ from . import qnn
 
 from .scope_builder import ScopeBuilder
 
+# Required to traverse large programs
+setrecursionlimit(10000)
+
 # Span
 Span = base.Span
 
index 72de071..76dec99 100644 (file)
@@ -161,9 +161,12 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
   CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
       << "max_pool2d does not support input split on width";
 
-  CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U)
+  CHECK(inputs[0].ndim() == 4U ||
+        inputs[0].ndim() == 5U ||
+        inputs[0].ndim() == 6U)
       << "Pool2D only support 4-D input (e.g., NCHW)"
-      << " or 5-D input (last dimension is a split of channel)";
+      << " or 5-D input (e.g. NCHWc on for vector instructions)"
+      << " or 6-D input (e.g. NCHWnc for tensor accelerators)";
 
   if (param->padding.size() == 1) {
     padding.push_back(padding[0]);
index d894fc0..a4c0548 100644 (file)
@@ -85,8 +85,8 @@ def _pack_weight_conv2d_transpose(data, dshape, cfactor):
     return data
 
 
-def _pack_bias(data, dshape, dtype, bfactor, cfactor):
-    """Pack the bias parameter.
+def _pack_const(data, dshape, dtype, bfactor, cfactor):
+    """Pack a constant parameter.
     """
     dshape = _to_shape(dshape)
     assert len(dshape) == 3
@@ -124,6 +124,7 @@ class ExprPack(ExprMutator):
         self.conv2d = op.op.get("nn.conv2d")
         self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
         self.add = op.op.get("add")
+        self.multiply = op.op.get("multiply")
         self.bias_add = op.op.get("nn.bias_add")
         self.number_of_conv2d = 0
         super().__init__()
@@ -203,23 +204,35 @@ class ExprPack(ExprMutator):
                         output_padding=call.attrs.output_padding,
                         out_dtype=call.attrs.out_dtype)
                 return conv2d
-            elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape):
+            elif call.op == self.add and \
+                    tuple(input_types[0].shape) == tuple(input_types[1].shape):
                 pass
             elif call.op == self.add and len(input_types[1].shape) == 3:
-                data, bias = args
-                bias = _pack_bias(bias,
-                                  _to_shape(input_types[1].shape),
-                                  input_types[1].dtype,
-                                  self.bfactor,
-                                  self.cfactor)
-                return relay.Call(self.add, [data, bias])
+                data, const = args
+                const = _pack_const(const,
+                                    _to_shape(input_types[1].shape),
+                                    input_types[1].dtype,
+                                    self.bfactor,
+                                    self.cfactor)
+                return relay.Call(self.add, [data, const])
+            elif call.op == self.multiply and \
+                    tuple(input_types[0].shape) == tuple(input_types[1].shape):
+                pass
+            elif call.op == self.multiply and len(input_types[1].shape) == 3:
+                data, const = args
+                const = _pack_const(const,
+                                    _to_shape(input_types[1].shape),
+                                    input_types[1].dtype,
+                                    self.bfactor,
+                                    self.cfactor)
+                return relay.Call(self.multiply, [data, const])
             elif self.start_pack and call.op == self.bias_add:
                 data, bias = args
-                bias = _pack_bias(bias,
-                                  _to_shape(input_types[1].shape),
-                                  input_types[1].dtype,
-                                  self.bfactor,
-                                  self.cfactor)
+                bias = _pack_const(bias,
+                                   _to_shape(input_types[1].shape),
+                                   input_types[1].dtype,
+                                   self.bfactor,
+                                   self.cfactor)
                 return relay.Call(self.add, [data, bias])
             elif self.start_pack and call.op == op.op.get('cast') and \
                     input_types[0].dtype == 'int32':
 # specific language governing permissions and limitations
 # under the License.
 """
-Deploy Pretrained ResNet Model from MxNet on VTA
+Deploy Pretrained Vision Model from MxNet on VTA
 ================================================
 **Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
 
-This tutorial provides an end-to-end demo, on how to run ResNet-18 inference
-onto the VTA accelerator design to perform ImageNet classification tasks.
+This tutorial provides an end-to-end demo, on how to run ImageNet classification
+inference onto the VTA accelerator design to perform ImageNet classification tasks.
 It showcases Relay as a front end compiler that can perform quantization (VTA
 only supports int8/32 inference) as well as graph packing (in order to enable
 tensorization in the core) to massage the compute graph for the hardware target.
@@ -40,7 +40,7 @@ tensorization in the core) to massage the compute graph for the hardware target.
 
 from __future__ import absolute_import, print_function
 
-import argparse, json, os, requests, time
+import argparse, json, os, requests, sys, time
 from io import BytesIO
 from os.path import join, isfile
 from PIL import Image
@@ -53,6 +53,7 @@ import tvm
 from tvm import rpc, autotvm, relay
 from tvm.contrib import graph_runtime, util, download
 from tvm.contrib.debugger import debug_runtime
+from tvm.relay import transform
 
 import vta
 from vta.testing import simulator
@@ -61,7 +62,6 @@ from vta.top import graph_pack
 # Make sure that TVM was compiled with RPC=1
 assert tvm.module.enabled("rpc")
 
-
 ######################################################################
 # Define the platform and model targets
 # -------------------------------------
@@ -75,13 +75,22 @@ env = vta.get_env()
 device = "vta"
 target = env.target if device == "vta" else env.target_vta_cpu
 
+# Dictionary lookup for when to start/end bit packing
+pack_dict = {
+    "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+}
+
 # Name of Gluon model to compile
 # The ``start_pack`` and ``stop_pack`` labels indicate where
 # to start and end the graph packing relay pass: in other words
 # where to start and finish offloading to VTA.
 model = "resnet18_v1"
-start_pack="nn.max_pool2d"
-stop_pack="nn.global_avg_pool2d"
+assert model in pack_dict
 
 ######################################################################
 # Obtain an execution remote
@@ -125,7 +134,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
 ######################################################################
 # Build the inference graph runtime
 # ---------------------------------
-# Grab ResNet-18 model from Gluon model zoo and compile with Relay.
+# Grab vision model from Gluon model zoo and compile with Relay.
 # The compilation steps are:
 #    1) Front end translation from MxNet into Relay module.
 #    2) Apply 8-bit quantization: here we skip the first conv layer,
@@ -140,7 +149,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
 # Load pre-configured AutoTVM schedules
 with autotvm.tophub.context(target):
 
-    # Populate the shape and data type dictionary for ResNet input
+    # Populate the shape and data type dictionary for ImageNet classifier input
     dtype_dict = {"data": 'float32'}
     shape_dict = {"data": (env.BATCH, 3, 224, 224)}
 
@@ -157,21 +166,22 @@ with autotvm.tophub.context(target):
     shape_dict.update({k: v.shape for k, v in params.items()})
     dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
 
-    # Perform quantization in Relay
-    with relay.quantize.qconfig(global_scale=8.0,
-                                skip_conv_layers=[0]):
-        relay_prog = relay.quantize.quantize(mod["main"], params=params)
-
-    # Perform graph packing and constant folding for VTA target
     if target.device_name == "vta":
+        # Perform quantization in Relay
+        with relay.quantize.qconfig(global_scale=8.0,
+                                    skip_conv_layers=[0]):
+            relay_prog = relay.quantize.quantize(mod["main"], params=params)
+        # Perform graph packing and constant folding for VTA target
         assert env.BLOCK_IN == env.BLOCK_OUT
         relay_prog = graph_pack(
             relay_prog,
             env.BATCH,
             env.BLOCK_OUT,
             env.WGT_WIDTH,
-            start_name=start_pack,
-            stop_name=stop_pack)
+            start_name=pack_dict[model][0],
+            stop_name=pack_dict[model][1])
+    else:
+        relay_prog = mod["main"]
 
     # Compile Relay program with AlterOpLayout disabled
     with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
@@ -199,8 +209,8 @@ with autotvm.tophub.context(target):
     m = graph_runtime.create(graph, lib, ctx)
 
 ######################################################################
-# Perform ResNet-18 inference
-# ---------------------------
+# Perform image classification inference
+# --------------------------------------
 # We run classification on an image sample from ImageNet
 # We just need to download the categories files, `synset.txt`
 # and an input test image.
@@ -256,7 +266,6 @@ else:
 tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
 for b in range(env.BATCH):
     top_categories = np.argsort(tvm_output.asnumpy()[b])
-
     # Report top-5 classification results
     print("\n{} prediction for sample {}".format(model, b))
     print("\t#1:", synset[top_categories[-1]])
@@ -264,7 +273,6 @@ for b in range(env.BATCH):
     print("\t#3:", synset[top_categories[-3]])
     print("\t#4:", synset[top_categories[-4]])
     print("\t#5:", synset[top_categories[-5]])
-
     # This just checks that one of the 5 top categories
     # is one variety of cat; this is by no means an accurate
     # assessment of how quantization affects classification