return get_relay_op(name)(data0, data1)
return _impl
+def _abs():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ return _op.abs(data)
+ return _impl
+
+def _arange():
+ def _impl(inputs, input_types):
+ if len(inputs) == 5:
+ dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1])
+ start = _create_typed_const(0, dtype)
+ stop = _create_typed_const(inputs[0], dtype)
+ step = _create_typed_const(1, dtype)
+ elif len(inputs) == 7:
+ dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
+ start = _create_typed_const(inputs[0], dtype)
+ stop = _create_typed_const(inputs[1], dtype)
+ step = _create_typed_const(inputs[2], dtype)
+ else:
+ msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
+ raise AssertionError(msg)
+ return _op.transform.arange(start=start,
+ stop=stop,
+ step=step,
+ dtype=_convert_data_type(dtype))
+ return _impl
+
def _squeeze():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.sigmoid(data)
return _impl
+def _softplus():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ beta = _expr.const(float(inputs[1]))
+ return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
+ return _impl
+
def _avg_pool2d():
def _impl(inputs, input_types):
data = inputs[0]
return _impl
# Helper functions for operator implementation
+def _convert_dtype_value(val):
+ convert_torch_dtype_map = {7:"torch.float64",
+ 6:"torch.float32",
+ 5:"torch.float16",
+ 4:"torch.int64",
+ 3:"torch.int32",
+ 2:"torch.int16",
+ 1:"torch.int8",
+ 0:"torch.unit8",
+ None:"torch.int64"} # Default is torch.int64
+ if val in convert_torch_dtype_map:
+ return convert_torch_dtype_map[val]
+ else:
+ msg = "Torch data type value %d is not handled yet." % (val)
+ raise NotImplementedError(msg)
def _convert_data_type(input_type):
if input_type in ["double", "torch.float64"]:
"aten::pow" : _elemwise("power"),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
+ "aten::abs" : _abs(),
+ "aten::arange" : _arange(),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::reciprocal" : _reciprocal(),
"aten::clone" : _clone(),
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
+ "aten::softplus" : _softplus(),
"aten::avg_pool2d" : _avg_pool2d(),
"aten::avg_pool3d" : _avg_pool3d(),
"aten::dropout" : _dropout(),
verify_model(Squeeze1().float().eval(), input_data=input_data)
verify_model(Squeeze2().float().eval(), input_data=input_data)
+def test_forward_arange():
+ torch.set_grad_enabled(False)
+
+ class Arange1(Module):
+ def forward(self, *args):
+ return torch.arange(5)
+ class Arange2(Module):
+ def forward(self, *args):
+ return torch.arange(2.5)
+ class Arange3(Module):
+ def forward(self, *args):
+ return torch.arange(1, 4)
+ class Arange4(Module):
+ def forward(self, *args):
+ return torch.arange(1, 2.5, 0.5)
+ class Arange5(Module):
+ def forward(self, *args):
+ return torch.arange(1, 2, 1, dtype=torch.int32)
+ class Arange6(Module):
+ def forward(self, *args):
+ return torch.arange(start=1, end=6, step=2)
+ class Arange7(Module):
+ def forward(self, *args):
+ return torch.arange(1, 4, dtype=torch.float32)
+ class Arange8(Module):
+ def forward(self, *args):
+ return torch.arange(1, 2, 1, dtype=torch.int16)
+
+ verify_model(Arange1().float().eval())
+ verify_model(Arange2().float().eval())
+ verify_model(Arange3().float().eval())
+ verify_model(Arange4().float().eval())
+ verify_model(Arange5().float().eval())
+ verify_model(Arange6().float().eval())
+ verify_model(Arange7().float().eval())
+ verify_model(Arange8().float().eval())
+
+def test_forward_abs():
+ torch.set_grad_enabled(False)
+ input_shape = [2, 1, 10, 1, 10]
+
+ class Abs1(Module):
+ def forward(self, *args):
+ return args[0].abs()
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Abs1().float().eval(), input_data=input_data)
+
def test_forward_concatenate():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.SELU().eval(), input_data=input_data)
+def test_forward_softplus():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3, 10, 10]
+ input_data = torch.rand(input_shape).float()
+ verify_model(torch.nn.Softplus().eval(), input_data=input_data)
+ verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data)
+ verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data)
+
+def test_forward_softsign():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3, 10, 10]
+ input_data = torch.rand(input_shape).float()
+ verify_model(torch.nn.Softsign().eval(), input_data=input_data)
+
def test_forward_log_sigmoid():
torch.set_grad_enabled(False)
input_shape = [10, 10]
test_forward_view()
test_forward_select()
test_forward_clone()
+ test_forward_softplus()
+ test_forward_softsign()
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
test_forward_mean()
test_forward_expand()
test_forward_pow()
+ test_forward_abs()
+ test_forward_arange()
test_forward_chunk()
test_forward_split()
test_upsample()