[fx2trt] Check input device in TRTModule (#63893)
authorShiyan Deng <dsy842974287@fb.com>
Wed, 25 Aug 2021 17:22:17 +0000 (10:22 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 17:25:34 +0000 (10:25 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63893

Add a check to ensure all the inputs are on cuda device.

Test Plan: CI

Reviewed By: kflu, houseroad

Differential Revision: D30525265

fbshipit-source-id: 6e50b70fd535defc1f802d51e8bb991b2dd73741

torch/fx/experimental/fx2trt/fx2trt.py

index 0e7cc24..ede99fd 100644 (file)
@@ -86,6 +86,7 @@ class TRTModule(torch.nn.Module):
         bindings: List[Any] = [None] * (len(self.input_names) + len(self.output_names))
 
         for i, input_name in enumerate(self.input_names):
+            assert inputs[i].is_cuda, f"{i}th input is not on cuda device."
             idx = self.engine.get_binding_index(input_name)
             bindings[idx] = contiguous_inputs[i].data_ptr()