end[dim] = min(end[dim], int(inputs[3]))
else:
if isinstance(inputs[3], _expr.Call):
- end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
+ target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
else:
- end[dim] = inputs[3]
+ target_end = inputs[3]
+
+ end[dim] = min(end[dim], target_end)
strides.append(int(inputs[4]))
return _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(strides),
- slice_mode="size")
+ slice_mode="end")
return _impl
def _split():
return _impl
+def _index():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ indices = []
+ raw_indices = []
+ max_indices_len = -1
+ for index in inputs[1]:
+ if not isinstance(index, _expr.Constant):
+ try:
+ index = _expr.const(_infer_value(index, {}))
+ except Exception:
+ raise RuntimeError("Only supports constant indices for "
+ "pytorch advanced indexing ")
+ raw_indices.append(index)
+ cindex_len = index.data.shape[0]
+ if cindex_len > max_indices_len:
+ max_indices_len = cindex_len
+
+ for index in raw_indices:
+ cnp = index.data.asnumpy()
+ cindex_len = cnp.shape[0]
+ if cindex_len < max_indices_len:
+ cnp = np.tile(cnp, max_indices_len // cindex_len)
+ indices.append(cnp)
+
+ ret = []
+ slice_map = {}
+ for i in range(indices[0].shape[0]):
+ tmp = data
+ current_indices = []
+ for index in indices:
+ current_indices.append(index[i])
+ index_key = tuple(current_indices)
+ if index_key in slice_map:
+ tmp = slice_map[index_key]
+ else:
+ tmp = _op.take(tmp, _expr.const(index[i]), axis=0)
+ slice_map[index_key] = tmp
+ ret.append(_op.expand_dims(tmp, axis=0))
+
+ return _op.concatenate(ret, axis=0)
+ return _impl
+
+
def _meshgrid():
def _impl(inputs, input_types):
data = inputs[0]
"aten::type_as" : _type_as(),
"aten::gather" : _gather(),
"aten::index_select" : _select(),
+ "aten::index" : _index(),
}
return convert_map
class Slice2(Module):
def forward(self, *args):
- return args[0][0, :, :, :]
+ return args[0][0, :, :-3, :]
class Slice3(Module):
def forward(self, *args):
x0 = torch.tensor(2) - torch.tensor(1)
x1 = torch.tensor(3) + torch.tensor(1)
- return args[0][:, x0:, :x1, :]
+ return args[0][:, x0:, 1:x1, :]
input_data = torch.rand(input_shape).float()
verify_model(Slice1().float().eval(), input_data=input_data)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+def test_forward_index():
+ torch.set_grad_enabled(False)
+ input_shape = [3, 4, 5, 6]
+
+ class Index0(Module):
+ def forward(self, x):
+ return x[[0, 1], [0, 2], :2, 4]
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Index0().eval(), input_data=input_data)
+
+ class Index1(Module):
+ def forward(self, x):
+ return x[[0], [1, 2, 3, 0], [3, 1, 2, 2], [4, 2, 1, 0]]
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Index1().eval(), input_data=input_data)
+
+
def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
test_adaptive_pool3d()
test_conv3d()
test_conv3d_transpose()
+ test_forward_index()
# Model tests
test_resnet18()