return list(generator())
+def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwargs):
+ make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ # Ordered as shapes for input, weight, bias
+ # and a dict of values of (stride, padding, output_padding, groups, dilation)
+ cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]
+ ((1, 3, 4, 4), (3, 3, 3, 3), (3,),
+ {'stride': (2, 2), 'padding': 2, 'output_padding': (1, 1), 'groups': 1}),
+ ((2, 2, 4, 4), (2, 2, 4, 5), (4,),
+ {'stride': (3, 2), 'padding': (1, 2), 'output_padding': (2, 3), 'groups': 2, 'dilation': (4, 4)}),
+ ((1, 1, 4, 5), (1, 1, 4, 3), (1,),
+ {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3)}),
+ ((1, 1, 4, 3), (1, 2, 3, 4), None,
+ {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
+ ((1, 4, 5, 5), (4, 8, 3, 3), None,
+ {})
+ )
+
+ def generator():
+ for input_shape, weight, bias, kwargs in cases:
+ yield SampleInput(make_arg(input_shape), args=(
+ make_arg(weight),
+ make_arg(bias) if bias is not None else bias
+ ), kwargs=kwargs)
+
+ return list(generator())
+
def sample_inputs_hardswish(self, device, dtype, requires_grad):
N = 5
# make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_nn_activation_relu,
supports_out=False),
+ OpInfo('nn.functional.conv_transpose2d',
+ aten_name='conv_transpose2d',
+ aliases=('conv_transpose2d',),
+ dtypesIfCPU=floating_types_and(torch.int64),
+ dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
+ sample_inputs_func=sample_inputs_conv_transpose2d,
+ gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+ skips=(
+ # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
+ # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
+ SkipInfo('TestJit', 'test_variant_consistency_jit'),
+ ),
+ supports_out=False,),
OpInfo('nn.functional.hardswish',
aten_name="hardswish",
supports_autograd=True,