Inference LSTM integration test (#18559)
authorAhmed Aly <ahhegazy@fb.com>
Thu, 28 Mar 2019 18:23:22 +0000 (11:23 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Mar 2019 18:31:06 +0000 (11:31 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18559

Adding integration test for inference LSTM

Reviewed By: houseroad

Differential Revision: D14656698

fbshipit-source-id: 80fb2a72be30fcb695f4471b72bf9d6e3965bf81

caffe2/python/operator_test/torch_integration_test.py

index 2d9aeaf..c9aa64d 100644 (file)
@@ -153,6 +153,86 @@ class TorchIntegration(hu.HypothesisTestCase):
         torch.testing.assert_allclose(rois, a)
         torch.testing.assert_allclose(rois_probs, b)
 
+    @given(
+        bsz=st.integers(1, 5),
+        seq_lens=st.integers(1, 6),
+        emb_lens=st.integers(5, 10),
+        hidden_size=st.integers(3, 7),
+        num_layers=st.integers(1, 4),
+        has_biases=st.booleans(),
+        is_bidirectional=st.booleans(),
+        batch_first=st.booleans(),
+    )
+    def test_inference_lstm(
+        self,
+        bsz,
+        seq_lens,
+        emb_lens,
+        hidden_size,
+        num_layers,
+        has_biases,
+        is_bidirectional,
+        batch_first,
+    ):
+        num_directions = 2 if is_bidirectional else 1
+        hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32)
+
+        if batch_first:
+            inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32)
+        else:
+            inputs = np.random.randn(seq_lens, bsz, emb_lens).astype(np.float32)
+
+        torch_lstm = torch.nn.LSTM(
+            emb_lens,
+            hidden_size,
+            batch_first=batch_first,
+            bidirectional=is_bidirectional,
+            bias=has_biases,
+            num_layers=num_layers,
+        )
+
+        def inference_lstm_ref():
+            input_names = ["inputs", "hidden_0", "hidden_1"]
+            workspace.FeedBlob("inputs", inputs)
+            workspace.FeedBlob("hidden_0", hx)
+            workspace.FeedBlob("hidden_1", hx)
+            for i, param in enumerate(torch_lstm._flat_weights):
+                input_names.append("param_{}".format(i))
+                workspace.FeedBlob("param_{}".format(i), param.detach().numpy())
+
+            ref_op = core.CreateOperator(
+                "InferenceLSTM",
+                input_names,
+                ["output", "hidden", "cell"],
+                num_layers=num_layers,
+                has_biases=has_biases,
+                batch_first=batch_first,
+                bidirectional=is_bidirectional,
+            )
+            workspace.RunOperatorOnce(ref_op)
+            return (
+                workspace.FetchBlob("output"),
+                workspace.FetchBlob("hidden"),
+                workspace.FetchBlob("cell")
+            )
+
+        output, hidden, cell = inference_lstm_ref()
+        output = torch.tensor(output)
+        hidden = torch.tensor(hidden)
+        cell = torch.tensor(cell)
+        lstm_in = [
+            torch.from_numpy(inputs),
+            torch.from_numpy(hx),
+            torch.from_numpy(hx),
+        ] + [param.detach() for param in torch_lstm._flat_weights]
+
+        a, b, c = torch.ops._caffe2.InferenceLSTM(
+            lstm_in, num_layers, has_biases, batch_first, is_bidirectional
+        )
+        torch.testing.assert_allclose(output, a)
+        torch.testing.assert_allclose(hidden, b)
+        torch.testing.assert_allclose(cell, c)
+
     # Test case is using workspace.has_cuda_support and not workspace.has_gpu_support
     # to exclude it from HIP because tensor interop doesn't work for HIP tensors yet
     @unittest.skipIf(not workspace.has_cuda_support, "No cuda support")