return _op.transform.take(data, index, axis=dim)
return _impl
+def _reciprocal():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ return _expr.const(1.0) / data
+ return _impl
+
+def _repeat():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ reps = _get_dims(inputs[1])
+ return _op.transform.tile(data, reps=reps)
+ return _impl
+
+def _repeat_interleave():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ if isinstance(inputs[1], int):
+ repeats = inputs[1]
+ axis = inputs[2]
+ else:
+ msg = "Only repeat with one value as repeat is currently supported."
+ raise AssertionError(msg)
+ if axis is None: # Flatten the data if no axis is given from torch
+ data = _op.transform.reshape(data, [-1])
+ axis = 0
+ return _op.transform.repeat(data, repeats=repeats, axis=axis)
+ return _impl
+
def _ones():
def _impl(inputs, input_types):
data = inputs[0]
return _op.transform.reshape(data, new_shape)
return _impl
+def _reshape():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ if isinstance(inputs[1], list):
+ new_shape = inputs[1]
+ else:
+ new_shape = _infer_shape(inputs[1])
+ return _op.transform.reshape(data, new_shape)
+ return _impl
+
def _clone():
def _impl(inputs, input_types):
data = inputs[0]
"aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
+ "aten::reciprocal" : _reciprocal(),
+ "aten::repeat" : _repeat(),
+ "aten::repeat_interleave" : _repeat_interleave(),
"aten::to" : _to(),
"aten::squeeze" : _squeeze(),
"aten::unsqueeze" : _unsqueeze(),
"aten::addmm" : _dense(),
"aten::size" : _size(),
"aten::view" : _view(),
+ "aten::reshape" : _reshape(),
"aten::clone" : _clone(),
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
verify_model(Multiply3().float().eval(), input_data=input_data)
verify_model(Multiply4().float().eval(), input_data=input_data)
+def test_forward_reciprocal():
+ torch.set_grad_enabled(False)
+ input_shape = [2, 1, 10, 1, 10]
+ class Reciprocal1(Module):
+ def forward(self, *args):
+ return args[0].reciprocal()
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Reciprocal1().float().eval(), input_data=input_data)
+
+def test_forward_repeat():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3]
+ class Repeat1(Module):
+ def forward(self, *args):
+ return args[0].repeat(1, 1)
+
+ class Repeat2(Module):
+ def forward(self, *args):
+ return args[0].repeat(4, 2)
+
+ class Repeat3(Module):
+ def forward(self, *args):
+ return args[0].repeat(4, 2, 1)
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Repeat1().float().eval(), input_data=input_data)
+ verify_model(Repeat2().float().eval(), input_data=input_data)
+ verify_model(Repeat3().float().eval(), input_data=input_data)
+
+def test_forward_repeat_interleave():
+ torch.set_grad_enabled(False)
+ input_shape = [2, 2, 3]
+ class RepeatInterleave1(Module):
+ def forward(self, *args):
+ return args[0].repeat_interleave(2)
+
+ class RepeatInterleave2(Module):
+ def forward(self, *args):
+ return args[0].repeat_interleave(3, dim=0)
+
+ class RepeatInterleave3(Module):
+ def forward(self, *args):
+ return args[0].repeat_interleave(2, dim=1)
+
+ class RepeatInterleave4(Module):
+ def forward(self, *args):
+ return args[0].repeat_interleave(4, dim=2)
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(RepeatInterleave1().float().eval(), input_data=input_data)
+ verify_model(RepeatInterleave2().float().eval(), input_data=input_data)
+ verify_model(RepeatInterleave3().float().eval(), input_data=input_data)
+ verify_model(RepeatInterleave4().float().eval(), input_data=input_data)
+
def test_forward_unsqueeze():
torch.set_grad_enabled(False)
input_shape = [10, 10]
init_weight(ln.eval())
verify_model(ln.eval(), input_data=inp)
+def test_forward_reshape():
+ torch.set_grad_enabled(False)
+ input_shape = [2, 1, 10, 1, 10]
+ new_shape = [2, 1, 10, 10]
+ class Reshape1(Module):
+ def forward(self, *args):
+ return args[0].reshape(new_shape)
+
+ class Reshape2(Module):
+ def forward(self, *args):
+ return args[0].reshape([-1])
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Reshape1().float().eval(), input_data=input_data)
+ verify_model(Reshape2().float().eval(), input_data=input_data)
+
def test_forward_transpose():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
test_forward_add()
test_forward_subtract()
test_forward_multiply()
+ test_forward_reshape()
+ test_forward_reciprocal()
+ test_forward_repeat()
+ test_forward_repeat_interleave()
test_forward_squeeze()
test_forward_unsqueeze()
test_forward_concatenate()