From aa017db59c712778089cdda05ee54c1adc7a3777 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 9 Apr 2019 11:53:23 -0700 Subject: [PATCH] make test_jit_fuser runnable Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19036 Differential Revision: D14839800 Pulled By: wanchaol fbshipit-source-id: b52c131b58e1b42a8c3da5d1117217c3dc2e5f5b --- test/test_jit_fuser.py | 15 +++++++++++---- test/test_namedtuple_return_api.py | 1 - torch/csrc/jit/passes/shape_analysis.cpp | 2 +- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index c435f8f..9f45947 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -5,13 +5,17 @@ from __future__ import unicode_literals import unittest import torch +import torch.nn as nn +import torch.nn.functional as F from torch import Tensor +from torch.testing import FileCheck -from common_utils import IS_WINDOWS, \ - skipIfRocm, IS_SANDCASTLE +from common_utils import run_tests, IS_WINDOWS, skipIfRocm, IS_SANDCASTLE +from textwrap import dedent +from itertools import product, permutations from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \ - backward_graph + backward_graph, get_lstm_inputs, get_milstm_inputs, LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell class TestFuser(JitTestCase): @@ -663,7 +667,7 @@ class TestFuser(JitTestCase): ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) return ingate * forgetgate * cellgate * outgate ''') - for permutation in itertools.permutations(choices, len(choices)): + for permutation in permutations(choices, len(choices)): code = template.format(*permutation) scope = {} exec(code, globals(), scope) @@ -876,3 +880,6 @@ class TestFuser(JitTestCase): ge = self.checkScript(scaleshift, inputs) self.assertGraphContainsExactly( ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 65572d3..b547176 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -5,7 +5,6 @@ import unittest import textwrap import torch from collections import namedtuple -import itertools path = os.path.dirname(os.path.realpath(__file__)) diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 75ff7fa..b1495d0 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -778,7 +778,7 @@ class ShapePropagator { [this](Node* node) -> type_vec_t { if (auto maybe_tensor_types = gatherTensorTypes(node)) { - AT_ASSERT(maybe_tensor_types->size() == 2); + AT_ASSERT(maybe_tensor_types->size() >= 2); auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType(); auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType(); size_t arg_for_type = 0; -- 2.7.4