[Torch] Upsampling op support and enable registering a user defined op conversion...
authormasahi <masahi129@gmail.com>
Sun, 1 Mar 2020 00:51:49 +0000 (09:51 +0900)
committerGitHub <noreply@github.com>
Sun, 1 Mar 2020 00:51:49 +0000 (09:51 +0900)
* add custom conversion map

* add roi align test using custom convert map

* refactor test

* add support for upsampling op and test on segmentation models

* remove redundant no_grad

* add upsampling test case

* make the default custom map None, instead of empty dict

* updated tests, remove packaging and drop PT 1.2 support

* add better support for aten::to and tests

* add a note on dilation in x86

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py
tutorials/frontend/from_pytorch.py

index fd66e3c..b256faa 100644 (file)
@@ -19,7 +19,6 @@
 # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
 """PT: PyTorch frontend."""
 import itertools
-from packaging import version
 
 import numpy as np
 
@@ -31,6 +30,7 @@ from .. import expr as _expr
 from .. import op as _op
 from .common import get_relay_op
 from .common import infer_shape as _infer_shape
+from .common import infer_value as _infer_value
 
 __all__ = ["from_pytorch"]
 
@@ -614,6 +614,61 @@ def _sqrt():
         return _op.tensor.sqrt(data)
     return _impl
 
+def _floor():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        return _op.floor(data)
+    return _impl
+
+def _to():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        if inputs[3] in ["cpu", "cuda"]:
+            return data
+        # special handling for aten::to(data, 6, _, _, _) case
+        # 6 means dtype = float
+        # this happens when converting upsampling with scale factor
+        cast_func = {
+            6: float,
+            3: int,
+        }
+        cast_func_expr = {
+            6: lambda x: _op.cast(x, "float32"),
+            3: lambda x: _op.cast(x, "int32"),
+        }
+        if inputs[1] in cast_func and not isinstance(data, _expr.Expr):
+            return cast_func[inputs[1]](data)
+        elif inputs[1] in cast_func and isinstance(data, _expr.Expr):
+            return cast_func_expr[inputs[1]](data)
+        return data
+
+    return _impl
+
+def _upsample(method):
+    def _impl(inputs, input_types):
+        if isinstance(inputs[1], _expr.Var):
+            out_size = _infer_shape(inputs[1])
+        elif isinstance(inputs[1], list):
+            infer_res = [_infer_value(size, {}) for size in inputs[1]]
+            out_size = [np.asscalar(res.asnumpy().astype(np.int))
+                        for res in infer_res]
+
+        data = inputs[0]
+
+        if len(inputs) > 2:
+            align_corners = inputs[2]
+        else:
+            align_corners = False
+
+        if align_corners:
+            coord_trans = "align_corners"
+        else:
+            coord_trans = "half_pixel"
+
+        return _op.image.resize(data, out_size, "NCHW", method, coord_trans)
+
+    return _impl
+
 # Helper functions for operator implementation
 
 def _convert_data_type(input_type):
@@ -686,7 +741,7 @@ _convert_map = {
     "aten::div_"                            : _elemwise("divide"),
     "aten::ones"                            : _ones(),
     "aten::zeros"                           : _zeros(),
-    "aten::to"                              : _identity(),
+    "aten::to"                              : _to(),
     "aten::unsqueeze"                       : _unsqueeze(),
     "aten::cat"                             : _concatenate(),
     "aten::slice"                           : _slice(),
@@ -729,15 +784,18 @@ _convert_map = {
     "aten::permute"                         : _transpose(),
     "aten::sum"                             : _reduce("sum"),
     "aten::prod"                            : _reduce("prod"),
-    "aten::sqrt"                            : _sqrt()
+    "aten::sqrt"                            : _sqrt(),
+    'aten::floor'                           : _floor(),
+    "aten::detach"                          : _identity(),
+    "aten::upsample_bilinear2d"             : _upsample("bilinear"),
+    "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
 }
 
 
 def _run_jit_passes(graph):
     """ The inline pass is necessary to unwrap prim::CallMethod """
     import torch
-    if version.parse(torch.__version__) >= version.parse("1.4.0"):
-        torch._C._jit_pass_inline(graph)
+    torch._C._jit_pass_inline(graph)
 
 
 def _is_int_seq(seq):
@@ -985,8 +1043,7 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
 
 def get_all_op_names(graph):
     """ Return all operator names in the input graph """
-    nodes = list(graph.nodes())
-    return set(node.kind() for node in nodes)
+    return set(node.kind() for node in graph.nodes())
 
 
 def get_graph_input_names(script_module):
@@ -997,7 +1054,7 @@ def get_graph_input_names(script_module):
     return ir_inputs[1:]  # remove self at the 0th arg
 
 
-def from_pytorch(script_module, input_shapes):
+def from_pytorch(script_module, input_shapes, custom_convert_map=None):
     """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
     The companion parameters will be handled automatically.
 
@@ -1011,6 +1068,9 @@ def from_pytorch(script_module, input_shapes):
         Graph level input shape dictionary
         The keys should be the same one returned by get_graph_input_names(...) above
 
+    custom_convert_map: Dictionary of str to Relay op
+        A custom op conversion map in the same format as _convert_map above
+
     Returns
     -------
     mod : tvm.relay.Module
@@ -1021,6 +1081,10 @@ def from_pytorch(script_module, input_shapes):
     """
     graph = script_module.graph.copy()
     _run_jit_passes(graph)
+
+    if custom_convert_map:
+        _convert_map.update(custom_convert_map)
+
     op_names = get_all_op_names(graph)
     _report_missing_conversion(op_names)
 
index 831389b..c2ff94d 100644 (file)
 # pylint: disable=import-self, invalid-name, unused-argument
 """Unit tests for various models and operators"""
 from time import time
-import os
 import sys
-from tempfile import TemporaryDirectory
 from scipy.stats import t as tdistr
 import numpy as np
 import torch
 from torch.nn import Module
 import tvm
-from tvm import te
 import torchvision
 
 from tvm import relay
@@ -36,22 +33,6 @@ from tvm.relay.frontend.pytorch import get_graph_input_names
 
 sys.setrecursionlimit(10000)
 
-def _vectorize(ten):
-    return ten.reshape(-1)
-
-def atol(tru, est):
-    def _atol_elt(tru, est):
-        return abs(tru - est)
-    tru = _vectorize(tru)
-    est = _vectorize(est)
-    return max([_atol_elt(x, y) for x, y in zip(tru, est)])
-
-def rtol(tru, est):
-    def _rtol_elt(tru, est):
-        return abs(tru - est) / min(abs(tru), abs(est))
-    tru = _vectorize(tru)
-    est = _vectorize(est)
-    return max([_rtol_elt(x, y) for x, y in zip(tru, est)])
 
 def assert_shapes_match(tru, est):
     if tru.shape != est.shape:
@@ -77,7 +58,7 @@ def load_torchvision(model_name):
             input_data[:, channel] /= std[channel]
         model = getattr(torchvision.models, model_name)(pretrained=True)
         model = model.float().eval()
-        return model, input_data
+        return model, [input_data]
 
 def load_pretrainedmodels(model_name):
     """Given a model name, returns a pretrainedmodels.pytorch model in eval
@@ -89,7 +70,7 @@ def load_pretrainedmodels(model_name):
     for channel in range(3):
         input_data[:, channel] -= model.mean[channel]
         input_data[:, channel] /= model.std[channel]
-    return model, input_data
+    return model, [input_data]
 
 def load_model(model_name):
     """Given a model name, returns a model as well as an example input."""
@@ -116,7 +97,7 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
     latencies = []
     count = 0
     while True:
-        if isinstance(model, torch.nn.Module):
+        if isinstance(model, Module):
             input_data = [torch.rand(shape).float() for shape in input_shapes]
             if torch.cuda.is_available():
                 input_data = list(map(lambda x: x.cuda(), input_data))
@@ -153,23 +134,34 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
             if err < thresh:
                 return est
 
-def verify_model(model_name, input_data=[]):
+def verify_model(model_name, input_data=[],
+                 custom_convert_map={},
+                 ctx_list=ctx_list()):
     """Assert that the output of a compiled model matches with that of its
     baseline."""
-    if len(input_data) == 0:
+    if isinstance(model_name, str):
         baseline_model, baseline_input = load_model(model_name)
-    else:
+    elif isinstance(input_data, list):
         baseline_model = model_name
         baseline_input = input_data
+    elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
+        baseline_model = model_name
+        baseline_input = [input_data]
+    else:
+        assert False, "Unexpected input format"
+
     if torch.cuda.is_available():
         baseline_model = baseline_model.cuda()
-        baseline_input = baseline_input.cuda()
+        baseline_input = [inp.cuda() for inp in baseline_input]
+
     with torch.no_grad():
-        baseline_outputs = baseline_model(baseline_input)
+        baseline_outputs = baseline_model(*baseline_input)
+
     if isinstance(baseline_outputs, tuple):
         baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
     else:
         baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
+
     trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
 
     if torch.cuda.is_available():
@@ -177,17 +169,21 @@ def verify_model(model_name, input_data=[]):
     else:
         trace = trace.cpu()
 
-    input_name = get_graph_input_names(trace)[0]  # only one input
-    input_shapes = {input_name: list(baseline_input.shape)}
-    mod, params = relay.frontend.from_pytorch(trace, input_shapes)
-    compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())}
+    input_names = get_graph_input_names(trace)
+    input_shapes = dict(zip(input_names,
+                            [inp.shape for inp in baseline_input]))
+    mod, params = relay.frontend.from_pytorch(trace, input_shapes,
+                                              custom_convert_map)
+    compiled_input = dict(zip(input_names,
+                              [inp.cpu().numpy() for inp in baseline_input]))
 
     with relay.build_config(opt_level=3):
-        for target, ctx in ctx_list():
+        for target, ctx in ctx_list:
             relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params)
             relay_model = graph_runtime.create(relay_graph, relay_lib, ctx)
             relay_model.set_input(**relay_params)
-            relay_model.set_input(**compiled_input)
+            for name, inp in compiled_input.items():
+                relay_model.set_input(name, inp)
             relay_model.run()
 
             for i, baseline_output in enumerate(baseline_outputs):
@@ -228,12 +224,11 @@ def test_forward_add():
                 ones = ones.cuda()
             return args[0] + ones
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Add1().float().eval(), input_data=input_data)
-        verify_model(Add2().float().eval(), input_data=input_data)
-        verify_model(Add3().float().eval(), input_data=input_data)
-        verify_model(Add4().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Add1().float().eval(), input_data=input_data)
+    verify_model(Add2().float().eval(), input_data=input_data)
+    verify_model(Add3().float().eval(), input_data=input_data)
+    verify_model(Add4().float().eval(), input_data=input_data)
 
 def test_forward_subtract():
     torch.set_grad_enabled(False)
@@ -261,12 +256,11 @@ def test_forward_subtract():
                 ones = ones.cuda()
             return args[0] - ones
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Subtract1().float().eval(), input_data=input_data)
-        verify_model(Subtract2().float().eval(), input_data=input_data)
-        verify_model(Subtract3().float().eval(), input_data=input_data)
-        verify_model(Subtract4().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Subtract1().float().eval(), input_data=input_data)
+    verify_model(Subtract2().float().eval(), input_data=input_data)
+    verify_model(Subtract3().float().eval(), input_data=input_data)
+    verify_model(Subtract4().float().eval(), input_data=input_data)
 
 def test_forward_multiply():
     torch.set_grad_enabled(False)
@@ -294,12 +288,11 @@ def test_forward_multiply():
                 ones = ones.cuda()
             return args[0] * ones
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Multiply1().float().eval(), input_data=input_data)
-        verify_model(Multiply2().float().eval(), input_data=input_data)
-        verify_model(Multiply3().float().eval(), input_data=input_data)
-        verify_model(Multiply4().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Multiply1().float().eval(), input_data=input_data)
+    verify_model(Multiply2().float().eval(), input_data=input_data)
+    verify_model(Multiply3().float().eval(), input_data=input_data)
+    verify_model(Multiply4().float().eval(), input_data=input_data)
 
 def test_forward_unsqueeze():
     torch.set_grad_enabled(False)
@@ -327,10 +320,9 @@ def test_forward_concatenate():
             c = (args[0][:, :, 2] + 5) * 13
             return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2)
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Concatenate1().float().eval(), input_data=input_data)
-        verify_model(Concatenate2().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Concatenate1().float().eval(), input_data=input_data)
+    verify_model(Concatenate2().float().eval(), input_data=input_data)
 
 def test_forward_relu():
     torch.set_grad_enabled(False)
@@ -340,9 +332,8 @@ def test_forward_relu():
         def forward(self, *args):
             return torch.nn.ReLU()(args[0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(ReLU1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(ReLU1().float().eval(), input_data=input_data)
 
 def test_forward_adaptiveavgpool():
     torch.set_grad_enabled(False)
@@ -356,10 +347,9 @@ def test_forward_adaptiveavgpool():
         def forward(self, *args):
             return torch.nn.AdaptiveAvgPool2d([10, 10])(args[0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data)
-        verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(AdaptiveAvgPool2D1().float().eval(), input_data=input_data)
+    verify_model(AdaptiveAvgPool2D2().float().eval(), input_data=input_data)
 
 def test_forward_maxpool():
     torch.set_grad_enabled(False)
@@ -373,10 +363,9 @@ def test_forward_maxpool():
         def forward(self, *args):
             return torch.nn.MaxPool2d(kernel_size=[10, 10])(args[0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(MaxPool2D1().float().eval(), input_data=input_data)
-        verify_model(MaxPool2D2().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(MaxPool2D1().float().eval(), input_data=input_data)
+    verify_model(MaxPool2D2().float().eval(), input_data=input_data)
 
 def test_forward_avgpool():
     torch.set_grad_enabled(False)
@@ -386,9 +375,8 @@ def test_forward_avgpool():
         def forward(self, *args):
             return torch.nn.AvgPool2d(kernel_size=[10, 10])(args[0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(AvgPool2D1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(AvgPool2D1().float().eval(), input_data=input_data)
 
 def test_forward_hardtanh():
     torch.set_grad_enabled(False)
@@ -398,9 +386,8 @@ def test_forward_hardtanh():
         def forward(self, *args):
             return torch.nn.Hardtanh()(args[0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(HardTanh1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(HardTanh1().float().eval(), input_data=input_data)
 
 def test_forward_conv():
     torch.set_grad_enabled(False)
@@ -433,11 +420,10 @@ def test_forward_conv():
         def forward(self, *args):
             return self.softmax(self.conv(args[0]))
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Conv2D1().float().eval(), input_data=input_data)
-        verify_model(Conv2D2().float().eval(), input_data=input_data)
-        verify_model(Conv2D3().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Conv2D1().float().eval(), input_data=input_data)
+    verify_model(Conv2D2().float().eval(), input_data=input_data)
+    verify_model(Conv2D3().float().eval(), input_data=input_data)
 
 def test_forward_threshold():
     torch.set_grad_enabled(False)
@@ -447,9 +433,8 @@ def test_forward_threshold():
         def forward(self, *args):
             return torch.nn.Threshold(0, 0)(args[0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Threshold1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Threshold1().float().eval(), input_data=input_data)
 
 def test_forward_contiguous():
     torch.set_grad_enabled(False)
@@ -459,9 +444,8 @@ def test_forward_contiguous():
         def forward(self, *args):
             return args[0].contiguous()
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Contiguous1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Contiguous1().float().eval(), input_data=input_data)
 
 def test_forward_batchnorm():
     torch.set_grad_enabled(False)
@@ -481,10 +465,9 @@ def test_forward_batchnorm():
         def forward(self, *args):
             return self.batch_norm(args[0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(BatchNorm1().float().eval(), input_data=input_data)
-        verify_model(BatchNorm2().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(BatchNorm1().float().eval(), input_data=input_data)
+    verify_model(BatchNorm2().float().eval(), input_data=input_data)
 
 def test_forward_transpose():
     torch.set_grad_enabled(False)
@@ -498,10 +481,9 @@ def test_forward_transpose():
         def forward(self, *args):
             return args[0].transpose(-2, -1)
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Transpose1().float().eval(), input_data=input_data)
-        verify_model(Transpose2().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Transpose1().float().eval(), input_data=input_data)
+    verify_model(Transpose2().float().eval(), input_data=input_data)
 
 def test_forward_size():
     torch.set_grad_enabled(False)
@@ -511,9 +493,8 @@ def test_forward_size():
         def forward(self, *args):
             return float(args[0].size(0)) * args[0]
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Size1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Size1().float().eval(), input_data=input_data)
 
 def test_forward_view():
     torch.set_grad_enabled(False)
@@ -527,10 +508,9 @@ def test_forward_view():
         def forward(self, *args):
             return args[0].view(args[0].shape[0], -1)
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(View1().float().eval(), input_data=input_data)
-        verify_model(View2().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(View1().float().eval(), input_data=input_data)
+    verify_model(View2().float().eval(), input_data=input_data)
 
 def test_forward_select():
     torch.set_grad_enabled(False)
@@ -540,9 +520,8 @@ def test_forward_select():
         def forward(self, *args):
             return args[0].select(1, 1)
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Select1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Select1().float().eval(), input_data=input_data)
 
 def test_forward_clone():
     torch.set_grad_enabled(False)
@@ -552,9 +531,8 @@ def test_forward_clone():
         def forward(self, *args):
             return args[0].clone()
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Clone1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Clone1().float().eval(), input_data=input_data)
 
 def test_forward_logsoftmax():
     torch.set_grad_enabled(False)
@@ -564,9 +542,8 @@ def test_forward_logsoftmax():
         def forward(self, *args):
             return torch.nn.LogSoftmax(dim=1)(args[0][0, 0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(LogSoftmax1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(LogSoftmax1().float().eval(), input_data=input_data)
 
 def test_forward_sigmoid():
     torch.set_grad_enabled(False)
@@ -576,9 +553,8 @@ def test_forward_sigmoid():
         def forward(self, *args):
             return torch.nn.Sigmoid()(args[0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Sigmoid1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Sigmoid1().float().eval(), input_data=input_data)
 
 def test_forward_dense():
     torch.set_grad_enabled(False)
@@ -598,10 +574,9 @@ def test_forward_dense():
         def forward(self, *args):
             return self.linear(args[0][0, 0])
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Dense1().float().eval(), input_data=input_data)
-        verify_model(Dense2().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Dense1().float().eval(), input_data=input_data)
+    verify_model(Dense2().float().eval(), input_data=input_data)
 
 def test_forward_dropout():
     torch.set_grad_enabled(False)
@@ -611,9 +586,8 @@ def test_forward_dropout():
         def forward(self, *args):
             return torch.nn.functional.dropout(args[0][0, 0], 0.5, False)
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Dropout1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Dropout1().float().eval(), input_data=input_data)
 
 def test_forward_slice():
     torch.set_grad_enabled(False)
@@ -627,10 +601,9 @@ def test_forward_slice():
         def forward(self, *args):
             return args[0][0, :, :, :]
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Slice1().float().eval(), input_data=input_data)
-        verify_model(Slice2().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Slice1().float().eval(), input_data=input_data)
+    verify_model(Slice2().float().eval(), input_data=input_data)
 
 def test_forward_mean():
     torch.set_grad_enabled(False)
@@ -640,9 +613,8 @@ def test_forward_mean():
         def forward(self, *args):
             return args[0].mean(2)
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Mean1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Mean1().float().eval(), input_data=input_data)
 
 def test_forward_expand():
     torch.set_grad_enabled(False)
@@ -652,9 +624,8 @@ def test_forward_expand():
         def forward(self, *args):
             return args[0].expand((3, -1, -1, -1))
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Expand1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Expand1().float().eval(), input_data=input_data)
 
 def test_forward_pow():
     torch.set_grad_enabled(False)
@@ -664,9 +635,8 @@ def test_forward_pow():
         def forward(self, *args):
             return args[0] ** 2
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Pow1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Pow1().float().eval(), input_data=input_data)
 
 def test_forward_chunk():
     torch.set_grad_enabled(False)
@@ -677,9 +647,61 @@ def test_forward_chunk():
             chunks = args[0].chunk(7, 2)
             return torch.cat(chunks, 2)
 
-    with torch.no_grad():
-        input_data = torch.rand(input_shape).float()
-        verify_model(Chunk1().float().eval(), input_data=input_data)
+    input_data = torch.rand(input_shape).float()
+    verify_model(Chunk1().float().eval(), input_data=input_data)
+
+def test_upsample():
+    class Upsample(Module):
+        def __init__(self, size=None, scale=None,
+                     mode="nearest", align_corners=None):
+            super().__init__()
+            self.size = size
+            self.scale = scale
+            self.mode = mode
+            self.align_corners = align_corners
+
+        def forward(self, x):
+            return torch.nn.functional.interpolate(x, size=self.size,
+                                                   scale_factor=self.scale,
+                                                   mode=self.mode,
+                                                   align_corners=self.align_corners)
+    inp = torch.rand((1, 3, 32, 32))
+    verify_model(Upsample(size=(64, 64), mode="nearest"), inp)
+    verify_model(Upsample(scale=2, mode="nearest"), inp)
+    verify_model(Upsample(size=(50, 50), mode="nearest"), inp)
+    verify_model(Upsample(size=(64, 64), mode="bilinear", align_corners=True), inp)
+    verify_model(Upsample(scale=2, mode="bilinear", align_corners=True), inp)
+    verify_model(Upsample(size=(50, 50), mode="bilinear", align_corners=True), inp)
+
+def test_to():
+    """ test for aten::to(...) """
+    class ToCPU(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x.to("cpu")
+
+    class ToFloat(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x.float()
+
+    class ToInt(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x.int()
+
+    verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32)))
+    verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int))
+    verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int))
+    verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32)))
+    verify_model(ToInt().eval(), torch.tensor(2.0))
+
 
 # Model tests
 def test_resnet18():
@@ -730,6 +752,57 @@ def test_vgg11_bn():
 """
 
 
+def test_custom_conversion_map():
+    def get_roi_align():
+        pool_size = 5
+        n_channels = 2 * (pool_size ** 2)
+        x = torch.rand(2, n_channels, 10, 10)
+        rois = torch.tensor([[0, 0, 0, 9, 9],  # format is (xyxy)
+                             [0, 0, 5, 4, 9],
+                             [0, 5, 5, 9, 9],
+                             [1, 0, 0, 9, 9]], dtype=torch.float)
+        roi_align = torchvision.ops.RoIAlign(pool_size, spatial_scale=1,
+                                             sampling_ratio=-1)
+        return roi_align.eval(), [x, rois]
+
+    def convert_roi_align():
+        def _impl(inputs, input_types):
+            spatial_scale = inputs[2]
+            pooled_size = (inputs[3], inputs[4])
+            sampling_ratio = inputs[5]
+            return relay.op.vision.roi_align(inputs[0], inputs[1],
+                                             pooled_size, spatial_scale,
+                                             sampling_ratio)
+        return _impl
+
+    custom_map = {'torchvision::roi_align': convert_roi_align()}
+    model, inputs = get_roi_align()
+
+    verify_model(model, inputs, custom_map)
+
+
+def test_segmentaton_models():
+    class SegmentationModelWrapper(Module):
+        def __init__(self, model):
+            super().__init__()
+            self.model = model
+
+        def forward(self, inp):
+            out = self.model(inp)
+            return out["out"]
+
+    fcn = torchvision.models.segmentation.fcn_resnet101(pretrained=True)
+    deeplab = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
+
+    inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)]
+
+    for model in [fcn, deeplab]:
+        # depthwise + dilated covolution not supported on x86
+        # see https://github.com/apache/incubator-tvm/issues/4962
+        verify_model(SegmentationModelWrapper(model.eval()), inp,
+                     ctx_list=[("cuda", tvm.gpu(0))])
+
+
 if __name__ == "__main__":
     # Single operator tests
     test_forward_add()
@@ -760,6 +833,8 @@ if __name__ == "__main__":
     test_forward_expand()
     test_forward_pow()
     test_forward_chunk()
+    test_upsample()
+    test_to()
 
     # Model tests
     test_resnet18()
@@ -770,3 +845,7 @@ if __name__ == "__main__":
     test_googlenet()
     test_mnasnet0_5()
     test_mobilenet_v2()
+
+    test_custom_conversion_map()
+
+    test_segmentaton_models()
index 503f64a..1c568ce 100644 (file)
@@ -37,7 +37,7 @@ https://pytorch.org/get-started/locally/
 PyTorch versions should be backwards compatible but should be used
 with the proper TorchVision version.
 
-Currently, TVM supports PyTorch 1.4, 1.3, and 1.2. Other versions may
+Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
 be unstable.
 """