From 5f997a7d2fcd81584d1c9f6e173e30c867892ee8 Mon Sep 17 00:00:00 2001 From: Pavithran Ramachandran Date: Fri, 20 Aug 2021 09:34:53 -0700 Subject: [PATCH] [PyTorch][Edge] Improve InflatableArgs for Bundled Inputs (#62368) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62368 # Context The bundled inputs accepts an expression in the form of string InflatableArg.fmt that can be applied on the inputs to inflate. The InflatableArg.fmt provides flexibility to have custom transformation to inflate. When the input arguments to a function are not Tensor type, TorchScript casts the inputs from type T to Optional[T] expects the function to handle Nullable (None) clause as well. This becomes tricky to handle in one line code or lambda functions. We propose an alternative way which allows InflatableArg to include the text of a TorchScript function that would be defined on the module as a helper, then use that in its inflation expression. This can be provided by InflatableArg.fmt_fn. Please refer to pytorch/test/test_bundled_inputs.py for example on how to use the same. Also refer JacobSzwejbka comment on the same [here](https://github.com/pytorch/pytorch/pull/62368#issuecomment-892012812) # Mitigation Allow InflatedArg to include the text of a TorchScript function that would be defined on the module as a helper, then use that in its inflation expression. ghstack-source-id: 135158680 Test Plan: To run `test_dict_args` ``` (base) [pavithran@devvm1803.vll0 /data/users/pavithran/fbsource/fbcode] buck test //caffe2/test:test_bundled_inputs -- test_dict_args Action graph will be rebuilt because files have been added or removed. Building: finished in 5.4 sec (100%) 12180/12180 jobs, 0/12180 updated Total time: 5.8 sec More details at https://www.internalfb.com/intern/buck/build/fafcf277-1095-4cba-978d-6022f0d391ad Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details. Running with tpx session id: 5ef9de71-c1b1-406b-a6c0-3321c2368b8d Trace available for this run at /tmp/tpx-20210727-163946.454212/trace.log Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/7036874465805934 ✓ ListingSuccess: caffe2/test:test_bundled_inputs - main (11.365) ✓ Pass: caffe2/test:test_bundled_inputs - test_dict_args (test_bundled_inputs.TestBundledInputs) (12.307) Summary Pass: 1 ListingSuccess: 1 If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users Finished test run: https://www.internalfb.com/intern/testinfra/testrun/7036874465805934 ``` To check the py code of TS module: P433043973 Reviewed By: dreiss Differential Revision: D29950421 fbshipit-source-id: c819ec5c94429b7fbf6c4beb0259457f169b08ec --- test/test_bundled_inputs.py | 115 +++++++++++++++++++++++++++++++++++++++++- torch/utils/bundled_inputs.py | 75 +++++++++++++++++++++++---- 2 files changed, 180 insertions(+), 10 deletions(-) diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index a0fb535..62263e1 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import io import textwrap -from typing import List +from typing import List, Optional, Dict import torch import torch.utils.bundled_inputs @@ -324,5 +324,118 @@ class TestBundledInputs(TestCase): ) self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)]) + + def test_dict_args(self): + class MyModel(torch.nn.Module): + def forward( + self, + arg1: Optional[Dict[str, torch.Tensor]], + arg2: Optional[List[torch.Tensor]], + arg3: torch.Tensor, + ): + if arg1 is None: + return arg3 + elif arg2 is None: + return arg1["a"] + arg1["b"] + else: + return arg1["a"] + arg1["b"] + arg2[0] + + small_sample = dict( + a=torch.zeros([10, 20]), + b=torch.zeros([1, 1]), + c=torch.zeros([10, 20]), + ) + small_list = [torch.zeros([10, 20])] + + big_sample = dict( + a=torch.zeros([1 << 5, 1 << 8, 1 << 10]), + b=torch.zeros([1 << 5, 1 << 8, 1 << 10]), + c=torch.zeros([1 << 5, 1 << 8, 1 << 10]), + ) + big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])] + + def condensed(t): + ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape) + assert ret.storage().size() == 1 + # ret.storage()[0] = 0 + return ret + + def bundle_optional_dict_of_randn(template): + return torch.utils.bundled_inputs.InflatableArg( + value=( + None + if template is None + else {k: condensed(v) for (k, v) in template.items()} + ), + fmt="{}", + fmt_fn=""" + def {}(self, value: Optional[Dict[str, Tensor]]): + if value is None: + return None + output = {{}} + for k, v in value.items(): + output[k] = torch.randn_like(v) + return output + """, + ) + + def bundle_optional_list_of_randn(template): + return torch.utils.bundled_inputs.InflatableArg( + value=(None if template is None else [condensed(v) for v in template]), + fmt="{}", + fmt_fn=""" + def {}(self, value: Optional[List[Tensor]]): + if value is None: + return None + output = [] + for v in value: + output.append(torch.randn_like(v)) + return output + """, + ) + + out : List[str] = [] + sm = torch.jit.script(MyModel()) + original_size = model_size(sm) + small_inputs = ( + bundle_optional_dict_of_randn(small_sample), + bundle_optional_list_of_randn(small_list), + torch.zeros([3, 4]), + ) + big_inputs = ( + bundle_optional_dict_of_randn(big_sample), + bundle_optional_list_of_randn(big_list), + torch.zeros([1 << 5, 1 << 8, 1 << 10]), + ) + + torch.utils.bundled_inputs.augment_model_with_bundled_inputs( + sm, + [ + big_inputs, + small_inputs, + ], + _receive_inflate_expr=out, + ) + augmented_size = model_size(sm) + # assert the size has not increased more than 8KB + self.assertLess(augmented_size, original_size + (1 << 13)) + + loaded = save_and_load(sm) + inflated = loaded.get_all_bundled_inputs() + self.assertEqual(len(inflated[0]), len(small_inputs)) + + methods, _ = torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods( + loaded + ) + + # One Function (forward) + # two bundled inputs (big_inputs and small_inputs) + # two args which have InflatableArg with fmt_fn + # 1 * 2 * 2 = 4 + self.assertEqual( + sum([method.startswith("_inflate_helper") for method in methods]), 4 + ) + + if __name__ == '__main__': run_tests() diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index bce658b..8a6d466 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -21,13 +21,18 @@ class InflatableArg(NamedTuple): the appropriate input. It can use 'value' as an input to the format str. It must result in a value of the same type as 'value'. + 'fmt_fn' is a formatable function code string that is executed to inflate the compressed + data into the appropriate input. It must result in a value of the same type as 'value'. + The function name should be the formatable part of the string. + Note: Only top level InflatableArgs can be inflated. i.e. you cannot place an inflatable arg inside of some other structure. You should instead create an inflatable arg such that the fmt code string returns the full structure of your input. """ value: Any - fmt: str + fmt: str = "{}" + fmt_fn: str = "" def bundle_inputs( @@ -279,13 +284,21 @@ def augment_many_model_functions_with_bundled_inputs( deflated_args = [] parts.append("(") for arg_idx, arg in enumerate(args): - deflated, inflater = _inflate_expr(arg, f"deflated[{inp_idx}][{arg_idx}]") + inflate_helper_fn_name = _get_inflate_helper_fn_name(arg_idx, inp_idx, function_name) + deflated, inflater, helper_definition = _inflate_expr( + arg, + f"deflated[{inp_idx}][{arg_idx}]", + inflate_helper_fn_name, + ) deflated_args.append(deflated) parts.append(f" {inflater},") + if helper_definition: + model.define(textwrap.dedent(helper_definition)) deflated_inputs.append(tuple(deflated_args)) parts.append("),") parts.append("") expr = "\n".join(parts) + # Back-channel return this expr for debugging. if _receive_inflate_expr is not None: _receive_inflate_expr.append(expr) @@ -332,7 +345,6 @@ def augment_many_model_functions_with_bundled_inputs( return len(self.get_all_bundled_inputs_for_forward()) """)) - # Define some high level helper methods that act on all bundled inputs model.define(textwrap.dedent(""" def get_bundled_inputs_functions_and_info(self): @@ -341,27 +353,44 @@ def augment_many_model_functions_with_bundled_inputs( return all_inputs """.format(template=get_bundled_inputs_functions_and_info_template))) -def _inflate_expr(arg: T, ref: str) -> Tuple[Union[T, torch.Tensor], str]: +def _inflate_expr( + arg: T, ref: str, inflate_helper_fn_name: str +) -> Tuple[Union[T, torch.Tensor], str, Optional[str]]: # Allow custom inflation expressions any object. # For example, calling custom image-decoding ops. # Or just use "{}" as the format string to ignore size limits. if isinstance(arg, InflatableArg): - return arg.value, arg.fmt.format(ref) + if arg.fmt_fn: + if arg.fmt not in ["{}", ""]: + raise Exception( + f"Bundled input argument at position '{ref}' has " + f"both arg.fmt_fn => \n{arg.fmt_fn} " + f"\n and arg.fmt => {arg.fmt}. " + "Please choose `arg.fmt` if the deflater is straightforward or " + "`arg.fmt_fn` if you need a function." + ) + + helper_definition = arg.fmt_fn.format(inflate_helper_fn_name) + expr = f"self.{inflate_helper_fn_name}({ref})" + + return arg.value, expr, helper_definition + else: + return arg.value, arg.fmt.format(ref), None if isinstance(arg, torch.Tensor): # Small-storage tensors can just be saved directly. if arg.storage().size() <= MAX_RAW_TENSOR_SIZE: - return arg, ref + return arg, ref, None # Small contiguous tensors can be cloned to have small storage. # TODO: Should we do this even for non-contiguous tensors? if arg.is_contiguous() and arg.numel() <= MAX_RAW_TENSOR_SIZE: - return arg.clone(), ref + return arg.clone(), ref, None # Example inputs commonly come from torch.zeros, torch.ones, or torch.full. # These can be represented compactly. for fmt in [torch.contiguous_format, torch.channels_last]: if arg.is_contiguous(memory_format=fmt) and (arg == arg.flatten()[0]).all().item(): return (arg.flatten()[0].clone().expand(*arg.size()), - f"{ref}.contiguous(memory_format={fmt})") + f"{ref}.contiguous(memory_format={fmt})", None) # Prevent big tensors from being bundled by default. # TODO: Provide more useful diagnostics. raise Exception( @@ -370,7 +399,7 @@ def _inflate_expr(arg: T, ref: str) -> Tuple[Union[T, torch.Tensor], str]: f"You probably don't want to bundle this as an input. " ) else: - return arg, ref + return arg, ref, None def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptModule) -> Tuple[List[str], List[str]]: methods: List[str] = [] @@ -389,9 +418,37 @@ def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptMo methods.append("get_all_bundled_inputs_for_" + function_name) methods.append("_generate_bundled_inputs_for_" + function_name) attributes.append("_bundled_inputs_deflated_" + function_name) + + bundled_inputs_fn = getattr( + script_module, + f"get_all_bundled_inputs_for_{function_name}" + ) + num_bundled_inputs: int = len(bundled_inputs_fn()) + + # Check inflate helper functions for each function, argument and bundled input + func = getattr(script_module, function_name, None) + for arg_idx in range(len(func.schema.arguments) - 1): + for input_idx in range(num_bundled_inputs): + helper_fn_name = _get_inflate_helper_fn_name( + arg_idx=arg_idx, + input_idx=input_idx, + function_name=function_name + ) + # if the arg has an InflatableArg with fmt_fn, add the helper function name + if hasattr(script_module, helper_fn_name): + methods.append(helper_fn_name) + return (methods, attributes) +def _get_inflate_helper_fn_name( + arg_idx: int, + input_idx: int, + function_name: str, +) -> str: + return f"_inflate_helper_for_{function_name}_input_{input_idx}_arg_{arg_idx}" + + def bundle_randn(*size, dtype=None): """Generate a tensor that will be inflated with torch.randn.""" -- 2.7.4