def _pad():
def _impl(inputs, input_types):
data = inputs[0]
- padding = inputs[1]
- pad_width = list(zip(padding, padding))
+ if isinstance(inputs[1], list):
+ pad_list = inputs[1]
+ else:
+ pad_list = list(_infer_shape(inputs[1]))
+
+ # initialize paddings based on input len
+ pad_len = len(_infer_shape(data)) * 2
+ paddings = [0] * pad_len
+
+ if len(pad_list) >= 2:
+ paddings[-1] = pad_list[1]
+ paddings[-2] = pad_list[0]
+ if len(pad_list) >= 4:
+ paddings[-3] = pad_list[3]
+ paddings[-4] = pad_list[2]
+ if len(pad_list) >= 6:
+ paddings[-5] = pad_list[5]
+ paddings[-6] = pad_list[4]
+
+ # group into tuple of 2 ints
+ paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)]
+
pad_value = inputs[2]
- return _op.nn.pad(data, pad_width, pad_value)
+
+ return _op.nn.pad(data, paddings, pad_value)
return _impl
verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)
+def test_forward_functional_pad():
+ torch.set_grad_enabled(False)
+ pad = (0, 0)
+ class Pad1(Module):
+ def forward(self, *args):
+ return torch.nn.functional.pad(args[0], pad, "constant", 0)
+
+ input_data = torch.rand((3, 3, 4, 2))
+ pad = (1, 1)
+ verify_model(Pad1().float().eval(), input_data=input_data)
+
+ pad = (1, 1, 2, 2)
+ verify_model(Pad1().float().eval(), input_data=input_data)
+
+ pad = (0, 1, 2, 1, 3, 3)
+ verify_model(Pad1().float().eval(), input_data=input_data)
+
+
+def test_forward_zero_pad2d():
+ inp = torch.rand((1, 1, 3, 3))
+ verify_model(torch.nn.ZeroPad2d(2).eval(), inp)
+ verify_model(torch.nn.ZeroPad2d((1, 1, 2, 0)).eval(), inp)
+
+
+def test_forward_constant_pad1d():
+ inp = torch.rand((1, 2, 4))
+ verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp)
+
+ inp = torch.rand((1, 2, 3))
+ verify_model(torch.nn.ConstantPad2d((3, 1), 3.5).eval(), inp)
+
+
+def test_forward_constant_pad2d():
+ inp = torch.rand((1, 2, 2, 2))
+ verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp)
+ verify_model(torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5).eval(), inp)
+
+
+def test_forward_constant_pad3d():
+ inp = torch.rand((1, 3, 2, 2, 2))
+ verify_model(torch.nn.ConstantPad3d(3, 3.5).eval(), inp)
+ verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp)
+
+
def test_forward_reflection_pad2d():
inp = torch.rand((1, 1, 3, 3))
verify_model(torch.nn.ReflectionPad2d(2).eval(), inp)
test_upsample()
test_forward_upsample3d()
test_to()
+ test_forward_functional_pad()
+ test_forward_zero_pad2d()
+ test_forward_constant_pad1d()
+ test_forward_constant_pad2d()
+ test_forward_constant_pad3d()
test_forward_reflection_pad2d()
test_adaptive_pool3d()
test_conv3d()