from torch._six import inf, PY2, builtins
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
- freeze_rng_state, set_rng_seed
+ freeze_rng_state, set_rng_seed, slowTest
from common_nn import module_tests, new_module_tests, criterion_tests
from textwrap import dedent
from functools import wraps
w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
self.assertEqual(ys, ybs.examples())
+ @slowTest
def test_greedy_search(self):
def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch)
self.assertEqual(ys, ybs.examples())
+ @slowTest
def test_beam_search(self):
def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx):
self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
self.assertEqual(m_orig.forward(input), m_import.forward(input))
+ @slowTest
@skipIfNoTorchVision
def test_script_module_trace_resnet18(self):
x = torch.ones(1, 3, 224, 224)
self.assertEqual(output_orig, output_import)
self.assertEqual(grad_orig, grad_import)
+ @slowTest
@skipIfNoTorchVision
def test_script_module_script_resnet(self):
def conv1x1(in_planes, out_planes, stride=1):