Merge script and _script_pdt API (#62420)
authornikithamalgi <nikithamalgi@devvm146.prn0.facebook.com>
Fri, 27 Aug 2021 01:54:51 +0000 (18:54 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 01:58:19 +0000 (18:58 -0700)
Summary:
Merge `torch.jit.script` and `torch.jit._script_pdt` API. This PR merges profile directed typing with script api

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62420

Reviewed By: iramazanli

Differential Revision: D30579015

Pulled By: nikithamalgifb

fbshipit-source-id: 99ba6839d235d61b2dd0144b466b2063a53ccece

test/jit/test_pdt.py
torch/jit/__init__.py
torch/jit/_script.py

index 57cd74f..468eb27 100644 (file)
@@ -40,7 +40,7 @@ class TestPDT(JitTestCase):
         make_global(TestPDTModel)
         pdt_model = TestPDTModel()
         inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ]
-        scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
+        scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
         self.assertEqual(scripted_pdt_model(50), pdt_model(50))
         self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
         self.assertTrue(scripted_pdt_model(True), pdt_model(True))
@@ -67,7 +67,7 @@ class TestPDT(JitTestCase):
         inner_pdt_model = NestedPDTInner()
         wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
         inp: List[Tuple[Any, ...]] = [(20, ), (False, )]
-        scripted_pdt_model = torch.jit._script_pdt(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
+        scripted_pdt_model = torch.jit.script(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
         self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
         self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
         self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
@@ -95,8 +95,8 @@ class TestPDT(JitTestCase):
         outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
         inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ]
         outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )]
-        scripted_pdt_model = torch.jit._script_pdt(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
-                                                   outer_pdt_model: outer_input, })
+        scripted_pdt_model = torch.jit.script(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
+                                              outer_pdt_model: outer_input, })
         self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
         self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
         self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True))
@@ -119,7 +119,7 @@ class TestPDT(JitTestCase):
         make_global(NestedFunctionInForward)
         pdt_model = NestedFunctionInForward()
         inp: List[Tuple[Any, ...]] = [(-1, ), (False, )]
-        scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp})
+        scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
         self.assertEqual(scripted_pdt_model(30), pdt_model(30))
         self.assertEqual(scripted_pdt_model(True), pdt_model(True))
 
@@ -142,7 +142,7 @@ class TestPDT(JitTestCase):
         make_global(TestModelWithExport)
         pdt_model = TestModelWithExport()
         inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ]
-        scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model.fn: inp})
+        scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model.fn: inp})
         self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
         self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
         self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2))
@@ -155,7 +155,7 @@ class TestPDT(JitTestCase):
         make_global(PDTModel)
         pdt_model = PDTModel()
         inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ]
-        scripted_pdt_model = torch.jit._script_pdt(PDTModel, example_inputs={pdt_model.test_sum: inp})
+        scripted_pdt_model = torch.jit.script(PDTModel, example_inputs={pdt_model.test_sum: inp})
         script_model = scripted_pdt_model()
         self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))
 
@@ -174,8 +174,8 @@ class TestPDT(JitTestCase):
         pdt_model = PDTModelWithManyMethods()
         list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ]
         str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ]
-        scripted_pdt_model = torch.jit._script_pdt(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
-                                                   pdt_model.test_substring: str_inp})
+        scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
+                                              pdt_model.test_substring: str_inp})
         script_model = scripted_pdt_model()
         self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], ))
         self.assertEqual(script_model.test_substring("helloworld", "world", ), pdt_model.test_substring("helloworld", "world", ))
@@ -195,8 +195,8 @@ class TestPDT(JitTestCase):
         pdt_model_two = PDTModelTwo()
         dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ]
         list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ]
-        scripted_pdt_model_one = torch.jit._script_pdt(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
-        scripted_pdt_model_two = torch.jit._script_pdt(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})
+        scripted_pdt_model_one = torch.jit.script(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
+        scripted_pdt_model_two = torch.jit.script(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})
 
         script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two()
         self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4),
@@ -209,28 +209,28 @@ class TestPDT(JitTestCase):
             return a + b
 
         make_global(test_sum)
-        scripted_fn_add = torch.jit._script_pdt(test_sum, example_inputs=[(3, 4)])
+        scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)])
         self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2))
 
         def test_sub(a, b):
             return a - b
 
         make_global(test_sub)
-        scripted_fn_sub = torch.jit._script_pdt(test_sub, example_inputs=[(3.9, 4.10)])
+        scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)])
         self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9))
 
         def test_mul(a, b):
             return a * b
 
         make_global(test_mul)
-        scripted_fn_mul = torch.jit._script_pdt(test_mul, example_inputs=[(-10, 9)])
+        scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)])
         self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3))
 
         def test_args_complex(real, img):
             return torch.complex(real, img)
 
         make_global(test_args_complex)
-        scripted_fn_complex = torch.jit._script_pdt(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
+        scripted_fn_complex = torch.jit.script(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
         arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
         self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))
 
@@ -241,7 +241,7 @@ class TestPDT(JitTestCase):
                 return 0
 
         make_global(test_bool)
-        scripted_fn_bool = torch.jit._script_pdt(test_bool, example_inputs=[(True,)])
+        scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)])
         self.assertEqual(scripted_fn_bool(True), test_bool(True))
 
         def test_str(a):
@@ -251,7 +251,7 @@ class TestPDT(JitTestCase):
                 return True
 
         make_global(test_str)
-        scripted_fn_str = torch.jit._script_pdt(test_str, example_inputs=[("",)])
+        scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)])
         self.assertEqual(scripted_fn_str("abc"), test_str("abc"))
 
     def test_pdt_list_and_tuple(self):
@@ -260,24 +260,24 @@ class TestPDT(JitTestCase):
 
         make_global(test_list_and_tuple)
 
-        scripted_fn_float_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
+        scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
         self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]))
 
-        scripted_fn_bool_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([True, False, True],)])
+        scripted_fn_bool_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([True, False, True],)])
         self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True]))
 
-        scripted_fn_int_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
+        scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
         self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]))
 
-        scripted_fn_float_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
+        scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
         self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)))
 
-        scripted_fn_bool_tuple_input = torch.jit._script_pdt(test_list_and_tuple,
-                                                             example_inputs=[((True, False, True),)])
+        scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple,
+                                                        example_inputs=[((True, False, True),)])
         self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)),
                          test_list_and_tuple((True, True, True)))
 
-        scripted_fn_int_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
+        scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
         self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)))
 
     def test_nested_list_and_tuple(self):
@@ -295,22 +295,22 @@ class TestPDT(JitTestCase):
         make_global(test_nested_list, test_nested_tuple)
 
         list_inp = [[1, 2, 3, ], [5, 6, 7, ]]
-        scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ])
+        scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
         inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]]
         self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
 
         list_inp = ([1, 2, 3, ], [5, 6, 7, ])
-        scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ])
+        scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
         inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ])
         self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
 
         tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )]
-        scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ])
+        scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
         inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )]
         self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
 
         tup_inp = ((True, False, True, ), (False, False, False, ))
-        scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ])
+        scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
         inp = ((True, True, True, ), (False, False, True, ))
         self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
 
@@ -324,11 +324,11 @@ class TestPDT(JitTestCase):
         make_global(test_dict, test_dict_int_list)
 
         str_bool_inp = {'foo' : True, 'bar': False}
-        scripted_fn = torch.jit._script_pdt(test_dict, example_inputs=[(str_bool_inp,)])
+        scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)])
         self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, ))
 
         str_list_inp = {0 : [True, False], 1: [False, True]}
-        scripted_fn = torch.jit._script_pdt(test_dict_int_list, example_inputs=[(str_list_inp,)])
+        scripted_fn = torch.jit.script(test_dict_int_list, example_inputs=[(str_list_inp,)])
         self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ),
                          test_dict_int_list({0 : [False, False], 1: [True, True]}, ))
 
@@ -349,14 +349,14 @@ class TestPDT(JitTestCase):
 
         make_global(test_multiple_types, test_multiple_type_refinement)
 
-        scripted_fn = torch.jit._script_pdt(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
+        scripted_fn = torch.jit.script(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
         self.assertEqual(scripted_fn(10), test_multiple_types(10))
         self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
         self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
         self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))
 
-        scripted_fn = torch.jit._script_pdt(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
-                                            ([3, 4, 5],), (True, ), ({"a": True}, ), ])
+        scripted_fn = torch.jit.script(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
+                                       ([3, 4, 5],), (True, ), ({"a": True}, ), ])
         self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
         self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
         self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999))
@@ -381,7 +381,7 @@ class TestPDT(JitTestCase):
         make_global(UserDefinedClass, test_model)
 
         user_class = UserDefinedClass()
-        scripted_fn = torch.jit._script_pdt(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
+        scripted_fn = torch.jit.script(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
         self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class))
         self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class))
 
@@ -403,7 +403,7 @@ class TestPDT(JitTestCase):
         make_global(ClassWithArgs, test_model_with_args)
 
         user_class = ClassWithArgs(False)
-        scripted_fn = torch.jit._script_pdt(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
+        scripted_fn = torch.jit.script(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
         self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)))
 
     def test_nn_parameter_as_arg(self):
@@ -420,7 +420,7 @@ class TestPDT(JitTestCase):
 
         make_global(TestNNParameter)
         pdt_model = TestNNParameter()
-        scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [(10, ), ], })
+        scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [(10, ), ], })
         self.assertEqual(scripted_fn(20), pdt_model(20))
 
     def test_fx_tracing_with_typing(self):
@@ -434,7 +434,7 @@ class TestPDT(JitTestCase):
 
         make_global(FXModel, FXModelOutput)
         pdt_model = FXModel()
-        scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
+        scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
         self.assertEqual(scripted_fn([20]), pdt_model([20]))
 
     def test_nonetype_as_optional_of_type(self):
@@ -446,11 +446,11 @@ class TestPDT(JitTestCase):
 
         make_global(test_none)
 
-        scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10.6, )])
+        scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10.6, )])
         self.assertEqual(scripted_fn(30.9, ), test_none(30.9, ))
 
-        scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10, )])
+        scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10, )])
         self.assertEqual(scripted_fn(2, ), test_none(2, ))
 
-        scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
+        scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
         self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), ))
index c9fd886..f7fa58b 100644 (file)
@@ -20,7 +20,6 @@ from torch._jit_internal import (
 )
 from torch.jit._script import (
     script,
-    _script_pdt,
     Attribute,
     ScriptModule,
     script_method,
index 3d173ae..09801ba 100644 (file)
@@ -984,57 +984,6 @@ def call_prepare_scriptable_func(obj):
     memo: Dict[int, torch.nn.Module] = {}
     return call_prepare_scriptable_func_impl(obj, memo)
 
-
-def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None,
-                example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
-    # This is a private API, intended for internal use only. Usage of this API is only for experimental
-    # purposes only and is highly discouraged.
-    global type_trace_db
-    if not _enabled:
-        return obj
-
-    if optimize is not None:
-        warnings.warn(
-            "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
-        )
-
-    # No-op for modules and functions that are already scripted
-    if isinstance(obj, ScriptModule):
-        return obj
-    if isinstance(obj, ScriptFunction):
-        return obj
-
-    if example_inputs:
-        # If MonkeyType is installed, enable profile directed type annotation
-        # Check if example_inputs are defined and generate call traces
-        # for the method by running eager mode version of the method with
-        # the provide example inputs. This logs all the traces in type_trace_db
-        type_trace_db = JitTypeTraceStore()
-        if monkeytype_trace:
-            monkeytype_config = JitTypeTraceConfig(type_trace_db)
-            with monkeytype_trace(monkeytype_config):
-                if isinstance(example_inputs, Dict):
-                    # If the obj is an nn.Module or a class, then each method is
-                    # executed with the arguments provided in the example inputs.
-                    # example inputs here will be of type Dict(class.method, (arguments))
-                    # This is used to infer type annotations for those methods
-                    # which are not called directly under the hood of monkeytype.
-                    for module, example_input in example_inputs.items():
-                        for example in example_input:
-                            module(*example)
-                elif isinstance(example_inputs, List):
-                    for examples in example_inputs:
-                        obj(*examples)
-                else:
-                    warnings.warn("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
-                                  " or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
-        else:
-            warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
-                          "to enable Profile-Directed Typing in TorchScript. Refer to "
-                          "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
-    return script(obj, optimize, _frames_up, _rcb)
-
-
 def create_script_dict(obj):
     """
     Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.
@@ -1065,7 +1014,8 @@ def create_script_list(obj, type_hint=None):
     return torch._C.ScriptList(obj)  # type: ignore[attr-defined]
 
 
-def script(obj, optimize=None, _frames_up=0, _rcb=None):
+def script(obj, optimize=None, _frames_up=0, _rcb=None,
+           example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
     r"""
     Scripting a function or ``nn.Module`` will inspect the source code, compile
     it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
@@ -1083,6 +1033,8 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
     Args:
         obj (callable, class, or ``nn.Module``):  The ``nn.Module``, function, class type,
                                                   dictionary, or list to compile.
+        example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs
+            to annotate the arguments for a function or ``nn.Module``.
 
     Returns:
         If ``obj`` is ``nn.Module``, ``script`` returns
@@ -1124,6 +1076,34 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
 
             ...
 
+    ****Scripting a function using example_inputs**
+        Example inputs can be used to annotate a function arguments.
+
+        Example (annotating a function before scripting):
+
+        .. testcode::
+
+            import torch
+
+            def test_sum(a, b):
+                return a + b
+
+            # Annotate the arguments to be int
+            scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])
+
+            print(type(scripted_fn))  # torch.jit.ScriptFunction
+
+            # See the compiled graph as Python code
+            print(scripted_fn.code)
+
+            # Call the function using the TorchScript interpreter
+            scripted_fn(20, 100)
+
+        .. testoutput::
+            :hide:
+
+            ...
+
     **Scripting an nn.Module**
         Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
         compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
@@ -1210,7 +1190,30 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
             scripted_module = torch.jit.script(MyModule())
             print(scripted_module.some_entry_point(torch.randn(2, 2)))
             print(scripted_module(torch.randn(2, 2)))
+
+        Example ( Annotating forward of nn.Module using example_inputs)::
+
+            import torch
+            import torch.nn as nn
+            from typing import NamedTuple
+
+            class MyModule(NamedTuple):
+            result: List[int]
+
+            class TestNNModule(torch.nn.Module):
+                def forward(self, a) -> MyModule:
+                    result = MyModule(result=a)
+                    return result
+
+            pdt_model = TestNNModule()
+
+            # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
+            scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
+
+            # Run the scripted_model with actual inputs
+            print(scripted_model([20]))
     """
+    global type_trace_db
     if not _enabled:
         return obj
 
@@ -1227,6 +1230,35 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
     if isinstance(obj, ScriptFunction):
         return obj
 
+    if example_inputs:
+        # If MonkeyType is installed, enable profile directed type annotation
+        # Check if example_inputs are defined and generate call traces
+        # for the method by running eager mode version of the method with
+        # the provide example inputs. This logs all the traces in type_trace_db
+        type_trace_db = JitTypeTraceStore()
+        if monkeytype_trace:
+            monkeytype_config = JitTypeTraceConfig(type_trace_db)
+            with monkeytype_trace(monkeytype_config):
+                if isinstance(example_inputs, Dict):
+                    # If the obj is an nn.Module or a class, then each method is
+                    # executed with the arguments provided in the example inputs.
+                    # example inputs here will be of type Dict(class.method, (arguments))
+                    # This is used to infer type annotations for those methods
+                    # which are not called directly under the hood of monkeytype.
+                    for module, example_input in example_inputs.items():
+                        for example in example_input:
+                            module(*example)
+                elif isinstance(example_inputs, List):
+                    for examples in example_inputs:
+                        obj(*examples)
+                else:
+                    raise ValueError("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
+                                     " or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
+        else:
+            warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
+                          "to enable Profile-Directed Typing in TorchScript. Refer to "
+                          "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
+
     if isinstance(obj, torch.nn.Module):
         obj = call_prepare_scriptable_func(obj)
         return torch.jit._recursive.create_script_module(