From d3958e114ec435b7829bdb380d7bbe7410831bad Mon Sep 17 00:00:00 2001 From: Siju Date: Sat, 25 May 2019 03:08:08 +0530 Subject: [PATCH] [RELAY]Frontend darknet (#2773) * [RELAY]Frontend darknet * CI test file updated & CI error fixed * avg_pool pad fix * Changed repo_url and doc formatting --- nnvm/python/nnvm/testing/__init__.py | 1 - nnvm/tests/python/frontend/darknet/test_forward.py | 4 +- nnvm/tutorials/from_darknet.py | 16 +- python/tvm/relay/frontend/__init__.py | 1 + python/tvm/relay/frontend/common.py | 2 +- python/tvm/relay/frontend/darknet.py | 847 +++++++++++++++++++++ python/tvm/relay/testing/__init__.py | 1 + .../nnvm => python/tvm/relay}/testing/darknet.py | 0 .../tvm/relay}/testing/yolo_detection.py | 0 tests/python/frontend/darknet/test_forward.py | 462 +++++++++++ tests/scripts/task_python_frontend.sh | 9 +- tutorials/frontend/from_darknet.py | 179 +++++ 12 files changed, 1507 insertions(+), 15 deletions(-) create mode 100644 python/tvm/relay/frontend/darknet.py rename {nnvm/python/nnvm => python/tvm/relay}/testing/darknet.py (100%) rename {nnvm/python/nnvm => python/tvm/relay}/testing/yolo_detection.py (100%) create mode 100644 tests/python/frontend/darknet/test_forward.py create mode 100644 tutorials/frontend/from_darknet.py diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py index 44b8529..41bcf83 100644 --- a/nnvm/python/nnvm/testing/__init__.py +++ b/nnvm/python/nnvm/testing/__init__.py @@ -13,5 +13,4 @@ from . import squeezenet from . import inception_v3 from . import dcgan from . import dqn -from . import yolo_detection from . import check_computation diff --git a/nnvm/tests/python/frontend/darknet/test_forward.py b/nnvm/tests/python/frontend/darknet/test_forward.py index 7f45a61..4e62ff2 100644 --- a/nnvm/tests/python/frontend/darknet/test_forward.py +++ b/nnvm/tests/python/frontend/darknet/test_forward.py @@ -27,8 +27,8 @@ from tvm.contrib import graph_runtime from tvm.contrib.download import download_testdata download_testdata.__test__ = False from nnvm import frontend -from nnvm.testing.darknet import LAYERTYPE -from nnvm.testing.darknet import __darknetffi__ +from tvm.relay.testing.darknet import LAYERTYPE +from tvm.relay.testing.darknet import __darknetffi__ import nnvm.compiler DARKNET_LIB = 'libdarknet2.0.so' diff --git a/nnvm/tutorials/from_darknet.py b/nnvm/tutorials/from_darknet.py index 857ef46..d2ab647 100644 --- a/nnvm/tutorials/from_darknet.py +++ b/nnvm/tutorials/from_darknet.py @@ -33,8 +33,8 @@ Please install CFFI and CV2 before executing this script import nnvm import nnvm.frontend.darknet -import nnvm.testing.yolo_detection -import nnvm.testing.darknet +import tvm.relay.testing.yolo_detection +import tvm.relay.testing.darknet import matplotlib.pyplot as plt import numpy as np import tvm @@ -42,7 +42,7 @@ import sys from ctypes import * from tvm.contrib.download import download_testdata -from nnvm.testing.darknet import __darknetffi__ +from tvm.relay.testing.darknet import __darknetffi__ # Model name MODEL_NAME = 'yolov3' @@ -104,7 +104,7 @@ img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + \ test_image + '?raw=true' img_path = download_testdata(img_url, test_image, "data") -data = nnvm.testing.darknet.load_image(img_path, netw, neth) +data = tvm.relay.testing.darknet.load_image(img_path, netw, neth) ###################################################################### # Execute on TVM Runtime # ---------------------- @@ -153,12 +153,12 @@ elif MODEL_NAME == 'yolov3': # do the detection and bring up the bounding boxes thresh = 0.5 nms_thresh = 0.45 -img = nnvm.testing.darknet.load_image_color(img_path) +img = tvm.relay.testing.darknet.load_image_color(img_path) _, im_h, im_w = img.shape -dets = nnvm.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, +dets = tvm.relay.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, 1, tvm_out) last_layer = net.layers[net.n - 1] -nnvm.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh) +tvm.relay.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh) coco_name = 'coco.names' coco_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + coco_name + '?raw=true' @@ -172,6 +172,6 @@ with open(coco_path) as f: names = [x.strip() for x in content] -nnvm.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) +tvm.relay.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) plt.imshow(img.transpose(1, 2, 0)) plt.show() diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index 8d308c7..76761fd 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -30,3 +30,4 @@ from .tflite import from_tflite from .coreml import from_coreml from .caffe2 import from_caffe2 from .tensorflow import from_tensorflow +from .darknet import from_darknet diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 9b89936..2347762 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -241,7 +241,7 @@ def get_relay_op(op_name): op = None else: # try search op in various modules - for candidate in (_op, _op.nn, _op.image): + for candidate in (_op, _op.nn, _op.image, _op.vision): op = getattr(candidate, op_name, None) if op is not None: break diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py new file mode 100644 index 0000000..6da3525 --- /dev/null +++ b/python/tvm/relay/frontend/darknet.py @@ -0,0 +1,847 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +""" +DarkNet symbol frontend for Relay. +""" + +from __future__ import absolute_import as _abs +from enum import Enum +import numpy as np +import tvm +from .. import ir_pass +from .. import expr as _expr +from .common import get_relay_op, new_var + +__all__ = ['from_darknet'] + +def _darknet_not_support(attr, op='relay'): + """Raise error if any operation is not supported.""" + err = "{} is not supported in {}.".format(attr, op) + raise NotImplementedError(err) + +def _get_params_prefix(opname, layer_num): + """Makes the params prefix name from opname and layer number.""" + return str(opname) + str(layer_num) + +def _get_params_name(prefix, item): + """Makes the params name for the k,v pair.""" + return prefix + '_'+ item + +def _get_param_var(params, prefix, item): + name = _get_params_name(prefix, item) + if name not in params: + raise AttributeError("{} not found in params dict.".format(name)) + return new_var(name, shape=params[name].shape, dtype=params[name].dtype) + +def _darknet_maxpooling(inputs, params, attrs, prefix): + """Process the max pool 2d operation.""" + new_attrs = {} + kernel = attrs.get('kernel') + strides = attrs.get('stride', 1) + pads = attrs.get('pad', 1) + new_attrs['pool_size'] = (kernel, kernel) + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + extra_pad_size = attrs.get('extra_pad_size', 0) + if extra_pad_size: + pad_width = ((0, 0), (0, 0), (0, extra_pad_size), (0, extra_pad_size)) + inputs = [get_relay_op('pad')(*inputs, + pad_width=pad_width, + pad_value=np.finfo(np.float32).min)] + return get_relay_op('max_pool2d')(*inputs, **new_attrs) + +def _darknet_avgpooling(inputs, params, attrs, prefix): + """Process the average pool 2d operation.""" + new_attrs = {} + kernel = attrs.get('kernel') + strides = attrs.get('stride', 1) + pads = attrs.get('pad', 0) + + new_attrs['pool_size'] = (kernel, kernel) + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + return get_relay_op('avg_pool2d')(*inputs, **new_attrs) + +def _darknet_conv2d(inputs, params, attrs, prefix): + """Process the convolution 2d operation.""" + new_attrs = {} + kernel = attrs.get('kernel') + strides = attrs.get('stride', 1) + pads = attrs.get('pad', 0) + + new_attrs['channels'] = attrs.get('num_filter') + new_attrs['kernel_size'] = (kernel, kernel) + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + new_attrs['dilation'] = attrs.get('dilate', (1, 1)) + new_attrs['groups'] = attrs.get('num_group', 1) + + weight = _get_param_var(params, prefix, 'weight') + out = get_relay_op('conv2d')(*inputs, weight=weight, **new_attrs) + + use_bias = not attrs.get('use_batchNorm', False) + if use_bias: + new_attrs = {} + new_attrs['axis'] = 1 + bias = _get_param_var(params, prefix, 'bias') + out = get_relay_op('bias_add')(out, bias=bias, **new_attrs) + else: + new_attrs = {} + new_attrs['epsilon'] = 0.000001 + gamma = _get_param_var(params, prefix, 'gamma') + beta = _get_param_var(params, prefix, 'beta') + moving_mean = _get_param_var(params, prefix, 'moving_mean') + moving_var = _get_param_var(params, prefix, 'moving_var') + out = get_relay_op('batch_norm')(out, gamma, beta, moving_mean, moving_var, **new_attrs) + + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + new_attrs['slope'] = 0.1 + out = _darknet_activations(out, None, new_attrs) + return out + +def _darknet_shortcut(inputs, params, attrs, prefix): + """Process the shortcut operation.""" + input_0 = inputs[0] + input_1 = inputs[1] + + input_0_channel = int(attrs['out_channel']) + input_1_channel = int(attrs['add_out_channel']) + input_0_size = int(attrs['out_size']) + input_1_size = int(attrs['add_out_size']) + + if input_0_size > input_1_size: + scale = int(input_0_size/input_1_size) + input_1 = get_relay_op('upsampling')(input_1, scale=scale) + + elif input_0_size < input_1_size: + stride = int(input_1_size/input_0_size) + input_1 = get_relay_op('avg_pool2d')(input_1, + pool_size=(1, 1), + strides=(stride, stride), + padding=(0, 0)) + + if input_0_channel != input_1_channel: + pad_channel = input_0_channel - input_1_channel + input_1 = get_relay_op('pad')(input_1, + pad_width=((0, 0), (0, pad_channel), (0, 0), (0, 0)), + pad_value=0.) + sym = input_0 + input_1 + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + sym = _darknet_activations(sym, None, new_attrs) + return sym + +def _darknet_dense(inputs, params, attrs, prefix): + """Process the dense operation.""" + new_attrs = {} + new_attrs['units'] = attrs.get('num_hidden') + data = inputs[0] + + if attrs.get('use_flatten', False) is True: + data = get_relay_op('batch_flatten')(data) + + weight = _get_param_var(params, prefix, 'weight') + data = get_relay_op('dense')(data, weight, **new_attrs) + + use_bias = attrs.get('use_bias', False) + if use_bias: + bias = _get_param_var(params, prefix, 'bias') + data = get_relay_op('bias_add')(data, bias, axis=1) + + if 'use_batchNorm' in attrs: + new_attrs = {} + new_attrs['epsilon'] = 0.000001 + gamma = _get_param_var(params, prefix, 'gamma') + beta = _get_param_var(params, prefix, 'beta') + moving_mean = _get_param_var(params, prefix, 'moving_mean') + moving_var = _get_param_var(params, prefix, 'moving_var') + data = get_relay_op('batch_norm')(data, gamma, beta, moving_mean, moving_var, **new_attrs) + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + data = _darknet_activations(data, None, new_attrs) + return data + +def _darknet_dropout(inputs, params, attrs, prefix): + """Process the dropout operation, its a blank operation.""" + new_attrs = {} + new_attrs['rate'] = attrs.get('p', 0.5) + return get_relay_op('dropout')(*inputs, **new_attrs) + +def _darknet_reshape(inputs, params, attrs, prefix): + """Process the reshape operation.""" + new_attrs = {} + new_attrs['shape'] = attrs.get('shape') + return get_relay_op('reshape')(*inputs, **new_attrs) + +def _darknet_upsampling(inputs, params, attrs, prefix): + """Process the upsampling operation.""" + new_attrs = {} + new_attrs['scale'] = attrs.get('scale', 1) + return get_relay_op('upsampling')(*inputs, **new_attrs) + +def _darknet_l2normalize(inputs, params, attrs, prefix): + """Process the l2 normalization operation.""" + new_attrs = {} + new_attrs['eps'] = attrs.get('eps', 0.0) + new_attrs['axis'] = [attrs.get('axis', 1)] + return get_relay_op('l2_normalize')(*inputs, **new_attrs) + +def _darknet_softmax_output(inputs, params, attrs, prefix): + """Process the softmax operation.""" + temperature = attrs.get('temperature', 1) + data = inputs[0] + if temperature != 1: + data = data / _expr.const(float(temperature)) + + if attrs.get('use_flatten', False) is True: + data = get_relay_op('batch_flatten')(data) + + new_attrs = {} + if attrs.get('multi_output', False): + new_attrs['axis'] = 1 + return get_relay_op('softmax')(data, **new_attrs) + +def _darknet_route(inputs, params, attrs, prefix): + """Process the route operation, which is equivalent to concat.""" + new_attrs = {'axis': attrs.get('dim', 1)} + return get_relay_op('concatenate')((inputs[0], inputs[1]), **new_attrs) + +def _darknet_reorg(inputs, params, attrs, prefix): + """Process the reorg operation.""" + new_attrs = {} + if 'stride' in attrs: + new_attrs = {'stride': attrs.get('stride', 1)} + return get_relay_op('yolo_reorg')(*inputs, **new_attrs) + +def _darknet_region(inputs, params, attrs, prefix): + """Process the region operation.""" + num = attrs.get('n', 1) + classes = attrs.get('classes', 1) + coords = attrs.get('coords', 0) + background = attrs.get('background', 0) + softmax = attrs.get('softmax', True) + input_shape = attrs.get('shape') + + split_size = classes + coords + 1 + intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3]) + data_block = get_relay_op('reshape')(inputs[0], newshape=intermediate_shape) + split_indices = (2, 4, 5) + split_res = get_relay_op('split')(data_block, indices_or_sections=split_indices, axis=2) + split_res0 = get_relay_op('sigmoid')(split_res[0]) + split_res2 = split_res[2] if background else get_relay_op('sigmoid')(split_res[2]) + split_res3 = get_relay_op('softmax')(split_res[3], axis=2) if softmax else split_res[3] + out = get_relay_op('concatenate')((split_res0, split_res[1], split_res2, split_res3), axis=2) + return get_relay_op('reshape')(out, newshape=input_shape) + +def _darknet_yolo(inputs, params, attrs, prefix): + """Process the yolo operation.""" + num = attrs.get('n', 1) + classes = attrs.get('classes', 1) + input_shape = attrs.get('shape') + split_size = classes + 5 + intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3]) + data_block = get_relay_op('reshape')(inputs[0], newshape=intermediate_shape) + split_indices = (2, 4) + split_res = get_relay_op('split')(data_block, indices_or_sections=split_indices, axis=2) + split_res0 = get_relay_op('sigmoid')(split_res[0]) + split_res2 = get_relay_op('sigmoid')(split_res[2]) + out = get_relay_op('concatenate')((split_res0, split_res[1], split_res2), axis=2) + return get_relay_op('reshape')(out, newshape=input_shape) + +class ACTIVATION(object): + """Darknet ACTIVATION Class constant.""" + LOGISTIC = 0 + RELU = 1 + RELIE = 2 + LINEAR = 3 + RAMP = 4 + TANH = 5 + PLSE = 6 + LEAKY = 7 + ELU = 8 + LOGGY = 9 + STAIR = 10 + HARDTAN = 11 + LHTAN = 12 + +def _darknet_activations(inputs, params, attrs): + """Process the activation function.""" + act = attrs.get('activation') + data = inputs[0] if isinstance(inputs, _expr.TupleWrapper) else inputs + + def _const(val): + return _expr.const(val) + + def _relu(data): + return get_relay_op('relu')(data) + + def _exp(data): + return get_relay_op('exp')(data) + + def _tanh(data): + return get_relay_op('tanh')(data) + + def _sigmoid(data): + return get_relay_op('sigmoid')(data) + + def _elu(data): + alpha = _const(-1.0) + return alpha * _relu(_const(1.0) - _exp(data)) + _relu(data) + + def _leaky_relu(data, slope): + new_attrs = {} + new_attrs['alpha'] = slope + return get_relay_op('leaky_relu')(data, **new_attrs) + + if ACTIVATION.LOGISTIC == act: + data = _sigmoid(data) + elif ACTIVATION.RELU == act: + data = _relu(data) + elif ACTIVATION.TANH == act: + data = _tanh(data) + elif ACTIVATION.LINEAR == act: + return data + elif ACTIVATION.LEAKY == act: + data = _leaky_relu(data, attrs.get('slope', 0.1)) + elif ACTIVATION.ELU == act: + data = _elu(data) + else: + _darknet_not_support('act: ' + attrs) + return data + +class LAYERTYPE(Enum): + """Darknet LAYERTYPE Class constant.""" + CONVOLUTIONAL = 0 + DECONVOLUTIONAL = 1 + CONNECTED = 2 + MAXPOOL = 3 + SOFTMAX = 4 + DETECTION = 5 + DROPOUT = 6 + CROP = 7 + ROUTE = 8 + COST = 9 + NORMALIZATION = 10 + AVGPOOL = 11 + LOCAL = 12 + SHORTCUT = 13 + ACTIVE = 14 + RNN = 15 + GRU = 16 + LSTM = 17 + CRNN = 18 + BATCHNORM = 19 + NETWORK = 20 + XNOR = 21 + REGION = 22 + YOLO = 23 + REORG = 24 + UPSAMPLE = 25 + LOGXENT = 26 + L2NORM = 27 + BLANK = 28 + +_DARKNET_CONVERT_MAP = { + LAYERTYPE.CONVOLUTIONAL : _darknet_conv2d, + LAYERTYPE.CONNECTED : _darknet_dense, + LAYERTYPE.MAXPOOL : _darknet_maxpooling, + LAYERTYPE.SOFTMAX : _darknet_softmax_output, + LAYERTYPE.DROPOUT : _darknet_dropout, + LAYERTYPE.AVGPOOL : _darknet_avgpooling, + LAYERTYPE.ROUTE : _darknet_route, + LAYERTYPE.REORG : _darknet_reorg, + LAYERTYPE.REGION : _darknet_region, + LAYERTYPE.SHORTCUT : _darknet_shortcut, + LAYERTYPE.UPSAMPLE : _darknet_upsampling, + LAYERTYPE.L2NORM : _darknet_l2normalize, + LAYERTYPE.YOLO : _darknet_yolo, + LAYERTYPE.DECONVOLUTIONAL : _darknet_not_support, + LAYERTYPE.BATCHNORM : _darknet_not_support, + LAYERTYPE.DETECTION : _darknet_not_support, + LAYERTYPE.CROP : _darknet_not_support, + LAYERTYPE.COST : _darknet_not_support, + LAYERTYPE.NORMALIZATION : _darknet_not_support, + LAYERTYPE.LOCAL : _darknet_not_support, + LAYERTYPE.ACTIVE : _darknet_not_support, + LAYERTYPE.RNN : _darknet_not_support, + LAYERTYPE.GRU : _darknet_not_support, + LAYERTYPE.LSTM : _darknet_not_support, + LAYERTYPE.CRNN : _darknet_not_support, + LAYERTYPE.NETWORK : _darknet_not_support, + LAYERTYPE.XNOR : _darknet_not_support, + LAYERTYPE.BLANK : _darknet_not_support, +} + +def _darknet_convert_symbol(op_name, inputs, params, attrs, params_prefix): + """Convert from darknet op to relay op. + Parameters + ---------- + op_name : str + Operator name, such as Convolution, Connected, etc + inputs : list of relay.Function + List of input symbols. + attrs : dict + Dict of operator attributes + params_prefix: str + Params name for this operation + + Returns + ------- + out_name : converted out name of operation + sym : tvm.relay.Function + Converted relay function + """ + + if op_name in _DARKNET_CONVERT_MAP: + sym = _DARKNET_CONVERT_MAP[op_name](inputs, params, attrs, params_prefix) + else: + _darknet_not_support('Operator type ' + str(op_name)) + return sym + +def _as_list(arr): + """Force being a list, ignore if already is.""" + if isinstance(arr, list): + return arr + return [arr] + +class GraphProto(object): + """A helper class for handling relay functions from darknet model. + """ + + def __init__(self, net, shape, dtype='float32'): + self._net = net + self._shape = shape + self._dtype = dtype + self._sym_array = {} + self._tvmparams = {} + self._outs = [] + self._state_ctr = {} + self._state_ctr['rnn'] = 0 + self._state_ctr['crnn'] = 0 + self._state_ctr['lstm'] = 0 + self._state_ctr['cell_state'] = 0 + self._state_ctr['gru'] = 0 + + def _read_memory_buffer(self, shape, data, dtype=None): + if dtype is None: + dtype = self._dtype + length = 1 + for x in shape: + length *= x + data_np = np.zeros(length, dtype=dtype) + for i in range(length): + data_np[i] = data[i] + return data_np.reshape(shape) + + def _get_convolution_weights(self, layer, opname): + """Get the convolution layer weights and biases.""" + if layer.nweights == 0: + return None + + if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: + raise RuntimeError("layer weights size not matching with n c h w") + + params = {} + shape = (layer.n, layer.c, layer.size, layer.size) + weights = self._read_memory_buffer(shape, layer.weights) + + biases = self._read_memory_buffer((layer.n, ), layer.biases) + + k = _get_params_name(opname, 'weight') + params[k] = tvm.nd.array(weights) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + params.update(self._get_batchnorm_weights(layer, opname, layer.n)) + k = _get_params_name(opname, 'beta') + params[k] = tvm.nd.array(biases) + else: + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + return params + + def _get_connected_weights(self, layer, opname): + """Parse the weights and biases for fully connected or dense layer.""" + size = layer.outputs * layer.inputs + if size == 0: + return None + + weights = self._read_memory_buffer((layer.outputs, layer.inputs), layer.weights) + biases = self._read_memory_buffer((layer.outputs, ), layer.biases) + + params = {} + k = _get_params_name(opname, 'weight') + params[k] = tvm.nd.array(weights) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + params.update(self._get_batchnorm_weights(layer, opname, layer.outputs)) + k = _get_params_name(opname, 'beta') + params[k] = tvm.nd.array(biases) + else: + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + return params + + def _get_region_weights(self, layer, opname): + """Parse the biases for region layer.""" + biases = self._read_memory_buffer((layer.n*2, ), layer.biases) + attributes = np.array([layer.n, layer.out_c, layer.out_h, layer.out_w, + layer.classes, layer.coords, layer.background], + dtype=np.int32) + params = {} + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + k = _get_params_name(opname, 'attr') + params[k] = tvm.nd.array(attributes) + return params + + def _get_yolo_weights(self, layer, opname): + """Parse the biases and mask for yolo layer.""" + biases = self._read_memory_buffer((layer.total*2, ), layer.biases) + mask = self._read_memory_buffer((layer.n, ), layer.mask, dtype='int32') + attributes = np.array([layer.n, layer.out_c, layer.out_h, layer.out_w, + layer.classes, layer.total], + dtype=np.int32) + params = {} + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + k = _get_params_name(opname, 'mask') + params[k] = tvm.nd.array(mask) + k = _get_params_name(opname, 'attr') + params[k] = tvm.nd.array(attributes) + return params + + def _get_batchnorm_weights(self, layer, opname, size): + """Parse the weights for batchnorm, which includes, scales, moving mean + and moving variances.""" + scales = self._read_memory_buffer((size, ), layer.scales) + rolling_mean = self._read_memory_buffer((size, ), layer.rolling_mean) + rolling_variance = self._read_memory_buffer((size, ), layer.rolling_variance) + + params = {} + k = _get_params_name(opname, 'moving_mean') + params[k] = tvm.nd.array(rolling_mean) + k = _get_params_name(opname, 'moving_var') + params[k] = tvm.nd.array(rolling_variance) + k = _get_params_name(opname, 'gamma') + params[k] = tvm.nd.array(scales) + return params + + def _get_darknet_attrs(self, layer, layer_num): + """Parse attributes of each layer and return.""" + attr = {} + use_flatten = True + layer_type = LAYERTYPE(layer.type) + if LAYERTYPE.CONVOLUTIONAL == layer_type: + attr.update({'pad' : layer.pad}) + attr.update({'num_group' : layer.groups}) + attr.update({'num_filter' : layer.n}) + attr.update({'stride' : layer.stride}) + attr.update({'kernel' : layer.size}) + attr.update({'activation' : (layer.activation)}) + + if layer.nbiases == 0: + attr.update({'use_bias' : False}) + else: + attr.update({'use_bias' : True}) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + + elif LAYERTYPE.CONNECTED == layer_type: + attr.update({'num_hidden' : layer.outputs}) + attr.update({'activation' : (layer.activation)}) + if layer_num != 0: + layer_prev = self._net.layers[layer_num - 1] + if (layer_prev.out_h == layer.h and + layer_prev.out_w == layer.w and + layer_prev.out_c == layer.c): + use_flatten = False + attr.update({'use_flatten' : use_flatten}) + attr.update({'use_bias' : True}) + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + attr.update({'use_bias' : False}) + + elif LAYERTYPE.MAXPOOL == layer_type: + attr.update({'pad' : layer.pad}) + attr.update({'stride' : layer.stride}) + attr.update({'kernel' : layer.size}) + max_output = (layer.w - layer.size + 2 * layer.pad)/float(layer.stride) + 1 + if max_output < layer.out_w: + extra_pad = (layer.out_w - max_output)*layer.stride + attr.update({'extra_pad_size' : int(extra_pad)}) + elif LAYERTYPE.AVGPOOL == layer_type: + attr.update({'pad' : layer.pad}) + if layer.stride == 0: + attr.update({'stride' : 1}) + else: + attr.update({'stride' : layer.stride}) + if layer.size == 0 and layer.h == layer.w: + attr.update({'kernel' : layer.h}) + else: + attr.update({'kernel' : layer.size}) + + elif LAYERTYPE.DROPOUT == layer_type: + attr.update({'p' : layer.probability}) + + elif LAYERTYPE.SOFTMAX == layer_type: + attr.update({'axis' : 1}) + attr.update({'use_flatten' : True}) + if layer.temperature: + attr.update({'temperature' : str(layer.temperature)}) + + elif LAYERTYPE.SHORTCUT == layer_type: + add_layer = self._net.layers[layer.index] + attr.update({'activation' : layer.activation}) + attr.update({'out_channel' : layer.out_c}) + attr.update({'out_size' : layer.out_h}) + attr.update({'add_out_channel' : add_layer.out_c}) + attr.update({'add_out_size' : add_layer.out_h}) + + elif LAYERTYPE.ROUTE == layer_type: + pass + + elif LAYERTYPE.COST == layer_type: + pass + + elif LAYERTYPE.REORG == layer_type: + attr.update({'stride' : layer.stride}) + + elif LAYERTYPE.REGION == layer_type: + attr.update({'n' : layer.n}) + attr.update({'classes' : layer.classes}) + attr.update({'coords' : layer.coords}) + attr.update({'background' : layer.background}) + attr.update({'softmax' : layer.softmax}) + attr.update({'shape' : (1, layer.c, layer.h, layer.w)}) + + elif LAYERTYPE.YOLO == layer_type: + attr.update({'n' : layer.n}) + attr.update({'classes' : layer.classes}) + attr.update({'shape' : (1, layer.c, layer.h, layer.w)}) + + elif LAYERTYPE.UPSAMPLE == layer_type: + attr.update({'scale' : layer.stride}) + + elif LAYERTYPE.L2NORM == layer_type: + pass + + else: + err = "Darknet layer type {} is not supported in relay.".format(layer_type) + raise NotImplementedError(err) + + return attr + + def _get_darknet_params(self, layer, opname): + """To parse and get the darknet params.""" + layer_type = LAYERTYPE(layer.type) + params = None + if LAYERTYPE.CONVOLUTIONAL == layer_type: + params = self._get_convolution_weights(layer, opname) + elif LAYERTYPE.CONNECTED == layer_type: + params = self._get_connected_weights(layer, opname) + elif LAYERTYPE.REGION == layer_type: + params = self._get_region_weights(layer, opname) + elif LAYERTYPE.YOLO == layer_type: + params = self._get_yolo_weights(layer, opname) + return params + + def _preproc_layer(self, layer, layer_num): + """To preprocess each darknet layer, some layer doesnt need processing.""" + if layer_num == 0: + name = 'data' + sym = new_var(name, shape=self._shape, dtype=self._dtype) + else: + sym = self._sym_array[layer_num - 1] + skip_layer = False + layer_type = LAYERTYPE(layer.type) + if LAYERTYPE.ROUTE == layer_type: + sym = [] + for j in range(layer.n): + sym.append(self._sym_array[layer.input_layers[j]]) + if layer.n == 1: + skip_layer = True + + elif LAYERTYPE.COST == layer_type: + skip_layer = True + + elif LAYERTYPE.SHORTCUT == layer_type: + sym = [sym, self._sym_array[layer.index]] + + elif LAYERTYPE.BLANK == layer_type: + skip_layer = True + + if skip_layer is True: + self._sym_array[layer_num] = sym + + return skip_layer, sym + + def _get_opname(self, layer): + """Returs the layer name.""" + return LAYERTYPE(layer.type) + + def _new_rnn_state_var(self, state=None, name='rnn'): + """Returs a symbol for state""" + sym_name = name + "%d_state" % self._state_ctr[name] + self._state_ctr[name] += 1 + return new_var(sym_name, shape=state.shape, dtype=str(state.dtype)) + + def _get_rnn_state_buffer(self, layer, name): + """Get the state buffer for rnn.""" + buffer = np.zeros((1, layer.outputs), self._dtype) + return self._new_rnn_state_var(buffer, name) + + def _get_darknet_rnn_attrs(self, layer, name, sym): + """Get the rnn converted symbol from attributes.""" + attr = self._get_darknet_attrs(layer, 0) + op_name = self._get_opname(layer) + prefix = _get_params_prefix(op_name, name) + params = self._get_darknet_params(layer, prefix) + sym = _darknet_convert_symbol(op_name, _as_list(sym), params, attr, prefix) + if params: + self._tvmparams.update(params) + return sym + + def _handle_darknet_rnn_layers(self, layer_num, sym): + """Parse attributes and handle the rnn layers.""" + attr = {} + layer = self._net.layers[layer_num] + processed = False + + layer_type = LAYERTYPE(layer.type) + if LAYERTYPE.RNN == layer_type: + attr.update({'n' : layer.n}) + attr.update({'batch' : layer.batch}) + attr.update({'num_hidden' : str(layer.outputs)}) + state = self._get_rnn_state_buffer(layer, 'rnn') + for _ in range(layer.steps): + input_layer = layer.input_layer + prefix = "_input_" + str(layer_num) + sym = self._get_darknet_rnn_attrs(input_layer, prefix, sym) + + self_layer = layer.self_layer + prefix = "_self_" + str(layer_num) + state = self._get_darknet_rnn_attrs(self_layer, prefix, state) + + state = sym + state + self._outs.append(state) + + output_layer = layer.output_layer + prefix = "_output_" + str(layer_num) + sym = self._get_darknet_rnn_attrs(output_layer, prefix, state) + + self._sym_array[layer_num] = sym + processed = True + return processed, sym + + def _make_outlist(self, sym, op_name, layer, layer_num): + layer_type = LAYERTYPE(layer.type) + if layer_type == LAYERTYPE.REGION: + #Add attributes + k = _get_params_name(op_name, 'attr') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + #Add bias + k = _get_params_name(op_name, 'bias') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + if layer_num != self._net.n-1: + self._outs.insert(0, sym) + + elif layer_type == LAYERTYPE.YOLO: + #Add attributes + k = _get_params_name(op_name, 'attr') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + #Add bias + k = _get_params_name(op_name, 'bias') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + #Add mask + k = _get_params_name(op_name, 'mask') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + if layer_num != self._net.n-1: + self._outs.insert(0, sym) + + def from_darknet(self): + """To convert the darknet symbol to relay functions.""" + for i in range(self._net.n): + layer = self._net.layers[i] + need_skip, sym = self._preproc_layer(layer, i) + if need_skip: + continue + + processed, sym = self._handle_darknet_rnn_layers(i, sym) + if processed: + continue + + attr = self._get_darknet_attrs(layer, i) + op_name = self._get_opname(layer) + prefix = _get_params_prefix(op_name, i) + params = self._get_darknet_params(self._net.layers[i], prefix) + sym = _darknet_convert_symbol(op_name, _as_list(sym), params, attr, prefix) + + if params: + self._tvmparams.update(params) + self._sym_array[i] = sym + self._make_outlist(sym, prefix, layer, i) + + outputs = _as_list(sym) + self._outs + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + sym = _expr.Function(ir_pass.free_vars(outputs), outputs) + return sym, self._tvmparams + +def from_darknet(net, + shape=None, + dtype="float32"): + """Convert from Darknet's model into compatible relay Function. + + Parameters + ---------- + net : Darknet net parameter + Darknet net structure. + shape : dict of str to tuple, optional + The input shape to the graph + dtype : str or dict of str to str + The input types to the graph + + Returns + ------- + sym : tvm.relay.Function + Compatible relay Function + params : dict of str to tvm.NDArray + The parameter dict to be used by relay + """ + + return GraphProto(net, shape, dtype).from_darknet() diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 192afe1..7a5007b 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -27,6 +27,7 @@ from . import inception_v3 from . import squeezenet from . import vgg from . import densenet +from . import yolo_detection from .config import ctx_list from .init import create_workload diff --git a/nnvm/python/nnvm/testing/darknet.py b/python/tvm/relay/testing/darknet.py similarity index 100% rename from nnvm/python/nnvm/testing/darknet.py rename to python/tvm/relay/testing/darknet.py diff --git a/nnvm/python/nnvm/testing/yolo_detection.py b/python/tvm/relay/testing/yolo_detection.py similarity index 100% rename from nnvm/python/nnvm/testing/yolo_detection.py rename to python/tvm/relay/testing/yolo_detection.py diff --git a/tests/python/frontend/darknet/test_forward.py b/tests/python/frontend/darknet/test_forward.py new file mode 100644 index 0000000..3545e8a --- /dev/null +++ b/tests/python/frontend/darknet/test_forward.py @@ -0,0 +1,462 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Test Darknet Models +=================== +This article is a test script to test darknet models with Relay. +All the required models and libraries will be downloaded from the internet +by the script. +""" +import numpy as np +import tvm +from tvm.contrib import graph_runtime +from tvm.contrib.download import download_testdata +download_testdata.__test__ = False +from tvm.relay.testing.darknet import LAYERTYPE +from tvm.relay.testing.darknet import __darknetffi__ +from tvm.relay.frontend.darknet import ACTIVATION +from tvm import relay + +REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/' +DARKNET_LIB = 'libdarknet2.0.so' +DARKNETLIB_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true' +LIB = __darknetffi__.dlopen(download_testdata(DARKNETLIB_URL, DARKNET_LIB, module='darknet')) + +DARKNET_TEST_IMAGE_NAME = 'dog.jpg' +DARKNET_TEST_IMAGE_URL = REPO_URL + 'data/' + DARKNET_TEST_IMAGE_NAME +'?raw=true' +DARKNET_TEST_IMAGE_PATH = download_testdata(DARKNET_TEST_IMAGE_URL, DARKNET_TEST_IMAGE_NAME, module='data') + +def _read_memory_buffer(shape, data, dtype='float32'): + length = 1 + for x in shape: + length *= x + data_np = np.zeros(length, dtype=dtype) + for i in range(length): + data_np[i] = data[i] + return data_np.reshape(shape) + +def _get_tvm_output(net, data, build_dtype='float32', states=None): + '''Compute TVM output''' + dtype = 'float32' + sym, params = relay.frontend.from_darknet(net, data.shape, dtype) + target = 'llvm' + shape_dict = {'data': data.shape} + graph, library, params = relay.build(sym, target, params=params) + + # Execute on TVM + ctx = tvm.cpu(0) + m = graph_runtime.create(graph, library, ctx) + # set inputs + m.set_input('data', tvm.nd.array(data.astype(dtype))) + if states: + for name in states.keys(): + m.set_input(name, tvm.nd.array(states[name].astype(dtype))) + m.set_input(**params) + m.run() + # get outputs + tvm_out = [] + for i in range(m.get_num_outputs()): + tvm_out.append(m.get_output(i).asnumpy()) + return tvm_out + +def _load_net(cfg_url, cfg_name, weights_url, weights_name): + cfg_path = download_testdata(cfg_url, cfg_name, module='darknet') + weights_path = download_testdata(weights_url, weights_name, module='darknet') + net = LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0) + return net + +def verify_darknet_frontend(net, build_dtype='float32'): + '''Test network with given input image on both darknet and tvm''' + def get_darknet_output(net, img): + LIB.network_predict_image(net, img) + out = [] + for i in range(net.n): + layer = net.layers[i] + if layer.type == LAYERTYPE.REGION: + attributes = np.array([layer.n, layer.out_c, layer.out_h, + layer.out_w, layer.classes, + layer.coords, layer.background], + dtype=np.int32) + out.insert(0, attributes) + out.insert(0, _read_memory_buffer((layer.n*2, ), layer.biases)) + layer_outshape = (layer.batch, layer.out_c, + layer.out_h, layer.out_w) + out.insert(0, _read_memory_buffer(layer_outshape, layer.output)) + elif layer.type == LAYERTYPE.YOLO: + attributes = np.array([layer.n, layer.out_c, layer.out_h, + layer.out_w, layer.classes, + layer.total], + dtype=np.int32) + out.insert(0, attributes) + out.insert(0, _read_memory_buffer((layer.total*2, ), layer.biases)) + out.insert(0, _read_memory_buffer((layer.n, ), layer.mask, dtype='int32')) + layer_outshape = (layer.batch, layer.out_c, + layer.out_h, layer.out_w) + out.insert(0, _read_memory_buffer(layer_outshape, layer.output)) + elif i == net.n-1: + if layer.type == LAYERTYPE.CONNECTED: + darknet_outshape = (layer.batch, layer.out_c) + elif layer.type in [LAYERTYPE.SOFTMAX]: + darknet_outshape = (layer.batch, layer.outputs) + else: + darknet_outshape = (layer.batch, layer.out_c, + layer.out_h, layer.out_w) + out.insert(0, _read_memory_buffer(darknet_outshape, layer.output)) + return out + + dtype = 'float32' + + img = LIB.letterbox_image(LIB.load_image_color(DARKNET_TEST_IMAGE_PATH.encode('utf-8'), 0, 0), net.w, net.h) + darknet_output = get_darknet_output(net, img) + batch_size = 1 + data = np.empty([batch_size, img.c, img.h, img.w], dtype) + i = 0 + for c in range(img.c): + for h in range(img.h): + for k in range(img.w): + data[0][c][h][k] = img.data[i] + i = i + 1 + + tvm_out = _get_tvm_output(net, data, build_dtype) + for tvm_outs, darknet_out in zip(tvm_out, darknet_output): + tvm.testing.assert_allclose(darknet_out, tvm_outs, rtol=1e-3, atol=1e-3) + +def _test_rnn_network(net, states): + '''Test network with given input data on both darknet and tvm''' + def get_darknet_network_predict(net, data): + return LIB.network_predict(net, data) + from cffi import FFI + ffi = FFI() + np_arr = np.zeros([1, net.inputs], dtype='float32') + np_arr[0, 2] = 1 + cffi_arr = ffi.cast('float*', np_arr.ctypes.data) + tvm_out = _get_tvm_output(net, np_arr, states=states)[0] + darknet_output = get_darknet_network_predict(net, cffi_arr) + darknet_out = np.zeros(net.outputs, dtype='float32') + for i in range(net.outputs): + darknet_out[i] = darknet_output[i] + last_layer = net.layers[net.n-1] + darknet_outshape = (last_layer.batch, last_layer.outputs) + darknet_out = darknet_out.reshape(darknet_outshape) + tvm.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-4, atol=1e-4) + +def test_forward_extraction(): + '''test extraction model''' + model_name = 'extraction' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_alexnet(): + '''test alexnet model''' + model_name = 'alexnet' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_resnet50(): + '''test resnet50 model''' + model_name = 'resnet50' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_yolov2(): + '''test yolov2 model''' + model_name = 'yolov2' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + build_dtype = {} + verify_darknet_frontend(net, build_dtype) + LIB.free_network(net) + +def test_forward_yolov3(): + '''test yolov3 model''' + model_name = 'yolov3' + cfg_name = model_name + '.cfg' + weights_name = model_name + '.weights' + cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true' + weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true' + net = _load_net(cfg_url, cfg_name, weights_url, weights_name) + build_dtype = {} + verify_darknet_frontend(net, build_dtype) + LIB.free_network(net) + +def test_forward_convolutional(): + '''test convolutional layer''' + net = LIB.make_network(1) + layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_dense(): + '''test fully connected layer''' + net = LIB.make_network(1) + layer = LIB.make_connected_layer(1, 75, 20, 1, 0, 0) + net.layers[0] = layer + net.w = net.h = 5 + LIB.resize_network(net, 5, 5) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_dense_batchnorm(): + '''test fully connected layer with batchnorm''' + net = LIB.make_network(1) + layer = LIB.make_connected_layer(1, 12, 2, 1, 1, 0) + for i in range(5): + layer.rolling_mean[i] = np.random.rand(1) + layer.rolling_variance[i] = np.random.rand(1) + layer.scales[i] = np.random.rand(1) + net.layers[0] = layer + net.w = net.h = 2 + LIB.resize_network(net, 2, 2) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_maxpooling(): + '''test maxpooling layer''' + net = LIB.make_network(1) + layer = LIB.make_maxpool_layer(1, 224, 224, 3, 2, 2, 0) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_avgpooling(): + '''test avgerage pooling layer''' + net = LIB.make_network(1) + layer = LIB.make_avgpool_layer(1, 224, 224, 3) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_conv_batch_norm(): + '''test batch normalization layer''' + net = LIB.make_network(1) + layer = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 1, 0, 0, 0) + for i in range(32): + layer.rolling_mean[i] = np.random.rand(1) + layer.rolling_variance[i] = np.random.rand(1) + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_shortcut(): + '''test shortcut layer''' + net = LIB.make_network(3) + layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_convolutional_layer(1, 111, 111, 32, 32, 1, 1, 1, 0, 1, 0, 0, 0, 0) + layer_3 = LIB.make_shortcut_layer(1, 0, 111, 111, 32, 111, 111, 32) + layer_3.activation = ACTIVATION.RELU + layer_3.alpha = 1 + layer_3.beta = 1 + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.layers[2] = layer_3 + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_reorg(): + '''test reorg layer''' + net = LIB.make_network(2) + layer_1 = LIB.make_convolutional_layer(1, 222, 222, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_reorg_layer(1, 110, 110, 32, 2, 0, 0, 0) + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.w = net.h = 222 + LIB.resize_network(net, 222, 222) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_region(): + '''test region layer''' + net = LIB.make_network(2) + layer_1 = LIB.make_convolutional_layer(1, 19, 19, 3, 425, 1, 1, 1, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_region_layer(1, 19, 19, 5, 80, 4) + layer_2.softmax = 1 + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.w = net.h = 19 + LIB.resize_network(net, 19, 19) + build_dtype = {} + verify_darknet_frontend(net, build_dtype) + LIB.free_network(net) + +def test_forward_yolo_op(): + '''test yolo layer''' + net = LIB.make_network(2) + layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 14, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_2 = LIB.make_yolo_layer(1, 111, 111, 2, 9, __darknetffi__.NULL, 2) + net.layers[0] = layer_1 + net.layers[1] = layer_2 + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + build_dtype = {} + verify_darknet_frontend(net, build_dtype) + LIB.free_network(net) + +def test_forward_upsample(): + '''test upsample layer''' + net = LIB.make_network(1) + layer = LIB.make_upsample_layer(1, 19, 19, 3, 3) + layer.scale = 1 + net.layers[0] = layer + net.w = net.h = 19 + LIB.resize_network(net, 19, 19) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_l2normalize(): + '''test l2 normalization layer''' + net = LIB.make_network(1) + layer = LIB.make_l2norm_layer(1, 224*224*3) + layer.c = layer.out_c = 3 + layer.h = layer.out_h = 224 + layer.w = layer.out_w = 224 + net.layers[0] = layer + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_elu(): + '''test elu activation layer''' + net = LIB.make_network(1) + layer_1 = LIB.make_convolutional_layer(1, 224, 224, 3, 32, 1, 3, 2, 0, 1, 0, 0, 0, 0) + layer_1.activation = ACTIVATION.ELU + net.layers[0] = layer_1 + net.w = net.h = 224 + LIB.resize_network(net, 224, 224) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_softmax(): + '''test softmax layer''' + net = LIB.make_network(1) + layer_1 = LIB.make_softmax_layer(1, 75, 1) + layer_1.temperature = 1 + net.layers[0] = layer_1 + net.w = net.h = 5 + LIB.resize_network(net, net.w, net.h) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_softmax_temperature(): + '''test softmax layer''' + net = LIB.make_network(1) + layer_1 = LIB.make_softmax_layer(1, 75, 1) + layer_1.temperature = 0.8 + net.layers[0] = layer_1 + net.w = net.h = 5 + LIB.resize_network(net, net.w, net.h) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_activation_logistic(): + '''test logistic activation layer''' + net = LIB.make_network(1) + batch = 1 + h = 224 + w = 224 + c = 3 + n = 32 + groups = 1 + size = 3 + stride = 2 + padding = 0 + activation = ACTIVATION.LOGISTIC + batch_normalize = 0 + binary = 0 + xnor = 0 + adam = 0 + layer_1 = LIB.make_convolutional_layer(batch, h, w, c, n, groups, size, stride, padding, + activation, batch_normalize, binary, xnor, adam) + net.layers[0] = layer_1 + net.w = w + net.h = h + LIB.resize_network(net, net.w, net.h) + verify_darknet_frontend(net) + LIB.free_network(net) + +def test_forward_rnn(): + '''test RNN layer''' + net = LIB.make_network(1) + batch = 1 + inputs = 4 + outputs = 4 + steps = 1 + activation = ACTIVATION.RELU + batch_normalize = 0 + adam = 0 + layer_1 = LIB.make_rnn_layer(batch, inputs, outputs, steps, activation, batch_normalize, adam) + net.layers[0] = layer_1 + net.inputs = inputs + net.outputs = outputs + net.w = net.h = 0 + LIB.resize_network(net, net.w, net.h) + states = {"rnn0_state": np.zeros([1, net.inputs])} + _test_rnn_network(net, states) + LIB.free_network(net) + +if __name__ == '__main__': + test_forward_resnet50() + test_forward_alexnet() + test_forward_extraction() + test_forward_yolov2() + test_forward_yolov3() + test_forward_convolutional() + test_forward_maxpooling() + test_forward_avgpooling() + test_forward_conv_batch_norm() + test_forward_shortcut() + test_forward_dense() + test_forward_dense_batchnorm() + test_forward_softmax() + test_forward_softmax_temperature() + test_forward_reorg() + test_forward_region() + test_forward_yolo_op() + test_forward_upsample() + test_forward_l2normalize() + test_forward_elu() + test_forward_rnn() + test_forward_activation_logistic() diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 37159db..609b001 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -62,10 +62,10 @@ python3 -m nose -v tests/python/frontend/mxnet echo "Running relay Keras frontend test..." python3 -m nose -v tests/python/frontend/keras -echo "Running relay ONNX frondend test..." +echo "Running relay ONNX frontend test..." python3 -m nose -v tests/python/frontend/onnx -echo "Running relay CoreML frondend test..." +echo "Running relay CoreML frontend test..." python3 -m nose -v tests/python/frontend/coreml echo "Running nnvm to relay frontend test..." @@ -74,5 +74,8 @@ python3 -m nose -v tests/python/frontend/nnvm_to_relay echo "Running relay Tensorflow frontend test..." python3 -m nose -v tests/python/frontend/tensorflow -echo "Running relay caffe2 frondend test..." +echo "Running relay caffe2 frontend test..." python3 -m nose -v tests/python/frontend/caffe2 + +echo "Running relay DarkNet frontend test..." +python3 -m nose -v tests/python/frontend/darknet || exit -1 diff --git a/tutorials/frontend/from_darknet.py b/tutorials/frontend/from_darknet.py new file mode 100644 index 0000000..2658a35 --- /dev/null +++ b/tutorials/frontend/from_darknet.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Compile YOLO-V2 and YOLO-V3 in DarkNet Models +============================================= +**Author**: `Siju Samuel `_ + +This article is an introductory tutorial to deploy darknet models with TVM. +All the required models and libraries will be downloaded from the internet by the script. +This script runs the YOLO-V2 and YOLO-V3 Model with the bounding boxes +Darknet parsing have dependancy with CFFI and CV2 library +Please install CFFI and CV2 before executing this script + +.. code-block:: bash + + pip install cffi + pip install opencv-python +""" + +# numpy and matplotlib +import numpy as np +import matplotlib.pyplot as plt +import sys + +# tvm, relay +import tvm +from tvm import relay +from ctypes import * +from tvm.contrib.download import download_testdata +from tvm.relay.testing.darknet import __darknetffi__ +import tvm.relay.testing.yolo_detection +import tvm.relay.testing.darknet + +# Model name +MODEL_NAME = 'yolov3' + +###################################################################### +# Download required files +# ----------------------- +# Download cfg and weights file if first time. +CFG_NAME = MODEL_NAME + '.cfg' +WEIGHTS_NAME = MODEL_NAME + '.weights' +REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/' +CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true' +WEIGHTS_URL = 'https://pjreddie.com/media/files/' + WEIGHTS_NAME + +cfg_path = download_testdata(CFG_URL, CFG_NAME, module="darknet") +weights_path = download_testdata(WEIGHTS_URL, WEIGHTS_NAME, module="darknet") + +# Download and Load darknet library +if sys.platform in ['linux', 'linux2']: + DARKNET_LIB = 'libdarknet2.0.so' + DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true' +elif sys.platform == 'darwin': + DARKNET_LIB = 'libdarknet_mac2.0.so' + DARKNET_URL = REPO_URL + 'lib_osx/' + DARKNET_LIB + '?raw=true' +else: + err = "Darknet lib is not supported on {} platform".format(sys.platform) + raise NotImplementedError(err) + +lib_path = download_testdata(DARKNET_URL, DARKNET_LIB, module="darknet") + +DARKNET_LIB = __darknetffi__.dlopen(lib_path) +net = DARKNET_LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0) +dtype = 'float32' +batch_size = 1 + +data = np.empty([batch_size, net.c, net.h, net.w], dtype) +shape_dict = {'data': data.shape} +print("Converting darknet to relay functions...") +sym, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape) + +###################################################################### +# Import the graph to Relay +# ------------------------- +# compile the model +target = 'llvm' +target_host = 'llvm' +ctx = tvm.cpu(0) +data = np.empty([batch_size, net.c, net.h, net.w], dtype) +shape = {'data': data.shape} +print("Compiling the model...") +with relay.build_config(opt_level=3): + graph, lib, params = relay.build(sym, target=target, target_host=target_host, params=params) + +[neth, netw] = shape['data'][2:] # Current image shape is 608x608 +###################################################################### +# Load a test image +# ----------------- +test_image = 'dog.jpg' +print("Loading the test image...") +img_url = REPO_URL + 'data/' + test_image + '?raw=true' +img_path = download_testdata(img_url, test_image, "data") + +data = tvm.relay.testing.darknet.load_image(img_path, netw, neth) +###################################################################### +# Execute on TVM Runtime +# ---------------------- +# The process is no different from other examples. +from tvm.contrib import graph_runtime + +m = graph_runtime.create(graph, lib, ctx) + +# set inputs +m.set_input('data', tvm.nd.array(data.astype(dtype))) +m.set_input(**params) +# execute +print("Running the test image...") + +m.run() +# get outputs +tvm_out = [] +if MODEL_NAME == 'yolov2': + layer_out = {} + layer_out['type'] = 'Region' + # Get the region layer attributes (n, out_c, out_h, out_w, classes, coords, background) + layer_attr = m.get_output(2).asnumpy() + layer_out['biases'] = m.get_output(1).asnumpy() + out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0], + layer_attr[2], layer_attr[3]) + layer_out['output'] = m.get_output(0).asnumpy().reshape(out_shape) + layer_out['classes'] = layer_attr[4] + layer_out['coords'] = layer_attr[5] + layer_out['background'] = layer_attr[6] + tvm_out.append(layer_out) + +elif MODEL_NAME == 'yolov3': + for i in range(3): + layer_out = {} + layer_out['type'] = 'Yolo' + # Get the yolo layer attributes (n, out_c, out_h, out_w, classes, total) + layer_attr = m.get_output(i*4+3).asnumpy() + layer_out['biases'] = m.get_output(i*4+2).asnumpy() + layer_out['mask'] = m.get_output(i*4+1).asnumpy() + out_shape = (layer_attr[0], layer_attr[1]//layer_attr[0], + layer_attr[2], layer_attr[3]) + layer_out['output'] = m.get_output(i*4).asnumpy().reshape(out_shape) + layer_out['classes'] = layer_attr[4] + tvm_out.append(layer_out) + +# do the detection and bring up the bounding boxes +thresh = 0.5 +nms_thresh = 0.45 +img = tvm.relay.testing.darknet.load_image_color(img_path) +_, im_h, im_w = img.shape +dets = tvm.relay.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, + 1, tvm_out) +last_layer = net.layers[net.n - 1] +tvm.relay.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh) + +coco_name = 'coco.names' +coco_url = REPO_URL + 'data/' + coco_name + '?raw=true' +font_name = 'arial.ttf' +font_url = REPO_URL + 'data/' + font_name + '?raw=true' +coco_path = download_testdata(coco_url, coco_name, module='data') +font_path = download_testdata(font_url, font_name, module='data') + +with open(coco_path) as f: + content = f.readlines() + +names = [x.strip() for x in content] + +tvm.relay.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) +plt.imshow(img.transpose(1, 2, 0)) +plt.show() -- 2.7.4