[Frontend][Pytorch]Add Pytorch advanced indexing (#6318)
authorYao Wang <kevinthesunwy@gmail.com>
Sat, 22 Aug 2020 01:54:26 +0000 (18:54 -0700)
committerGitHub <noreply@github.com>
Sat, 22 Aug 2020 01:54:26 +0000 (10:54 +0900)
* Add Pytorch advanced indexing

* Minor fix for test

* Fix for cuda

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index b75f3f9..7237403 100644 (file)
@@ -274,16 +274,18 @@ def _slice():
             end[dim] = min(end[dim], int(inputs[3]))
         else:
             if isinstance(inputs[3], _expr.Call):
-                end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
+                target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
             else:
-                end[dim] = inputs[3]
+                target_end = inputs[3]
+
+            end[dim] = min(end[dim], target_end)
 
         strides.append(int(inputs[4]))
         return _op.transform.strided_slice(data,
                                            begin=_expr.const(begin),
                                            end=_expr.const(end),
                                            strides=_expr.const(strides),
-                                           slice_mode="size")
+                                           slice_mode="end")
     return _impl
 
 def _split():
@@ -1759,6 +1761,50 @@ def _one_hot():
     return _impl
 
 
+def _index():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        indices = []
+        raw_indices = []
+        max_indices_len = -1
+        for index in inputs[1]:
+            if not isinstance(index, _expr.Constant):
+                try:
+                    index = _expr.const(_infer_value(index, {}))
+                except Exception:
+                    raise RuntimeError("Only supports constant indices for "
+                                       "pytorch advanced indexing ")
+            raw_indices.append(index)
+            cindex_len = index.data.shape[0]
+            if cindex_len > max_indices_len:
+                max_indices_len = cindex_len
+
+        for index in raw_indices:
+            cnp = index.data.asnumpy()
+            cindex_len = cnp.shape[0]
+            if cindex_len < max_indices_len:
+                cnp = np.tile(cnp, max_indices_len // cindex_len)
+            indices.append(cnp)
+
+        ret = []
+        slice_map = {}
+        for i in range(indices[0].shape[0]):
+            tmp = data
+            current_indices = []
+            for index in indices:
+                current_indices.append(index[i])
+                index_key = tuple(current_indices)
+                if index_key in slice_map:
+                    tmp = slice_map[index_key]
+                else:
+                    tmp = _op.take(tmp, _expr.const(index[i]), axis=0)
+                    slice_map[index_key] = tmp
+            ret.append(_op.expand_dims(tmp, axis=0))
+
+        return _op.concatenate(ret, axis=0)
+    return _impl
+
+
 def _meshgrid():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -2064,6 +2110,7 @@ def _get_convert_map(prelude):
         "aten::type_as"                         : _type_as(),
         "aten::gather"                          : _gather(),
         "aten::index_select"                    : _select(),
+        "aten::index"                           : _index(),
     }
     return convert_map
 
index e5c9634..ab0a4b0 100644 (file)
@@ -1202,13 +1202,13 @@ def test_forward_slice():
 
     class Slice2(Module):
         def forward(self, *args):
-            return args[0][0, :, :, :]
+            return args[0][0, :, :-3, :]
 
     class Slice3(Module):
         def forward(self, *args):
             x0 = torch.tensor(2) - torch.tensor(1)
             x1 = torch.tensor(3) + torch.tensor(1)
-            return args[0][:, x0:, :x1, :]
+            return args[0][:, x0:, 1:x1, :]
 
     input_data = torch.rand(input_shape).float()
     verify_model(Slice1().float().eval(), input_data=input_data)
@@ -2620,6 +2620,25 @@ def test_forward_matmul():
     verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
 
 
+def test_forward_index():
+    torch.set_grad_enabled(False)
+    input_shape = [3, 4, 5, 6]
+
+    class Index0(Module):
+        def forward(self, x):
+            return x[[0, 1], [0, 2], :2, 4]
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Index0().eval(), input_data=input_data)
+
+    class Index1(Module):
+        def forward(self, x):
+            return x[[0], [1, 2, 3, 0], [3, 1, 2, 2], [4, 2, 1, 0]]
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Index1().eval(), input_data=input_data)
+
+
 def test_forward_pretrained_bert_base_uncased():
     ######################################################################
     # This is an example how to run BERT models using TVM
@@ -2859,6 +2878,7 @@ if __name__ == "__main__":
     test_adaptive_pool3d()
     test_conv3d()
     test_conv3d_transpose()
+    test_forward_index()
 
     # Model tests
     test_resnet18()