[Torch] Miscellaneous fix, enable some VM GPU tests (#6418)
authormasahi <masahi129@gmail.com>
Tue, 8 Sep 2020 14:32:34 +0000 (23:32 +0900)
committerGitHub <noreply@github.com>
Tue, 8 Sep 2020 14:32:34 +0000 (07:32 -0700)
* fix strides conversion

* enable gpu target for some vm tests

* fix pooling stride None case

Co-authored-by: masa <masa@pop-os.localdomain>
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 99d2dae..8d85009 100644 (file)
@@ -305,7 +305,9 @@ def _slice():
 
             end[dim] = min(end[dim], target_end)
 
-        strides.append(int(inputs[4]))
+        strides = [1] * len(end)
+        strides[dim] = int(inputs[4])
+
         return _op.transform.strided_slice(data,
                                            begin=_expr.const(begin),
                                            end=_expr.const(end),
@@ -683,7 +685,7 @@ def _maxpool_2d():
         data = inputs[0]
 
         pool_size = inputs[1]
-        strides = inputs[2]
+        strides = inputs[2] if inputs[2] else pool_size
         padding = inputs[3]
         dilation = inputs[4]
         ceil_mode = int(inputs[5])
@@ -706,7 +708,7 @@ def _maxpool_1d():
         data = inputs[0]
 
         pool_size = inputs[1]
-        strides = inputs[2]
+        strides = inputs[2] if inputs[2] else pool_size
         padding = inputs[3]
         dilation = inputs[4]
         ceil_mode = int(inputs[5])
@@ -723,7 +725,7 @@ def _maxpool_3d():
         data = inputs[0]
 
         pool_size = inputs[1]
-        strides = inputs[2]
+        strides = inputs[2] if inputs[2] else pool_size
         padding = inputs[3]
         dilation = inputs[4]
         ceil_mode = int(inputs[5])
@@ -1173,14 +1175,8 @@ def _avg_pool2d(prelude):
         data = inputs[0]
 
         pool_size = inputs[1]
-
-        if inputs[2]:
-            strides = inputs[2]
-        else:
-            strides = pool_size
-
+        strides = inputs[2] if inputs[2] else pool_size
         padding = inputs[3]
-
         ceil_mode = int(inputs[4])
         count_include_pad = int(inputs[5])
 
@@ -1204,14 +1200,8 @@ def _avg_pool3d():
         data = inputs[0]
 
         pool_size = inputs[1]
-
-        if inputs[2]:
-            strides = inputs[2]
-        else:
-            strides = pool_size
-
+        strides = inputs[2] if inputs[2] else pool_size
         padding = inputs[3]
-
         ceil_mode = int(inputs[4])
         count_include_pad = int(inputs[5])
 
index b35c7d6..e651700 100644 (file)
@@ -674,6 +674,13 @@ def test_forward_maxpool2d():
                                     stride=2).eval(),
                  input_data)
 
+    # A functional variant (default strides = None case)
+    class MaxPool2D(Module):
+        def forward(self, *args):
+            return torch.nn.functional.max_pool2d(args[0], kernel_size=[10, 10])
+
+    verify_model(MaxPool2D(), input_data=input_data)
+
     class MaxPool2DWithIndices(Module):
         def __init__(self):
             super(MaxPool2DWithIndices, self).__init__()
@@ -700,6 +707,14 @@ def test_forward_maxpool1d():
                                     stride=2).eval(),
                  input_data)
 
+    # A functional variant (default strides = None case)
+    class MaxPool1D(Module):
+        def forward(self, *args):
+            return torch.nn.functional.max_pool1d(args[0], kernel_size=10)
+
+    verify_model(MaxPool1D(), input_data=input_data)
+
+
 @tvm.testing.uses_gpu
 def test_forward_maxpool3d():
     torch.set_grad_enabled(False)
@@ -715,6 +730,14 @@ def test_forward_maxpool3d():
                                     stride=2).eval(),
                  input_data)
 
+    # A functional variant (default strides = None case)
+    class MaxPool3D(Module):
+        def forward(self, *args):
+            return torch.nn.functional.max_pool3d(args[0], kernel_size=[10, 10, 10])
+
+    verify_model(MaxPool3D(), input_data=input_data)
+
+
 @tvm.testing.uses_gpu
 def test_forward_split():
     torch.set_grad_enabled(False)
@@ -1323,10 +1346,20 @@ def test_forward_slice():
             x1 = torch.tensor(3) + torch.tensor(1)
             return args[0][:, x0:, 1:x1, :]
 
+    class SliceWithStride(torch.nn.Module):
+        def forward(self, x):
+            return x[..., 0::2] + x[..., 1::2]
+
+    class SliceWithStride2(torch.nn.Module):
+        def forward(self, x):
+            return x[0::2, 0::2] + x[1::2, 1::2]
+
     input_data = torch.rand(input_shape).float()
-    verify_model(Slice1().float().eval(), input_data=input_data)
-    verify_model(Slice2().float().eval(), input_data=input_data)
-    verify_model(Slice3().float().eval(), input_data=input_data)
+    verify_model(Slice1(), input_data=input_data)
+    verify_model(Slice2(), input_data=input_data)
+    verify_model(Slice3(), input_data=input_data)
+    verify_model(SliceWithStride(), input_data=torch.randn(1, 4))
+    verify_model(SliceWithStride2(), input_data=torch.randn(4, 4))
 
 
 @tvm.testing.uses_gpu
@@ -1584,9 +1617,11 @@ def test_forward_nms():
         scores = torch.rand(num_boxes, dtype=torch.float)
         return boxes, scores
 
+    targets = ["llvm"]  # dynamic nms does not work on gpu
+
     for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]:
         in_boxes, in_scores = _gen_rand_inputs(num_boxes)
-        verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores])
+        verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores], targets)
 
 
 @tvm.testing.uses_gpu
@@ -1752,19 +1787,23 @@ def test_3d_models():
     verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4)
 
 
-def verify_script_model(pt_model, ishapes):
+def _get_default_vm_targets():
+    return [tgt for (tgt, _) in tvm.testing.enabled_targets()]
+
+
+def verify_script_model(pt_model, ishapes, targets):
     script_module = torch.jit.script(pt_model)
-    verify_model_vm(script_module, ishapes)
+    verify_model_vm(script_module, ishapes, targets=targets)
 
 
-def verify_trace_model(pt_model, idata):
+def verify_trace_model(pt_model, idata, targets):
     traced_model = torch.jit.trace(pt_model, idata)
     ishapes = [data.shape for data in idata]
-    verify_model_vm(traced_model, ishapes, idata=idata)
+    verify_model_vm(traced_model, ishapes, idata=idata, targets=targets)
 
 
-def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None):
-    input_model = imodel
+def verify_model_vm(input_model, ishapes, idtype=torch.float,
+                    idata=None, targets=["llvm"]):
     input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
     input_shapes = list(zip(input_names, ishapes))
     input_data = idata if idata else [torch.randn(shape, dtype=idtype)
@@ -1772,26 +1811,29 @@ def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None):
     # Compile via VM
     mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
 
-    executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
-                                     target="llvm")
-    evaluator = executor.evaluate()
+    for tgt in targets:
+        print("Running on target", tgt)
+        ctx = tvm.context(tgt, 0)
 
-    # Inference
-    for name, inp in zip(input_names, input_data):
-        params[name] = inp.numpy()
-    vm_res = evaluator(**params)
+        executor = relay.create_executor("vm", mod=mod, ctx=ctx, target=tgt)
+        evaluator = executor.evaluate()
 
-    # Baseline result
-    with torch.no_grad():
-        pt_result = input_model(*input_data)
+        # Inference
+        for name, inp in zip(input_names, input_data):
+            params[name] = inp.numpy()
+        vm_res = evaluator(**params)
 
-    # Verify the accuracy
-    if not isinstance(pt_result, torch.Tensor):
-        tvm_res = vm_res.asnumpy().item()
-        assert pt_result == tvm_res
-    else:
-        tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(),
-                                    rtol=1e-5, atol=1e-5)
+        # Baseline result
+        with torch.no_grad():
+            pt_result = input_model(*input_data)
+
+        # Verify the accuracy
+        if not isinstance(pt_result, torch.Tensor):
+            tvm_res = vm_res.asnumpy().item()
+            assert pt_result == tvm_res
+        else:
+            tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(),
+                                        rtol=1e-5, atol=1e-5)
 
 
 @tvm.testing.uses_gpu
@@ -1905,7 +1947,7 @@ def test_control_flow():
     ]
 
     for pt_model in models:
-        verify_script_model(pt_model.eval(), [(10, 20)])
+        verify_script_model(pt_model.eval(), [(10, 20)], _get_default_vm_targets())
 
 
 @tvm.testing.uses_gpu
@@ -1943,7 +1985,7 @@ def test_simple_rnn():
                 y, h = self.cell(xs[i], h)
             return y
 
-    verify_script_model(RNNLoop().eval(), [(10, 10, 4)])
+    verify_script_model(RNNLoop().eval(), [(10, 10, 4)], _get_default_vm_targets())
 
 
 @tvm.testing.uses_gpu