From: Elias Ellison Date: Wed, 15 Sep 2021 20:43:12 +0000 (-0700) Subject: Add Maxpool to shape analysis / Opinfo (#63530) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~181 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2626cd3ba4452234ac82501d7959a4c5ca6bf063;p=platform%2Fupstream%2Fpytorch.git Add Maxpool to shape analysis / Opinfo (#63530) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63530 how to review: pretty much just check that the inputs generated are a good representation of the op semantics, that should be sufficient for correctness, and then you can also double check the op size semantics by going to https://codebrowser.bddppq.com/pytorch/pytorch/ typing in native::{op_name} and looking at the op implementation as a bonus if you want Test Plan: Imported from OSS Reviewed By: driazati Differential Revision: D30738147 Pulled By: eellison fbshipit-source-id: cf52339e572ee04e0d6167fd95d8a82d58ea7706 --- diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index 7691cb6..7841631 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -139,6 +139,73 @@ const std::string shape_compute_functions = def broadcast_one_unused_input(self: List[int], other: List[int], unused: Any): return broadcast(self, other) + # note: python already rounds down towards negative infinity on integer division, special arithmetic not needed + def div_rtn(x: int, y: int): + return x // y + + def pooling_output_shape_pad_lr(inputSize: int, kernelSize: int, pad_l: int, pad_r: int, stride: int, dilation: int, ceil_mode: bool): + outputSize = div_rtn(inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 + (stride - 1 if ceil_mode else 0), stride) + 1 + if ceil_mode: + if (outputSize - 1) * stride >= inputSize + pad_l: + outputSize = outputSize - 1 + return outputSize + + def pooling_output_shape(inputSize: int, kernelSize: int, pad_l: int, stride: int, dilation: int, ceil_mode: bool): + assert stride != 0, "stride should not be zeero" + return pooling_output_shape_pad_lr(inputSize, kernelSize, pad_l, pad_l, stride, dilation, ceil_mode) + + def pool2d_shape_check(input: List[int], kH: int, kW: int, dH: int, dW: int, padH: int, padW: int, + dilationH: int, dilationW: int, nInputPlane: int, inputHeight: int, inputWidth: int, outputHeight: int, outputWidth: int): + + ndim = len(input) + nOutputPlane = nInputPlane + + assert kW > 0 and kH > 0 + assert dW > 0 and dH > 0 + assert dilationH > 0 and dilationW > 0 + + valid_dims = input[1] != 0 and input[2] != 0 + assert ndim == 3 and input[0] != 0 and valid_dims or (ndim == 4 and valid_dims and input[3] != 0) + + assert kW // 2 >= padW and kH // 2 >= padH + assert outputWidth >= 1 and outputHeight >= 1 + + def max_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool): + assert len(kernel_size) == 1 or len(kernel_size) == 2, "max_pool2d: kernel_size must either be a single int, or a tuple of two ints" + kH = kernel_size[0] + kW = kH if len(kernel_size) == 1 else kernel_size[1] + + assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + dH = kH if len(stride) == 0 else stride[0] + dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1] + + assert len(padding) == 1 or len(padding) == 2, "max_pool2d: padding must be either be a single int, or a tuple of two ints" + padH = padding[0] + padW = padH if len(padding) == 1 else padding[1] + + assert len(dilation) == 1 or len(dilation) == 2, "max_pool2d: dilation must be either a single int, or a tuple of two ints" + dilationH = dilation[0] + dilationW = dilationH if len(dilation) == 1 else dilation[1] + + assert len(input) == 3 or len(input) == 4 + + nbatch = input[-4] if len(input) == 4 else 1 + nInputPlane = input[-3] + inputHeight = input[-2] + inputWidth = input[-1] + + outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) + outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) + + pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, nInputPlane, + inputHeight, inputWidth, outputHeight, outputWidth) + + if len(input) == 3: + return [nInputPlane, outputHeight, outputWidth] + else: + return [nbatch, nInputPlane, outputHeight, outputWidth] + )" + R"( def mm(self: List[int] , mat2: List[int]): assert len(self) == 2, "self must be a matrix" assert len(mat2) == 2, "mat2 must be a matrix" @@ -306,7 +373,8 @@ const std::string shape_compute_functions = else: out.append(self[i]) return out - + )" + R"( def linear(input: List[int], weight: List[int], bias: Optional[List[int]]): out = matmul(input, t(weight)) if bias is not None: @@ -461,7 +529,7 @@ const std::string shape_compute_functions = return linear(input, weight, bias) )" #endif - ; +; // mapping function schema to shape compute graphs allows multiple functions to // share the same shape compute graph, which is memory efficient and also will @@ -518,6 +586,7 @@ static const OperatorMap& get_schema_to_function_graph() { {"aten::mv(Tensor self, Tensor vec) -> Tensor", "mv"}, {"aten::matmul(Tensor self, Tensor other) -> Tensor", "matmul"}, {"aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", "linear"}, + {"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "max_pool2d"}, {"aten::t(Tensor(a) self) -> Tensor(a)", "t"}, {"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "transpose"}, {"aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", "conv1d"}, @@ -572,6 +641,7 @@ void loadModule(const CompilationUnit& module) { void loadFunctions() { auto src = std::make_shared(shape_compute_functions); + std::stringstream ss; std::vector constantTable; auto resolver = std::make_shared( compilation_unit, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 26f64b3..18f7431 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2586,6 +2586,37 @@ def sample_inputs_adaptive_avg_pool2d(op_info, device, dtype, requires_grad, **k return list(generator()) +def sample_inputs_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + kerneli = [[3, 2], 3] + stridei = [[2, 2]] + Ni = [1, 4, None] + Ci = [32] + Hi = [8, 16] + Wi = [8, 16] + ceil_modei = [True, False] + paddingi = [0, 1] + dilationi = [1, (1, 2)] + products = product(kerneli, stridei, Ni, Ci, Hi, Wi, ceil_modei, paddingi, dilationi) + + def generator(): + for kernel, stride, N, C, H, W, ceil_mode, padding, dilation in products: + max_pool = torch.nn.MaxPool2d(kernel, stride, ceil_mode=ceil_mode, padding=padding, dilation=dilation) + kwargs = { + "kernel_size": max_pool.kernel_size, + "stride": max_pool.stride, + "padding": max_pool.padding, + "dilation": max_pool.dilation, + "ceil_mode": max_pool.ceil_mode, + "return_indices": max_pool.return_indices, + } + sample_input = make_arg((N, C, H, W)) if N is not None else (make_arg((C, H, W))) + + yield SampleInput(sample_input, kwargs=kwargs) + + return list(generator()) + def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad) @@ -7688,6 +7719,15 @@ op_db: List[OpInfo] = [ dtypesIfCPU=floating_types_and(torch.int64), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_avgpool2d), + OpInfo('nn.functional.max_pool2d', + aten_name='max_pool2d', + supports_autograd=True, + supports_out=False, + assert_jit_shape_analysis=True, + dtypesIfCPU=floating_types(), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_scripting=False, # TODO: fix aliasing test + sample_inputs_func=sample_inputs_max_pool2d), UnaryUfuncInfo( 'nn.functional.logsigmoid', aten_name="log_sigmoid",