Add permute021 fx2trt converter (#63238)
authorShiyan Deng <dsy842974287@fb.com>
Thu, 19 Aug 2021 17:16:26 +0000 (10:16 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 19 Aug 2021 17:17:48 +0000 (10:17 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63238

Reviewed By: yinghai

Differential Revision: D30295373

fbshipit-source-id: 2a189fe485edaa978fd03e4b8d8582edb34ec648

torch/fx/experimental/fx2trt/fx2trt.py

index 9879fd7e0952e89d0378dfb919241b3f54e837f8..0e7cc24c18be5c1514302b72c32fc62a5002b133 100644 (file)
@@ -302,8 +302,9 @@ class TRTInterpreter(torch.fx.Interpreter):
                 missing_converter.add(f"{node.op} {node.target}")
             elif node.op == "call_module":
                 submod = self.fetch_attr(node.target)
-                if not CONVERTERS.get(type(submod)):
-                    missing_converter.add(f"{node.op} {type(submod)}")
+                submod_type = getattr(submod, "_base_class_origin", type(submod))
+                if not CONVERTERS.get(submod_type):
+                    missing_converter.add(f"{node.op} {submod_type}")
 
         return missing_converter
 
@@ -373,12 +374,11 @@ class TRTInterpreter(torch.fx.Interpreter):
     def call_module(self, target, args, kwargs):
         assert isinstance(target, str)
         submod = self.fetch_attr(target)
-        converter = CONVERTERS.get(type(submod))
+        submod_type = getattr(submod, "_base_class_origin", type(submod))
+        converter = CONVERTERS.get(submod_type)
 
         if not converter:
-            raise RuntimeError(
-                f"Conversion of module of type {type(submod)} not currently supported!"
-            )
+            raise RuntimeError(f'Conversion of module of type {submod_type} not currently supported!')
 
         return converter(self.network, submod, args, kwargs, self._cur_node_name)