Fix TRTOperatorSupport (#64873)
authorShirong Wu <shirong@fb.com>
Mon, 13 Sep 2021 22:53:20 +0000 (15:53 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 22:55:15 +0000 (15:55 -0700)
commit3710edc86b6a68a5b77c977f33c8584325accf1f
tree2323c8971dbab8ae087da2a0467aa01e8aba7bac
parent914e3a861ac441b17b3f143c3076fa7a424c4738
Fix TRTOperatorSupport (#64873)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64873

Fix TRTOperatorSupport's key naming to match the key generated by torch.fx.passes.tools_common.get_node_target. The get_node_target is used by splitter_base for comparing whether operator is supported by name.

Test Plan:
print out the supported operator dict and check name.
Run TRTSplitter with lrm_split_model_generator and verify split result is correct with all supported operators printed.
current split result:
````
Supported node types in the model:
acc_ops.size: ((), {'input': torch.float32})
acc_ops.getitem: ((), {'input': torch.float32})
acc_ops.getitem: ((), {'input': None})
acc_ops.reshape: ((), {'input': torch.float32})
acc_ops.unsqueeze: ((), {'input': torch.float32})
acc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32})
acc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})
acc_ops.mul: ((), {'input': torch.float32, 'other': torch.float32})
acc_ops.cat: ((), {})
acc_ops.add: ((), {'input': torch.float32, 'other': torch.float32})
acc_ops.add: ((), {'input': torch.float32})
acc_ops.tanh: ((), {'input': torch.float32})
acc_ops.transpose: ((), {'input': torch.float32})
acc_ops.matmul: ((), {'input': torch.float32, 'other': torch.float32})
acc_ops.div: ((), {'input': torch.float32, 'other': torch.float32})
acc_ops.squeeze: ((), {'input': torch.float32})
acc_ops.noop: ((), {'input': torch.float32})
acc_ops.layer_norm: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})
acc_ops.permute: ((), {'input': torch.float32})
acc_ops.sigmoid: ((), {'input': torch.float32})
acc_ops.flatten: ((), {'input': torch.float32})
acc_ops.softmax: ((), {'input': torch.float32})
acc_ops.sum: ((), {'input': torch.float32})

Unsupported node types in the model:
torch.ops.fb.pad_sequence_embeddings: ((), {'embeddings': torch.float32, 'offsets': torch.int32})
acc_ops.linalg_norm: ((), {'input': torch
```

Reviewed By: yinghai

Differential Revision: D30884463

fbshipit-source-id: 22442aa6a69cd148ce9bc8be8f62157dd6d19954
torch/fx/experimental/fx2trt/tools/trt_splitter.py