From 028f47ce654d7419e0cebb274541696e960fb90c Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Thu, 5 Sep 2019 11:17:09 -0700 Subject: [PATCH] [VTA][Relay] Extending Vision model coverage compilation for VTA (#3740) * adding support for graphpack over multiply op * increasing resnet model coverage * fix indentation * lint * moving recursion limit fix into graphpack pass * moving recursionlimit to relay init * pooling on NCHWnc format * adding more models * deploy_resnet_on_vta.py * trailing line * generalizing to vision models * merge conflicts * fix, apply quantization to VTA only * improving comments * trimming models that have runtime issues for the moment * lint * lint * lint --- python/tvm/relay/__init__.py | 4 ++ src/relay/op/nn/pooling.cc | 7 ++- vta/python/vta/top/graphpack.py | 43 ++++++++++++------- ...oy_resnet_on_vta.py => deploy_vision_on_vta.py} | 50 +++++++++++++--------- 4 files changed, 66 insertions(+), 38 deletions(-) rename vta/tutorials/frontend/{deploy_resnet_on_vta.py => deploy_vision_on_vta.py} (87%) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 8271244..b56ef65 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -17,6 +17,7 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import +from sys import setrecursionlimit from ..api import register_func from . import base from . import ty @@ -59,6 +60,9 @@ from . import qnn from .scope_builder import ScopeBuilder +# Required to traverse large programs +setrecursionlimit(10000) + # Span Span = base.Span diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 72de071..76dec99 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -161,9 +161,12 @@ Array Pool2DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool2d does not support input split on width"; - CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) + CHECK(inputs[0].ndim() == 4U || + inputs[0].ndim() == 5U || + inputs[0].ndim() == 6U) << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)"; + << " or 5-D input (e.g. NCHWc on for vector instructions)" + << " or 6-D input (e.g. NCHWnc for tensor accelerators)"; if (param->padding.size() == 1) { padding.push_back(padding[0]); diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index d894fc0..a4c0548 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -85,8 +85,8 @@ def _pack_weight_conv2d_transpose(data, dshape, cfactor): return data -def _pack_bias(data, dshape, dtype, bfactor, cfactor): - """Pack the bias parameter. +def _pack_const(data, dshape, dtype, bfactor, cfactor): + """Pack a constant parameter. """ dshape = _to_shape(dshape) assert len(dshape) == 3 @@ -124,6 +124,7 @@ class ExprPack(ExprMutator): self.conv2d = op.op.get("nn.conv2d") self.conv2d_transpose = op.op.get("nn.conv2d_transpose") self.add = op.op.get("add") + self.multiply = op.op.get("multiply") self.bias_add = op.op.get("nn.bias_add") self.number_of_conv2d = 0 super().__init__() @@ -203,23 +204,35 @@ class ExprPack(ExprMutator): output_padding=call.attrs.output_padding, out_dtype=call.attrs.out_dtype) return conv2d - elif call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape): + elif call.op == self.add and \ + tuple(input_types[0].shape) == tuple(input_types[1].shape): pass elif call.op == self.add and len(input_types[1].shape) == 3: - data, bias = args - bias = _pack_bias(bias, - _to_shape(input_types[1].shape), - input_types[1].dtype, - self.bfactor, - self.cfactor) - return relay.Call(self.add, [data, bias]) + data, const = args + const = _pack_const(const, + _to_shape(input_types[1].shape), + input_types[1].dtype, + self.bfactor, + self.cfactor) + return relay.Call(self.add, [data, const]) + elif call.op == self.multiply and \ + tuple(input_types[0].shape) == tuple(input_types[1].shape): + pass + elif call.op == self.multiply and len(input_types[1].shape) == 3: + data, const = args + const = _pack_const(const, + _to_shape(input_types[1].shape), + input_types[1].dtype, + self.bfactor, + self.cfactor) + return relay.Call(self.multiply, [data, const]) elif self.start_pack and call.op == self.bias_add: data, bias = args - bias = _pack_bias(bias, - _to_shape(input_types[1].shape), - input_types[1].dtype, - self.bfactor, - self.cfactor) + bias = _pack_const(bias, + _to_shape(input_types[1].shape), + input_types[1].dtype, + self.bfactor, + self.cfactor) return relay.Call(self.add, [data, bias]) elif self.start_pack and call.op == op.op.get('cast') and \ input_types[0].dtype == 'int32': diff --git a/vta/tutorials/frontend/deploy_resnet_on_vta.py b/vta/tutorials/frontend/deploy_vision_on_vta.py similarity index 87% rename from vta/tutorials/frontend/deploy_resnet_on_vta.py rename to vta/tutorials/frontend/deploy_vision_on_vta.py index c01f989..a508fc4 100644 --- a/vta/tutorials/frontend/deploy_resnet_on_vta.py +++ b/vta/tutorials/frontend/deploy_vision_on_vta.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. """ -Deploy Pretrained ResNet Model from MxNet on VTA +Deploy Pretrained Vision Model from MxNet on VTA ================================================ **Author**: `Thierry Moreau `_ -This tutorial provides an end-to-end demo, on how to run ResNet-18 inference -onto the VTA accelerator design to perform ImageNet classification tasks. +This tutorial provides an end-to-end demo, on how to run ImageNet classification +inference onto the VTA accelerator design to perform ImageNet classification tasks. It showcases Relay as a front end compiler that can perform quantization (VTA only supports int8/32 inference) as well as graph packing (in order to enable tensorization in the core) to massage the compute graph for the hardware target. @@ -40,7 +40,7 @@ tensorization in the core) to massage the compute graph for the hardware target. from __future__ import absolute_import, print_function -import argparse, json, os, requests, time +import argparse, json, os, requests, sys, time from io import BytesIO from os.path import join, isfile from PIL import Image @@ -53,6 +53,7 @@ import tvm from tvm import rpc, autotvm, relay from tvm.contrib import graph_runtime, util, download from tvm.contrib.debugger import debug_runtime +from tvm.relay import transform import vta from vta.testing import simulator @@ -61,7 +62,6 @@ from vta.top import graph_pack # Make sure that TVM was compiled with RPC=1 assert tvm.module.enabled("rpc") - ###################################################################### # Define the platform and model targets # ------------------------------------- @@ -75,13 +75,22 @@ env = vta.get_env() device = "vta" target = env.target if device == "vta" else env.target_vta_cpu +# Dictionary lookup for when to start/end bit packing +pack_dict = { + "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], + "resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"], +} + # Name of Gluon model to compile # The ``start_pack`` and ``stop_pack`` labels indicate where # to start and end the graph packing relay pass: in other words # where to start and finish offloading to VTA. model = "resnet18_v1" -start_pack="nn.max_pool2d" -stop_pack="nn.global_avg_pool2d" +assert model in pack_dict ###################################################################### # Obtain an execution remote @@ -125,7 +134,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) ###################################################################### # Build the inference graph runtime # --------------------------------- -# Grab ResNet-18 model from Gluon model zoo and compile with Relay. +# Grab vision model from Gluon model zoo and compile with Relay. # The compilation steps are: # 1) Front end translation from MxNet into Relay module. # 2) Apply 8-bit quantization: here we skip the first conv layer, @@ -140,7 +149,7 @@ ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) # Load pre-configured AutoTVM schedules with autotvm.tophub.context(target): - # Populate the shape and data type dictionary for ResNet input + # Populate the shape and data type dictionary for ImageNet classifier input dtype_dict = {"data": 'float32'} shape_dict = {"data": (env.BATCH, 3, 224, 224)} @@ -157,21 +166,22 @@ with autotvm.tophub.context(target): shape_dict.update({k: v.shape for k, v in params.items()}) dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) - # Perform quantization in Relay - with relay.quantize.qconfig(global_scale=8.0, - skip_conv_layers=[0]): - relay_prog = relay.quantize.quantize(mod["main"], params=params) - - # Perform graph packing and constant folding for VTA target if target.device_name == "vta": + # Perform quantization in Relay + with relay.quantize.qconfig(global_scale=8.0, + skip_conv_layers=[0]): + relay_prog = relay.quantize.quantize(mod["main"], params=params) + # Perform graph packing and constant folding for VTA target assert env.BLOCK_IN == env.BLOCK_OUT relay_prog = graph_pack( relay_prog, env.BATCH, env.BLOCK_OUT, env.WGT_WIDTH, - start_name=start_pack, - stop_name=stop_pack) + start_name=pack_dict[model][0], + stop_name=pack_dict[model][1]) + else: + relay_prog = mod["main"] # Compile Relay program with AlterOpLayout disabled with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): @@ -199,8 +209,8 @@ with autotvm.tophub.context(target): m = graph_runtime.create(graph, lib, ctx) ###################################################################### -# Perform ResNet-18 inference -# --------------------------- +# Perform image classification inference +# -------------------------------------- # We run classification on an image sample from ImageNet # We just need to download the categories files, `synset.txt` # and an input test image. @@ -256,7 +266,6 @@ else: tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0))) for b in range(env.BATCH): top_categories = np.argsort(tvm_output.asnumpy()[b]) - # Report top-5 classification results print("\n{} prediction for sample {}".format(model, b)) print("\t#1:", synset[top_categories[-1]]) @@ -264,7 +273,6 @@ for b in range(env.BATCH): print("\t#3:", synset[top_categories[-3]]) print("\t#4:", synset[top_categories[-4]]) print("\t#5:", synset[top_categories[-5]]) - # This just checks that one of the 5 top categories # is one variety of cat; this is by no means an accurate # assessment of how quantization affects classification -- 2.7.4