make test_jit_fuser runnable
authorWanchao Liang <wanchaol@users.noreply.github.com>
Tue, 9 Apr 2019 18:53:23 +0000 (11:53 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 9 Apr 2019 19:36:25 +0000 (12:36 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19036

Differential Revision: D14839800

Pulled By: wanchaol

fbshipit-source-id: b52c131b58e1b42a8c3da5d1117217c3dc2e5f5b

test/test_jit_fuser.py
test/test_namedtuple_return_api.py
torch/csrc/jit/passes/shape_analysis.cpp

index c435f8f..9f45947 100644 (file)
@@ -5,13 +5,17 @@ from __future__ import unicode_literals
 
 import unittest
 import torch
+import torch.nn as nn
+import torch.nn.functional as F
 from torch import Tensor
+from torch.testing import FileCheck
 
-from common_utils import IS_WINDOWS, \
-    skipIfRocm, IS_SANDCASTLE
+from common_utils import run_tests, IS_WINDOWS, skipIfRocm, IS_SANDCASTLE
+from textwrap import dedent
+from itertools import product, permutations
 
 from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
-    backward_graph
+    backward_graph, get_lstm_inputs, get_milstm_inputs, LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
 
 
 class TestFuser(JitTestCase):
@@ -663,7 +667,7 @@ class TestFuser(JitTestCase):
             ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
             return ingate * forgetgate * cellgate * outgate
         ''')
-        for permutation in itertools.permutations(choices, len(choices)):
+        for permutation in permutations(choices, len(choices)):
             code = template.format(*permutation)
             scope = {}
             exec(code, globals(), scope)
@@ -876,3 +880,6 @@ class TestFuser(JitTestCase):
         ge = self.checkScript(scaleshift, inputs)
         self.assertGraphContainsExactly(
             ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
+
+if __name__ == '__main__':
+    run_tests()
index 65572d3..b547176 100644 (file)
@@ -5,7 +5,6 @@ import unittest
 import textwrap
 import torch
 from collections import namedtuple
-import itertools
 
 
 path = os.path.dirname(os.path.realpath(__file__))
index 75ff7fa..b1495d0 100644 (file)
@@ -778,7 +778,7 @@ class ShapePropagator {
         [this](Node* node) -> type_vec_t {
           if (auto maybe_tensor_types =
                   gatherTensorTypes<DimensionedTensorType>(node)) {
-            AT_ASSERT(maybe_tensor_types->size() == 2);
+            AT_ASSERT(maybe_tensor_types->size() >= 2);
             auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
             auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
             size_t arg_for_type = 0;