[VTA] YoloV3 Support (#4887)
authorHua Jiang <huaj@xilinx.com>
Wed, 26 Feb 2020 23:52:28 +0000 (15:52 -0800)
committerGitHub <noreply@github.com>
Wed, 26 Feb 2020 23:52:28 +0000 (15:52 -0800)
* [VTA] YoloV3 Support

Issue:
YoloV3 use some operator and logic that not get good support by
existing vta logic, like nn.pad, upsample, and 255 output channel.

Solution:
add related logic to let darknet YoloV3 can running on VTA

* Fix small(0, or 1 heigh/width) detect frame issue.

* add yolov3-tiny turtorial

* add os import

* address review comments.

* rename tutorial file with a short name.

* rename deploy_vision_on_vta.py into deploy_classification.py.

* address review comment, fix plint eror in deploy_detection.py

vta/python/vta/top/graphpack.py
vta/tutorials/frontend/deploy_classification.py [new file with mode: 0644]
vta/tutorials/frontend/deploy_detection.py [new file with mode: 0644]
vta/tutorials/frontend/deploy_vision_on_vta.py [deleted file]

index 76b3dc54b1133f3ebaa47a7a6ef23fe74d51efc8..2689fbcb6ec7f88bd3d5f85e138d7ecae040c650 100644 (file)
@@ -31,6 +31,8 @@ def run_opt_pass(expr, opt_pass):
     return entry if isinstance(expr, relay.Function) else entry.body
 
 def _to_shape(shape):
+    """ convert shape into tuple.
+    """
     return tuple(int(sh) for sh in shape)
 
 def _pack_batch_channel(data, dshape, bfactor, cfactor):
@@ -55,6 +57,49 @@ def _unpack_batch_channel(data, old_shape):
     return data
 
 
+def _const_shape_match(data, dshape, cfactor_out):
+    """ Pad the constant if the shape[0] not divisible by cfactor_out.
+    """
+    assert len(dshape) == 3
+    pad_width = int(dshape[0]) % cfactor_out
+    if pad_width != 0:
+        pad_width = cfactor_out -pad_width
+        data = op.nn.pad(data, [[0, pad_width], [0, 0], [0, 0]])
+        dshape = tuple([dshape[0] + pad_width, dshape[1], dshape[2]])
+    return data, dshape
+
+def _weight_shape_match(data, dshape, channels, cfactor_out, transpose=False):
+    """ Pad the weight if the shape[0] not divisible by cfactor_out.
+    """
+    assert len(dshape) == 4
+    pad_width = int(dshape[0]) % cfactor_out
+    channels_pad = int(channels) % cfactor_out
+    if pad_width != 0:
+        pad_width = cfactor_out - pad_width
+        data = op.nn.pad(data, [[0, pad_width], [0, 0], [0, 0], [0, 0]])
+        dshape = tuple([dshape[0] + pad_width, dshape[1], dshape[2], dshape[3]])
+
+    if channels_pad != 0:
+        channels = channels + (cfactor_out - channels_pad)
+
+    return data, dshape, channels
+
+def _weight_shape_match_transpose(data, dshape, channels, cfactor_out):
+    """ Pad the weight if the shape[1] not divisible by cfactor_out.
+    """
+    assert len(dshape) == 4
+    pad_width = int(dshape[1]) % cfactor_out
+    channels_pad = int(channels) % cfactor_out
+    if pad_width != 0:
+        pad_width = cfactor_out - pad_width
+        data = op.nn.pad(data, [[0, 0], [0, pad_width], [0, 0], [0, 0]])
+        dshape = tuple(dshape[0], [dshape[1] + pad_width, dshape[2], dshape[3]])
+
+    if channels_pad != 0:
+        channels = channels + (cfactor_out - channels_pad)
+
+    return data, dshape, channels
+
 def _pack_weight(data, dshape, cfactor):
     """Pack the weight into packed format.
     """
@@ -106,10 +151,19 @@ def _pack_const(data, dshape, dtype, bfactor, cfactor):
     return data
 
 
-def _get_shape(node):
-    """Get the shape of a node.
+def _get_tensor_shape(node):
+    """Get node shape.
     """
-    return _to_shape(node.checked_type.shape)
+    if isinstance(node.checked_type, relay.ty.TensorType):
+        return _to_shape(node.checked_type.shape)
+    return []
+
+def _get_tensor_type(node):
+    """Get node type.
+    """
+    if isinstance(node.checked_type, relay.ty.TensorType):
+        return node.checked_type.dtype
+    return "float32"
 
 def _operator_idx_inc(expr, count_meta, operator_current_idx):
     """Increase operator index
@@ -136,14 +190,17 @@ class ExprPack(ExprMutator):
         self.add = op.op.get("add")
         self.multiply = op.op.get("multiply")
         self.bias_add = op.op.get("nn.bias_add")
+        self.pad = op.op.get("nn.pad")
+        self.upsampling = op.op.get("nn.upsampling")
+        self.reshape = op.op.get("reshape")
         self.number_of_conv2d = 0
         super().__init__()
 
     def visit_call(self, call):
         """ Visit the children. """
         # First visit the children.
-        oshape = _get_shape(call)
-        odtype = call.checked_type.dtype
+        oshape = _get_tensor_shape(call)
+        odtype = _get_tensor_type(call)
         input_types = [arg.checked_type for arg in call.args]
         args = [self.visit(arg) for arg in call.args]
 
@@ -156,7 +213,7 @@ class ExprPack(ExprMutator):
             if self.start_pack:
                 self.start_pack = False
                 data = args[0]
-                data_shape = _get_shape(call.args[0])
+                data_shape = _get_tensor_shape(call.args[0])
                 return _unpack_batch_channel(data, data_shape)
         if self.start_pack:
             # Operator cases
@@ -169,11 +226,17 @@ class ExprPack(ExprMutator):
                 data, weight = args
                 data_shape = _to_shape(input_types[0].shape)
                 kernel_shape = _to_shape(input_types[1].shape)
+                channels = call.attrs.channels
+                weight, kernel_shape, channels = _weight_shape_match(weight,
+                                                                     kernel_shape,
+                                                                     channels,
+                                                                     self.cfactor)
                 kernel = _pack_weight(weight, kernel_shape, self.cfactor)
                 # insert bit packing when necessary
                 if w_lanes != 1:
                     assert 8 % w_lanes == 0
                     kernel = op.bitpack(kernel, lanes=w_lanes)
+
                 conv2d = op.nn.conv2d(
                     data,
                     kernel,
@@ -181,7 +244,7 @@ class ExprPack(ExprMutator):
                     padding=call.attrs.padding,
                     dilation=call.attrs.dilation,
                     groups=call.attrs.groups,
-                    channels=call.attrs.channels,
+                    channels=channels,
                     kernel_size=call.attrs.kernel_size,
                     data_layout=data_layout,
                     kernel_layout=kernel_layout,
@@ -198,6 +261,11 @@ class ExprPack(ExprMutator):
                     data, weight = args
                     data_shape = _to_shape(input_types[0].shape)
                     kernel_shape = _to_shape(input_types[1].shape)
+                    channels = call.attrs.channels
+                    weight, kernel_shape, channels = _weight_shape_match_transpose(weight,
+                                                                                   kernel_shape,
+                                                                                   channels,
+                                                                                   self.cfactor)
                     kernel = _pack_weight_conv2d_transpose(weight, kernel_shape, self.cfactor)
                     conv2d = op.nn.conv2d_transpose(
                         data,
@@ -218,8 +286,11 @@ class ExprPack(ExprMutator):
                 pass
             elif call.op == self.add and len(input_types[1].shape) == 3:
                 data, const = args
+                const, input_shape = _const_shape_match(const,
+                                                        input_types[1].shape,
+                                                        self.cfactor)
                 const = _pack_const(const,
-                                    _to_shape(input_types[1].shape),
+                                    _to_shape(input_shape),
                                     input_types[1].dtype,
                                     self.bfactor,
                                     self.cfactor)
@@ -247,6 +318,36 @@ class ExprPack(ExprMutator):
                     input_types[0].dtype == 'int32':
                 cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs)
                 return relay.Call(op.op.get('copy'), [cast])
+            elif call.op == self.pad:
+                pad_width = call.attrs.pad_width
+                if len(pad_width) == 6:
+                    pass
+                elif len(pad_width) == 4:
+                    data, = args
+                    new_pad_width = []
+                    new_pad_width.extend(pad_width)
+                    for _ in range(2):
+                        new_pad_width.append([0, 0])
+                    return op.nn.pad(data,
+                                     pad_value=call.attrs.pad_value,
+                                     pad_width=new_pad_width)
+            elif call.op == self.upsampling:
+                data, = args
+                scale_h = call.attrs.scale_h
+                scale_w = call.attrs.scale_w
+                data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
+                method = call.attrs.method
+                align_corners = call.attrs.align_corners
+                return op.nn.upsampling(data,
+                                        scale_h,
+                                        scale_w,
+                                        data_layout,
+                                        method,
+                                        align_corners)
+            elif call.op == self.reshape and len(input_types[0].shape) == 4:
+                data, = args
+                data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
+                return op.reshape(data, input_types[0].shape)
 
         return relay.Call(
             self.visit(call.op),
diff --git a/vta/tutorials/frontend/deploy_classification.py b/vta/tutorials/frontend/deploy_classification.py
new file mode 100644 (file)
index 0000000..df02b48
--- /dev/null
@@ -0,0 +1,289 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy Pretrained Vision Model from MxNet on VTA
+================================================
+**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
+
+This tutorial provides an end-to-end demo, on how to run ImageNet classification
+inference onto the VTA accelerator design to perform ImageNet classification tasks.
+It showcases Relay as a front end compiler that can perform quantization (VTA
+only supports int8/32 inference) as well as graph packing (in order to enable
+tensorization in the core) to massage the compute graph for the hardware target.
+"""
+
+######################################################################
+# Install dependencies
+# --------------------
+# To use the autotvm package in tvm, we need to install some extra dependencies.
+# (change "3" to "2" if you use python2):
+#
+# .. code-block:: bash
+#
+#   pip3 install --user mxnet requests "Pillow<7"
+#
+# Now return to the python code. Import packages.
+
+from __future__ import absolute_import, print_function
+
+import argparse, json, os, requests, sys, time
+from io import BytesIO
+from os.path import join, isfile
+from PIL import Image
+
+from mxnet.gluon.model_zoo import vision
+import numpy as np
+from matplotlib import pyplot as plt
+
+import tvm
+from tvm import rpc, autotvm, relay
+from tvm.contrib import graph_runtime, util, download
+from tvm.contrib.debugger import debug_runtime
+from tvm.relay import transform
+
+import vta
+from vta.testing import simulator
+from vta.top import graph_pack
+
+# Make sure that TVM was compiled with RPC=1
+assert tvm.runtime.enabled("rpc")
+
+######################################################################
+# Define the platform and model targets
+# -------------------------------------
+# Execute on CPU vs. VTA, and define the model.
+
+# Load VTA parameters from the vta/config/vta_config.json file
+env = vta.get_env()
+
+# Set ``device=arm_cpu`` to run inference on the CPU
+# or ``device=vta`` to run inference on the FPGA.
+device = "vta"
+target = env.target if device == "vta" else env.target_vta_cpu
+
+# Dictionary lookup for when to start/end bit packing
+pack_dict = {
+    "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+    "resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
+}
+
+# Name of Gluon model to compile
+# The ``start_pack`` and ``stop_pack`` labels indicate where
+# to start and end the graph packing relay pass: in other words
+# where to start and finish offloading to VTA.
+model = "resnet18_v1"
+assert model in pack_dict
+
+######################################################################
+# Obtain an execution remote
+# --------------------------
+# When target is 'pynq', reconfigure FPGA and runtime.
+# Otherwise, if target is 'sim', execute locally.
+
+if env.TARGET not in ["sim", "tsim"]:
+
+    # Get remote from tracker node if environment variable is set.
+    # To set up the tracker, you'll need to follow the "Auto-tuning
+    # a convolutional network for VTA" tutorial.
+    tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
+    tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
+    # Otherwise if you have a device you want to program directly from
+    # the host, make sure you've set the variables below to the IP of
+    # your board.
+    device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
+    device_port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
+    if not tracker_host or not tracker_port:
+        remote = rpc.connect(device_host, int(device_port))
+    else:
+        remote = autotvm.measure.request_remote(env.TARGET, tracker_host, int(tracker_port), timeout=10000)
+
+    # Reconfigure the JIT runtime and FPGA.
+    # You can program the FPGA with your own custom bitstream
+    # by passing the path to the bitstream file instead of None.
+    reconfig_start = time.time()
+    vta.reconfig_runtime(remote)
+    vta.program_fpga(remote, bitstream=None)
+    reconfig_time = time.time() - reconfig_start
+    print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
+
+# In simulation mode, host the RPC server locally.
+else:
+    remote = rpc.LocalSession()
+
+# Get execution context from remote
+ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
+
+######################################################################
+# Build the inference graph runtime
+# ---------------------------------
+# Grab vision model from Gluon model zoo and compile with Relay.
+# The compilation steps are:
+#
+# 1. Front end translation from MxNet into Relay module.
+# 2. Apply 8-bit quantization: here we skip the first conv layer,
+#    and dense layer which will both be executed in fp32 on the CPU.
+# 3. Perform graph packing to alter the data layout for tensorization.
+# 4. Perform constant folding to reduce number of operators (e.g. eliminate batch norm multiply).
+# 5. Perform relay build to object file.
+# 6. Load the object file onto remote (FPGA device).
+# 7. Generate graph runtime, `m`.
+#
+
+# Load pre-configured AutoTVM schedules
+with autotvm.tophub.context(target):
+
+    # Populate the shape and data type dictionary for ImageNet classifier input
+    dtype_dict = {"data": 'float32'}
+    shape_dict = {"data": (env.BATCH, 3, 224, 224)}
+
+    # Get off the shelf gluon model, and convert to relay
+    gluon_model = vision.get_model(model, pretrained=True)
+
+    # Measure build start time
+    build_start = time.time()
+
+    # Start front end compilation
+    mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
+
+    # Update shape and type dictionary
+    shape_dict.update({k: v.shape for k, v in params.items()})
+    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
+
+    if target.device_name == "vta":
+        # Perform quantization in Relay
+        # Note: We set opt_level to 3 in order to fold batch norm
+        with relay.build_config(opt_level=3):
+            with relay.quantize.qconfig(global_scale=8.0,
+                                        skip_conv_layers=[0]):
+                mod = relay.quantize.quantize(mod, params=params)
+            # Perform graph packing and constant folding for VTA target
+            assert env.BLOCK_IN == env.BLOCK_OUT
+            relay_prog = graph_pack(
+                mod["main"],
+                env.BATCH,
+                env.BLOCK_OUT,
+                env.WGT_WIDTH,
+                start_name=pack_dict[model][0],
+                stop_name=pack_dict[model][1])
+    else:
+        relay_prog = mod["main"]
+
+    # Compile Relay program with AlterOpLayout disabled
+    with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
+        if target.device_name != "vta":
+            graph, lib, params = relay.build(
+                relay_prog, target=target,
+                params=params, target_host=env.target_host)
+        else:
+            with vta.build_config():
+                graph, lib, params = relay.build(
+                    relay_prog, target=target,
+                    params=params, target_host=env.target_host)
+
+    # Measure Relay build time
+    build_time = time.time() - build_start
+    print(model + " inference graph built in {0:.2f}s!".format(build_time))
+
+    # Send the inference library over to the remote RPC server
+    temp = util.tempdir()
+    lib.save(temp.relpath("graphlib.o"))
+    remote.upload(temp.relpath("graphlib.o"))
+    lib = remote.load_module("graphlib.o")
+
+    # Graph runtime
+    m = graph_runtime.create(graph, lib, ctx)
+
+######################################################################
+# Perform image classification inference
+# --------------------------------------
+# We run classification on an image sample from ImageNet
+# We just need to download the categories files, `synset.txt`
+# and an input test image.
+
+# Download ImageNet categories
+categ_url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
+categ_fn = "synset.txt"
+download.download(join(categ_url, categ_fn), categ_fn)
+synset = eval(open(categ_fn).read())
+
+# Download test image
+image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
+image_fn = 'cat.png'
+download.download(image_url, image_fn)
+
+# Prepare test image for inference
+image = Image.open(image_fn).resize((224, 224))
+plt.imshow(image)
+plt.show()
+image = np.array(image) - np.array([123., 117., 104.])
+image /= np.array([58.395, 57.12, 57.375])
+image = image.transpose((2, 0, 1))
+image = image[np.newaxis, :]
+image = np.repeat(image, env.BATCH, axis=0)
+
+# Set the network parameters and inputs
+m.set_input(**params)
+m.set_input('data', image)
+
+# Perform inference and gather execution statistics
+# More on: https://docs.tvm.ai/api/python/module.html#tvm.runtime.Module.time_evaluator
+num = 4 # number of times we run module for a single measurement
+rep = 3 # number of measurements (we derive std dev from this)
+timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
+
+if env.TARGET in ["sim", "tsim"]:
+    simulator.clear_stats()
+    timer()
+    sim_stats = simulator.stats()
+    print("\nExecution statistics:")
+    for k, v in sim_stats.items():
+        # Since we execute the workload many times, we need to normalize stats
+        # Note that there is always one warm up run
+        # Therefore we divide the overall stats by (num * rep + 1)
+        print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1)))
+else:
+    tcost = timer()
+    std = np.std(tcost.results) * 1000
+    mean = tcost.mean * 1000
+    print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH))
+    print("Average per sample inference time: %.2fms" % (mean/env.BATCH))
+
+# Get classification results
+tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
+for b in range(env.BATCH):
+    top_categories = np.argsort(tvm_output.asnumpy()[b])
+    # Report top-5 classification results
+    print("\n{} prediction for sample {}".format(model, b))
+    print("\t#1:", synset[top_categories[-1]])
+    print("\t#2:", synset[top_categories[-2]])
+    print("\t#3:", synset[top_categories[-3]])
+    print("\t#4:", synset[top_categories[-4]])
+    print("\t#5:", synset[top_categories[-5]])
+    # This just checks that one of the 5 top categories
+    # is one variety of cat; this is by no means an accurate
+    # assessment of how quantization affects classification
+    # accuracy but is meant to catch changes to the
+    # quantization pass that would accuracy in the CI.
+    cat_detected = False
+    for k in top_categories[-5:]:
+        if "cat" in synset[k]:
+            cat_detected = True
+    assert(cat_detected)
diff --git a/vta/tutorials/frontend/deploy_detection.py b/vta/tutorials/frontend/deploy_detection.py
new file mode 100644 (file)
index 0000000..09d8465
--- /dev/null
@@ -0,0 +1,330 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Deploy Pretrained Vision Detection Model from Darknet on VTA
+================================================
+**Author**: `Hua Jiang <https://github.com/huajsj>`_
+
+This tutorial provides an end-to-end demo, on how to run Darknet YoloV3-tiny
+inference onto the VTA accelerator design to perform Image detection tasks.
+It showcases Relay as a front end compiler that can perform quantization (VTA
+only supports int8/32 inference) as well as graph packing (in order to enable
+tensorization in the core) to massage the compute graph for the hardware target.
+"""
+
+######################################################################
+# Install dependencies
+# --------------------
+# To use the autotvm package in tvm, we need to install some extra dependencies.
+# (change "3" to "2" if you use python2):
+#
+# .. code-block:: bash
+#
+# pip3 install "Pillow<7"
+#
+# YOLO-V3-tiny Model with Darknet parsing have dependancy with CFFI and CV2 library,
+# we need to install CFFI and CV2 before executing this script.
+#
+# pip3 install "Pillow<7"
+#
+# pip3 install cffi
+# pip3 install opencv-python
+#
+# Now return to the python code. Import packages.
+
+from __future__ import absolute_import, print_function
+
+import sys
+import os
+import time
+import matplotlib.pyplot as plt
+import numpy as np
+import tvm
+import vta
+from tvm import rpc, autotvm, relay
+from tvm.relay.testing import yolo_detection, darknet
+from tvm.relay.testing.darknet import __darknetffi__
+from tvm.contrib import graph_runtime, graph_runtime, util
+from tvm.contrib.download import download_testdata
+from vta.testing import simulator
+from vta.top import graph_pack
+# Make sure that TVM was compiled with RPC=1
+assert tvm.runtime.enabled("rpc")
+
+##############################################################################
+# Download yolo net configure file, weight file, darknet library file based on
+# Model Name
+# ----------------------------------------------------------------------------
+MODEL_NAME = 'yolov3-tiny'
+REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/'
+
+cfg_path = download_testdata('https://github.com/pjreddie/darknet/blob/master/cfg/'
+                             + MODEL_NAME + '.cfg' + '?raw=true',
+                             MODEL_NAME + '.cfg',
+                             module="darknet")
+weights_path = download_testdata('https://pjreddie.com/media/files/'
+                                 + MODEL_NAME + '.weights' + '?raw=true',
+                                 MODEL_NAME + '.weights',
+                                 module="darknet")
+
+if sys.platform in ['linux', 'linux2']:
+    darknet_lib_path = download_testdata(REPO_URL + 'lib/' + 'libdarknet2.0.so' + '?raw=true',
+                                         'libdarknet2.0.so',
+                                         module="darknet")
+elif sys.platform == 'darwin':
+    darknet_lib_path = download_testdata(REPO_URL+'lib_osx/'+'libdarknet_mac2.0.so'+'?raw=true',
+                                         'libdarknet_mac2.0.so',
+                                         module="darknet")
+else:
+    raise NotImplementedError("Darknet lib is not supported on {} platform"
+                              .format(sys.platform))
+
+##################################################
+# Download yolo categories and illustration front.
+# ------------------------------------------------
+coco_path = download_testdata(REPO_URL + 'data/' + 'coco.names' + '?raw=true',
+                              'coco.names',
+                              module='data')
+font_path = download_testdata(REPO_URL + 'data/' + 'arial.ttf' + '?raw=true',
+                              'arial.ttf',
+                              module='data')
+with open(coco_path) as f:
+    content = f.readlines()
+names = [x.strip() for x in content]
+
+########################################
+# Define the platform and model targets.
+# --------------------------------------
+# Execute on CPU vs. VTA, and define the model.
+
+# Load VTA parameters from the vta/config/vta_config.json file
+env = vta.get_env()
+# Set ``device=arm_cpu`` to run inference on the CPU
+# or ``device=vta`` to run inference on the FPGA.
+device = "vta"
+target = env.target if device == "vta" else env.target_vta_cpu
+
+pack_dict = {
+    "yolov3-tiny": ["nn.max_pool2d", "cast", 4, 185],
+}
+
+# Name of Darknet model to compile
+# The ``start_pack`` and ``stop_pack`` labels indicate where
+# to start and end the graph packing relay pass: in other words
+# where to start and finish offloading to VTA.
+# the number 4 indicate the the ``start_pack`` index is 4, the
+# number 185 indicate the ``stop_pack index`` is 185, by using
+# name and index number, here we can located to correct place
+# where to start/end when there are multiple ``nn.max_pool2d``
+# or ``cast``, print(mod.astext(show_meta_data=False)) can help
+# to find operator name and index information.
+assert MODEL_NAME in pack_dict
+
+#############################
+# Obtain an execution remote.
+# ---------------------------
+# When target is 'pynq' or other FPGA backend, reconfigure FPGA and runtime.
+# Otherwise, if target is 'sim', execute locally.
+
+if env.TARGET not in ["sim", "tsim"]:
+    # Get remote from tracker node if environment variable is set.
+    # To set up the tracker, you'll need to follow the "Auto-tuning
+    # a convolutional network for VTA" tutorial.
+    tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
+    tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
+    # Otherwise if you have a device you want to program directly from
+    # the host, make sure you've set the variables below to the IP of
+    # your board.
+    device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
+    device_port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
+    if not tracker_host or not tracker_port:
+        remote = rpc.connect(device_host, int(device_port))
+    else:
+        remote = autotvm.measure.request_remote(env.TARGET,
+                                                tracker_host,
+                                                int(tracker_port),
+                                                timeout=10000)
+    # Reconfigure the JIT runtime and FPGA.
+    # You can program the FPGA with your own custom bitstream
+    # by passing the path to the bitstream file instead of None.
+    reconfig_start = time.time()
+    vta.reconfig_runtime(remote)
+    vta.program_fpga(remote, bitstream=None)
+    reconfig_time = time.time() - reconfig_start
+    print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
+
+# In simulation mode, host the RPC server locally.
+else:
+    remote = rpc.LocalSession()
+
+# Get execution context from remote
+ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
+
+####################################
+# Build the inference graph runtime.
+# ----------------------------------
+# Using Darknet library load downloaded vision model and compile with Relay.
+# The compilation steps are:
+#
+# 1. Front end translation from Darknet into Relay module.
+# 2. Apply 8-bit quantization: here we skip the first conv layer,
+#    and dense layer which will both be executed in fp32 on the CPU.
+# 3. Perform graph packing to alter the data layout for tensorization.
+# 4. Perform constant folding to reduce number of operators (e.g. eliminate batch norm multiply).
+# 5. Perform relay build to object file.
+# 6. Load the object file onto remote (FPGA device).
+# 7. Generate graph runtime, `m`.
+#
+
+# Load pre-configured AutoTVM schedules
+with autotvm.tophub.context(target):
+    net = __darknetffi__.dlopen(darknet_lib_path).load_network(cfg_path.encode('utf-8'),
+                                                               weights_path.encode('utf-8'),
+                                                               0)
+    dshape = (env.BATCH, net.c, net.h, net.w)
+    dtype = 'float32'
+
+    # Measure build start time
+    build_start = time.time()
+
+    # Start front end compilation
+    mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=dshape)
+
+    if target.device_name == "vta":
+    # Perform quantization in Relay
+    # Note: We set opt_level to 3 in order to fold batch norm
+        with relay.build_config(opt_level=3):
+            with relay.quantize.qconfig(global_scale=33.0,
+                                        skip_conv_layers=[0],
+                                        store_lowbit_output=True,
+                                        round_for_shift=True):
+                mod = relay.quantize.quantize(mod, params=params)
+            # Perform graph packing and constant folding for VTA target
+            mod = graph_pack(
+                mod["main"],
+                env.BATCH,
+                env.BLOCK_OUT,
+                env.WGT_WIDTH,
+                start_name=pack_dict[MODEL_NAME][0],
+                stop_name=pack_dict[MODEL_NAME][1],
+                start_name_idx=pack_dict[MODEL_NAME][2],
+                stop_name_idx=pack_dict[MODEL_NAME][3])
+    else:
+        mod = mod["main"]
+
+    # Compile Relay program with AlterOpLayout disabled
+    with vta.build_config(disabled_pass={"AlterOpLayout"}):
+        graph, lib, params = relay.build(
+            mod,
+            target=target,
+            params=params,
+            target_host=env.target_host)
+
+    # Measure Relay build time
+    build_time = time.time() - build_start
+    print(MODEL_NAME + " inference graph built in {0:.2f}s!".format(build_time))
+
+    # Send the inference library over to the remote RPC server
+    temp = util.tempdir()
+    lib.save(temp.relpath("graphlib.o"))
+    remote.upload(temp.relpath("graphlib.o"))
+    lib = remote.load_module("graphlib.o")
+
+    # Graph runtime
+    m = graph_runtime.create(graph, lib, ctx)
+
+####################################
+# Perform image detection inference.
+# ----------------------------------
+# We run detect on an downloaded image
+# Download test image
+[neth, netw] = dshape[2:]
+test_image = 'person.jpg'
+img_url = REPO_URL + 'data/' + test_image + '?raw=true'
+img_path = download_testdata(img_url, test_image, "data")
+data = darknet.load_image(img_path, neth, netw).transpose(1, 2, 0)
+
+# Prepare test image for inference
+plt.imshow(data)
+plt.show()
+data = data.transpose((2, 0, 1))
+data = data[np.newaxis, :]
+data = np.repeat(data, env.BATCH, axis=0)
+
+# Set the network parameters and inputs
+m.set_input('data', data)
+m.set_input(**params)
+
+# Perform inference and gather execution statistics
+# More on: https://docs.tvm.ai/api/python/module.html#tvm.runtime.Module.time_evaluator
+num = 4 # number of times we run module for a single measurement
+rep = 3 # number of measurements (we derive std dev from this)
+timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
+
+if env.TARGET in ["sim", "tsim"]:
+    simulator.clear_stats()
+    timer()
+    sim_stats = simulator.stats()
+    print("\nExecution statistics:")
+    for k, v in sim_stats.items():
+        # Since we execute the workload many times, we need to normalize stats
+        # Note that there is always one warm up run
+        # Therefore we divide the overall stats by (num * rep + 1)
+        print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1)))
+else:
+    tcost = timer()
+    std = np.std(tcost.results) * 1000
+    mean = tcost.mean * 1000
+    print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH))
+    print("Average per sample inference time: %.2fms" % (mean/env.BATCH))
+
+# Get detection results from out
+thresh = 0.5
+nms_thresh = 0.45
+tvm_out = []
+for i in range(2):
+    layer_out = {}
+    layer_out['type'] = 'Yolo'
+    # Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total)
+    layer_attr = m.get_output(i*4+3).asnumpy()
+    layer_out['biases'] = m.get_output(i*4+2).asnumpy()
+    layer_out['mask'] = m.get_output(i*4+1).asnumpy()
+    out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0],
+                 layer_attr[2], layer_attr[3])
+    layer_out['output'] = m.get_output(i*4).asnumpy().reshape(out_shape)
+    layer_out['classes'] = layer_attr[4]
+    tvm_out.append(layer_out)
+    thresh = 0.560
+
+# Show detection results
+img = darknet.load_image_color(img_path)
+_, im_h, im_w = img.shape
+dets = yolo_detection.fill_network_boxes((netw, neth),
+                                         (im_w, im_h),
+                                         thresh,
+                                         1,
+                                         tvm_out)
+last_layer = net.layers[net.n - 1]
+yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh)
+yolo_detection.draw_detections(font_path,
+                               img,
+                               dets,
+                               thresh,
+                               names,
+                               last_layer.classes)
+plt.imshow(img.transpose(1, 2, 0))
+plt.show()
diff --git a/vta/tutorials/frontend/deploy_vision_on_vta.py b/vta/tutorials/frontend/deploy_vision_on_vta.py
deleted file mode 100644 (file)
index df02b48..0000000
+++ /dev/null
@@ -1,289 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Deploy Pretrained Vision Model from MxNet on VTA
-================================================
-**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
-
-This tutorial provides an end-to-end demo, on how to run ImageNet classification
-inference onto the VTA accelerator design to perform ImageNet classification tasks.
-It showcases Relay as a front end compiler that can perform quantization (VTA
-only supports int8/32 inference) as well as graph packing (in order to enable
-tensorization in the core) to massage the compute graph for the hardware target.
-"""
-
-######################################################################
-# Install dependencies
-# --------------------
-# To use the autotvm package in tvm, we need to install some extra dependencies.
-# (change "3" to "2" if you use python2):
-#
-# .. code-block:: bash
-#
-#   pip3 install --user mxnet requests "Pillow<7"
-#
-# Now return to the python code. Import packages.
-
-from __future__ import absolute_import, print_function
-
-import argparse, json, os, requests, sys, time
-from io import BytesIO
-from os.path import join, isfile
-from PIL import Image
-
-from mxnet.gluon.model_zoo import vision
-import numpy as np
-from matplotlib import pyplot as plt
-
-import tvm
-from tvm import rpc, autotvm, relay
-from tvm.contrib import graph_runtime, util, download
-from tvm.contrib.debugger import debug_runtime
-from tvm.relay import transform
-
-import vta
-from vta.testing import simulator
-from vta.top import graph_pack
-
-# Make sure that TVM was compiled with RPC=1
-assert tvm.runtime.enabled("rpc")
-
-######################################################################
-# Define the platform and model targets
-# -------------------------------------
-# Execute on CPU vs. VTA, and define the model.
-
-# Load VTA parameters from the vta/config/vta_config.json file
-env = vta.get_env()
-
-# Set ``device=arm_cpu`` to run inference on the CPU
-# or ``device=vta`` to run inference on the FPGA.
-device = "vta"
-target = env.target if device == "vta" else env.target_vta_cpu
-
-# Dictionary lookup for when to start/end bit packing
-pack_dict = {
-    "resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
-    "resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
-    "resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
-    "resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
-    "resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
-    "resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
-}
-
-# Name of Gluon model to compile
-# The ``start_pack`` and ``stop_pack`` labels indicate where
-# to start and end the graph packing relay pass: in other words
-# where to start and finish offloading to VTA.
-model = "resnet18_v1"
-assert model in pack_dict
-
-######################################################################
-# Obtain an execution remote
-# --------------------------
-# When target is 'pynq', reconfigure FPGA and runtime.
-# Otherwise, if target is 'sim', execute locally.
-
-if env.TARGET not in ["sim", "tsim"]:
-
-    # Get remote from tracker node if environment variable is set.
-    # To set up the tracker, you'll need to follow the "Auto-tuning
-    # a convolutional network for VTA" tutorial.
-    tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
-    tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
-    # Otherwise if you have a device you want to program directly from
-    # the host, make sure you've set the variables below to the IP of
-    # your board.
-    device_host = os.environ.get("VTA_PYNQ_RPC_HOST", "192.168.2.99")
-    device_port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
-    if not tracker_host or not tracker_port:
-        remote = rpc.connect(device_host, int(device_port))
-    else:
-        remote = autotvm.measure.request_remote(env.TARGET, tracker_host, int(tracker_port), timeout=10000)
-
-    # Reconfigure the JIT runtime and FPGA.
-    # You can program the FPGA with your own custom bitstream
-    # by passing the path to the bitstream file instead of None.
-    reconfig_start = time.time()
-    vta.reconfig_runtime(remote)
-    vta.program_fpga(remote, bitstream=None)
-    reconfig_time = time.time() - reconfig_start
-    print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
-
-# In simulation mode, host the RPC server locally.
-else:
-    remote = rpc.LocalSession()
-
-# Get execution context from remote
-ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
-
-######################################################################
-# Build the inference graph runtime
-# ---------------------------------
-# Grab vision model from Gluon model zoo and compile with Relay.
-# The compilation steps are:
-#
-# 1. Front end translation from MxNet into Relay module.
-# 2. Apply 8-bit quantization: here we skip the first conv layer,
-#    and dense layer which will both be executed in fp32 on the CPU.
-# 3. Perform graph packing to alter the data layout for tensorization.
-# 4. Perform constant folding to reduce number of operators (e.g. eliminate batch norm multiply).
-# 5. Perform relay build to object file.
-# 6. Load the object file onto remote (FPGA device).
-# 7. Generate graph runtime, `m`.
-#
-
-# Load pre-configured AutoTVM schedules
-with autotvm.tophub.context(target):
-
-    # Populate the shape and data type dictionary for ImageNet classifier input
-    dtype_dict = {"data": 'float32'}
-    shape_dict = {"data": (env.BATCH, 3, 224, 224)}
-
-    # Get off the shelf gluon model, and convert to relay
-    gluon_model = vision.get_model(model, pretrained=True)
-
-    # Measure build start time
-    build_start = time.time()
-
-    # Start front end compilation
-    mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
-
-    # Update shape and type dictionary
-    shape_dict.update({k: v.shape for k, v in params.items()})
-    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
-
-    if target.device_name == "vta":
-        # Perform quantization in Relay
-        # Note: We set opt_level to 3 in order to fold batch norm
-        with relay.build_config(opt_level=3):
-            with relay.quantize.qconfig(global_scale=8.0,
-                                        skip_conv_layers=[0]):
-                mod = relay.quantize.quantize(mod, params=params)
-            # Perform graph packing and constant folding for VTA target
-            assert env.BLOCK_IN == env.BLOCK_OUT
-            relay_prog = graph_pack(
-                mod["main"],
-                env.BATCH,
-                env.BLOCK_OUT,
-                env.WGT_WIDTH,
-                start_name=pack_dict[model][0],
-                stop_name=pack_dict[model][1])
-    else:
-        relay_prog = mod["main"]
-
-    # Compile Relay program with AlterOpLayout disabled
-    with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
-        if target.device_name != "vta":
-            graph, lib, params = relay.build(
-                relay_prog, target=target,
-                params=params, target_host=env.target_host)
-        else:
-            with vta.build_config():
-                graph, lib, params = relay.build(
-                    relay_prog, target=target,
-                    params=params, target_host=env.target_host)
-
-    # Measure Relay build time
-    build_time = time.time() - build_start
-    print(model + " inference graph built in {0:.2f}s!".format(build_time))
-
-    # Send the inference library over to the remote RPC server
-    temp = util.tempdir()
-    lib.save(temp.relpath("graphlib.o"))
-    remote.upload(temp.relpath("graphlib.o"))
-    lib = remote.load_module("graphlib.o")
-
-    # Graph runtime
-    m = graph_runtime.create(graph, lib, ctx)
-
-######################################################################
-# Perform image classification inference
-# --------------------------------------
-# We run classification on an image sample from ImageNet
-# We just need to download the categories files, `synset.txt`
-# and an input test image.
-
-# Download ImageNet categories
-categ_url = "https://github.com/uwsaml/web-data/raw/master/vta/models/"
-categ_fn = "synset.txt"
-download.download(join(categ_url, categ_fn), categ_fn)
-synset = eval(open(categ_fn).read())
-
-# Download test image
-image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg'
-image_fn = 'cat.png'
-download.download(image_url, image_fn)
-
-# Prepare test image for inference
-image = Image.open(image_fn).resize((224, 224))
-plt.imshow(image)
-plt.show()
-image = np.array(image) - np.array([123., 117., 104.])
-image /= np.array([58.395, 57.12, 57.375])
-image = image.transpose((2, 0, 1))
-image = image[np.newaxis, :]
-image = np.repeat(image, env.BATCH, axis=0)
-
-# Set the network parameters and inputs
-m.set_input(**params)
-m.set_input('data', image)
-
-# Perform inference and gather execution statistics
-# More on: https://docs.tvm.ai/api/python/module.html#tvm.runtime.Module.time_evaluator
-num = 4 # number of times we run module for a single measurement
-rep = 3 # number of measurements (we derive std dev from this)
-timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
-
-if env.TARGET in ["sim", "tsim"]:
-    simulator.clear_stats()
-    timer()
-    sim_stats = simulator.stats()
-    print("\nExecution statistics:")
-    for k, v in sim_stats.items():
-        # Since we execute the workload many times, we need to normalize stats
-        # Note that there is always one warm up run
-        # Therefore we divide the overall stats by (num * rep + 1)
-        print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1)))
-else:
-    tcost = timer()
-    std = np.std(tcost.results) * 1000
-    mean = tcost.mean * 1000
-    print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH))
-    print("Average per sample inference time: %.2fms" % (mean/env.BATCH))
-
-# Get classification results
-tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
-for b in range(env.BATCH):
-    top_categories = np.argsort(tvm_output.asnumpy()[b])
-    # Report top-5 classification results
-    print("\n{} prediction for sample {}".format(model, b))
-    print("\t#1:", synset[top_categories[-1]])
-    print("\t#2:", synset[top_categories[-2]])
-    print("\t#3:", synset[top_categories[-3]])
-    print("\t#4:", synset[top_categories[-4]])
-    print("\t#5:", synset[top_categories[-5]])
-    # This just checks that one of the 5 top categories
-    # is one variety of cat; this is by no means an accurate
-    # assessment of how quantization affects classification
-    # accuracy but is meant to catch changes to the
-    # quantization pass that would accuracy in the CI.
-    cat_detected = False
-    for k in top_categories[-5:]:
-        if "cat" in synset[k]:
-            cat_detected = True
-    assert(cat_detected)