[Relay][Frontend][darknet] Solve tvm parsing darknet resnext failure bug (#3778)
authoryouluexx <hanzh_20@163.com>
Wed, 4 Sep 2019 05:46:29 +0000 (13:46 +0800)
committerYizhi Liu <liuyizhi@apache.org>
Wed, 4 Sep 2019 05:46:29 +0000 (13:46 +0800)
* test_darkent_bug

* test_darkent

* add resnext tests

.gitignore
python/tvm/relay/frontend/darknet.py
tests/python/frontend/darknet/test_forward.py

index f044577..2f124d9 100644 (file)
@@ -231,4 +231,4 @@ conda/pkg
 
 # antlr files
 *.tokens
-*.interp
\ No newline at end of file
+*.interp
index f452146..982bcea 100644 (file)
@@ -458,11 +458,11 @@ class GraphProto(object):
         if layer.nweights == 0:
             return None
 
-        if (layer.n * layer.c * layer.size * layer.size) != layer.nweights:
+        if (layer.n * layer.c // layer.groups * layer.size * layer.size) != layer.nweights:
             raise RuntimeError("layer weights size not matching with n c h w")
 
         params = {}
-        shape = (layer.n, layer.c, layer.size, layer.size)
+        shape = (layer.n, layer.c // layer.groups, layer.size, layer.size)
         weights = self._read_memory_buffer(shape, layer.weights)
 
         biases = self._read_memory_buffer((layer.n, ), layer.biases)
index ebfbbd3..51f05d7 100644 (file)
@@ -189,6 +189,18 @@ def test_forward_resnet50():
     verify_darknet_frontend(net)
     LIB.free_network(net)
 
+def test_forward_resnext50():
+    '''test resnet50 model'''
+    model_name = 'resnext50'
+    cfg_name = model_name + '.cfg'
+    weights_name = model_name + '.weights'
+    cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
+    weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
+    net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
+    verify_darknet_frontend(net)
+    LIB.free_network(net)
+
+
 def test_forward_yolov2():
     '''test yolov2 model'''
     model_name = 'yolov2'
@@ -441,6 +453,7 @@ def test_forward_rnn():
 
 if __name__ == '__main__':
     test_forward_resnet50()
+    test_forward_resnext50()
     test_forward_alexnet()
     test_forward_extraction()
     test_forward_yolov2()