import unittest
import torch
+import torch.nn as nn
+import torch.nn.functional as F
from torch import Tensor
+from torch.testing import FileCheck
-from common_utils import IS_WINDOWS, \
- skipIfRocm, IS_SANDCASTLE
+from common_utils import run_tests, IS_WINDOWS, skipIfRocm, IS_SANDCASTLE
+from textwrap import dedent
+from itertools import product, permutations
from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
- backward_graph
+ backward_graph, get_lstm_inputs, get_milstm_inputs, LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
class TestFuser(JitTestCase):
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
return ingate * forgetgate * cellgate * outgate
''')
- for permutation in itertools.permutations(choices, len(choices)):
+ for permutation in permutations(choices, len(choices)):
code = template.format(*permutation)
scope = {}
exec(code, globals(), scope)
ge = self.checkScript(scaleshift, inputs)
self.assertGraphContainsExactly(
ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
+
+if __name__ == '__main__':
+ run_tests()
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types =
gatherTensorTypes<DimensionedTensorType>(node)) {
- AT_ASSERT(maybe_tensor_types->size() == 2);
+ AT_ASSERT(maybe_tensor_types->size() >= 2);
auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
size_t arg_for_type = 0;