x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.full(x.shape, 2), x)
+ def test_full_like(self):
+ x = torch.randn(3, 4, requires_grad=True)
+ self.assertONNX(lambda x: torch.full_like(x, 2), x)
+
def test_max(self):
x = torch.randn(3, 4, requires_grad=True)
y = torch.randn(3, 4, requires_grad=True)
x = torch.randn(3, 4)
self.assertONNX(torch.nn.Linear(4, 5, bias=True), x)
+ def test_zeros_like(self):
+ x = torch.randn(5, 8, requires_grad=True)
+ self.assertONNX(lambda x: torch.zeros_like(x), x)
+
+ def test_ones_like(self):
+ x = torch.randn(6, 10, requires_grad=True)
+ self.assertONNX(lambda x: torch.ones_like(x), x)
if __name__ == '__main__':
no_onnx_dep_flag = '--no-onnx'
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=0)
-def zeros_like(g, input):
- return g.op("Sub", input, input).setType(input.type().contiguous())
+@parse_args('v', 'i', 'v', 'v')
+def zeros_like(g, input, dtype, layout, device):
+ return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=0.0)
@parse_args('v', 'i', 'v', 'v')
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=1)
+@parse_args('v', 'i', 'v', 'v')
+def ones_like(g, input, dtype, layout, device):
+ return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=1.0)
+
+
def full(g, sizes, value, dtype, layout, device):
const_value = _maybe_get_const(value, 't')
if _is_value(const_value):
input_as_shape_i=1, value_f=const_value)
-def full_like(g, input, fill_value):
- # TODO: a more efficient implementation (ConstantFill?)
- return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1)))
+@parse_args('v', 'f', 'i', 'v', 'v')
+def full_like(g, input, fill_value, dtype, layout, device):
+ return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=fill_value)
@parse_args('v', 'v', 'v', 'v', 'i')