Add Maxpool to shape analysis / Opinfo (#63530)
authorElias Ellison <eellison@devfair044.h1.fair>
Wed, 15 Sep 2021 20:43:12 +0000 (13:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 20:44:33 +0000 (13:44 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63530

how to review: pretty much just check that the inputs generated are a good representation of the op semantics, that should be sufficient for correctness, and then you can also double check the op size semantics by going to https://codebrowser.bddppq.com/pytorch/pytorch/ typing in native::{op_name} and looking at the op implementation as a bonus if you want

Test Plan: Imported from OSS

Reviewed By: driazati

Differential Revision: D30738147

Pulled By: eellison

fbshipit-source-id: cf52339e572ee04e0d6167fd95d8a82d58ea7706

torch/csrc/jit/runtime/symbolic_shape_registry.cpp
torch/testing/_internal/common_methods_invocations.py

index 7691cb6..7841631 100644 (file)
@@ -139,6 +139,73 @@ const std::string shape_compute_functions =
         def broadcast_one_unused_input(self: List[int], other: List[int], unused: Any):
           return broadcast(self, other)
 
+        # note: python already rounds down towards negative infinity on integer division, special arithmetic not needed
+        def div_rtn(x: int, y: int):
+          return x // y
+
+        def pooling_output_shape_pad_lr(inputSize: int, kernelSize: int, pad_l: int, pad_r: int, stride: int, dilation: int, ceil_mode: bool):
+          outputSize = div_rtn(inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 + (stride - 1 if ceil_mode else 0), stride) + 1
+          if ceil_mode:
+            if (outputSize - 1) * stride >= inputSize + pad_l:
+              outputSize = outputSize - 1
+          return outputSize
+
+        def pooling_output_shape(inputSize: int, kernelSize: int, pad_l: int, stride: int, dilation: int, ceil_mode: bool):
+          assert stride != 0, "stride should not be zeero"
+          return pooling_output_shape_pad_lr(inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode)
+
+        def pool2d_shape_check(input: List[int], kH: int, kW: int, dH: int, dW: int, padH: int, padW: int,
+                               dilationH: int, dilationW: int, nInputPlane: int, inputHeight: int, inputWidth: int, outputHeight: int, outputWidth: int):
+
+          ndim = len(input)
+          nOutputPlane = nInputPlane
+
+          assert kW > 0 and kH > 0
+          assert dW > 0 and dH > 0
+          assert dilationH > 0 and dilationW > 0
+
+          valid_dims = input[1] != 0 and input[2] != 0
+          assert ndim == 3 and input[0] != 0 and valid_dims or (ndim == 4 and valid_dims and input[3] != 0)
+
+          assert kW // 2 >= padW and kH // 2 >= padH
+          assert outputWidth >= 1 and outputHeight >= 1
+
+        def max_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool):
+          assert len(kernel_size) == 1 or len(kernel_size) == 2, "max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
+          kH = kernel_size[0]
+          kW = kH if len(kernel_size) == 1 else kernel_size[1]
+
+          assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
+          dH = kH if len(stride) == 0 else stride[0]
+          dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
+
+          assert len(padding) == 1 or len(padding) == 2, "max_pool2d: padding must be either be a single int, or a tuple of two ints"
+          padH = padding[0]
+          padW = padH if len(padding) == 1 else padding[1]
+
+          assert len(dilation) == 1 or len(dilation) == 2, "max_pool2d: dilation must be either a single int, or a tuple of two ints"
+          dilationH = dilation[0]
+          dilationW = dilationH if len(dilation) == 1 else dilation[1]
+
+          assert len(input) == 3 or len(input) == 4
+
+          nbatch = input[-4] if len(input) == 4 else 1
+          nInputPlane = input[-3]
+          inputHeight = input[-2]
+          inputWidth = input[-1]
+
+          outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
+          outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
+
+          pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane,
+                             inputHeight, inputWidth, outputHeight, outputWidth)
+
+          if len(input) == 3:
+            return [nInputPlane, outputHeight, outputWidth]
+          else:
+            return [nbatch, nInputPlane, outputHeight, outputWidth]
+    )"
+    R"(
         def mm(self: List[int] , mat2: List[int]):
           assert len(self) == 2, "self must be a matrix"
           assert len(mat2) == 2, "mat2 must be a matrix"
@@ -306,7 +373,8 @@ const std::string shape_compute_functions =
             else:
               out.append(self[i])
           return out
-
+    )"
+    R"(
         def linear(input: List[int], weight: List[int], bias: Optional[List[int]]):
           out = matmul(input, t(weight))
           if bias is not None:
@@ -461,7 +529,7 @@ const std::string shape_compute_functions =
           return linear(input, weight, bias)
     )"
 #endif
-    ;
+;
 
 // mapping function schema to shape compute graphs allows multiple functions to
 // share the same shape compute graph, which is memory efficient and also will
@@ -518,6 +586,7 @@ static const OperatorMap<std::string>& get_schema_to_function_graph() {
       {"aten::mv(Tensor self, Tensor vec) -> Tensor", "mv"},
       {"aten::matmul(Tensor self, Tensor other) -> Tensor", "matmul"},
       {"aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", "linear"},
+      {"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "max_pool2d"},
       {"aten::t(Tensor(a) self) -> Tensor(a)", "t"},
       {"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "transpose"},
       {"aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", "conv1d"},
@@ -572,6 +641,7 @@ void loadModule(const CompilationUnit& module) {
 
 void loadFunctions() {
   auto src = std::make_shared<Source>(shape_compute_functions);
+  std::stringstream ss;
   std::vector<at::IValue> constantTable;
   auto resolver = std::make_shared<SourceImporterImpl>(
       compilation_unit,
index 26f64b3..18f7431 100644 (file)
@@ -2586,6 +2586,37 @@ def sample_inputs_adaptive_avg_pool2d(op_info, device, dtype, requires_grad, **k
 
     return list(generator())
 
+def sample_inputs_max_pool2d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    kerneli = [[3, 2], 3]
+    stridei = [[2, 2]]
+    Ni = [1, 4, None]
+    Ci = [32]
+    Hi = [8, 16]
+    Wi = [8, 16]
+    ceil_modei = [True, False]
+    paddingi = [0, 1]
+    dilationi = [1, (1, 2)]
+    products = product(kerneli, stridei, Ni, Ci, Hi, Wi, ceil_modei, paddingi, dilationi)
+
+    def generator():
+        for kernel, stride, N, C, H, W, ceil_mode, padding, dilation in products:
+            max_pool = torch.nn.MaxPool2d(kernel, stride, ceil_mode=ceil_mode, padding=padding, dilation=dilation)
+            kwargs = {
+                "kernel_size": max_pool.kernel_size,
+                "stride": max_pool.stride,
+                "padding": max_pool.padding,
+                "dilation": max_pool.dilation,
+                "ceil_mode": max_pool.ceil_mode,
+                "return_indices": max_pool.return_indices,
+            }
+            sample_input = make_arg((N, C, H, W)) if N is not None else (make_arg((C, H, W)))
+
+            yield SampleInput(sample_input, kwargs=kwargs)
+
+    return list(generator())
+
 def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs):
     make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad)
 
@@ -7688,6 +7719,15 @@ op_db: List[OpInfo] = [
            dtypesIfCPU=floating_types_and(torch.int64),
            dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_avgpool2d),
+    OpInfo('nn.functional.max_pool2d',
+           aten_name='max_pool2d',
+           supports_autograd=True,
+           supports_out=False,
+           assert_jit_shape_analysis=True,
+           dtypesIfCPU=floating_types(),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           supports_scripting=False,  # TODO: fix aliasing test
+           sample_inputs_func=sample_inputs_max_pool2d),
     UnaryUfuncInfo(
         'nn.functional.logsigmoid',
         aten_name="log_sigmoid",