Move fast rnn benchmark to pytorch/pytorch
authorWanchao Liang <wanchaol@users.noreply.github.com>
Wed, 27 Mar 2019 21:39:33 +0000 (14:39 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 27 Mar 2019 21:46:09 +0000 (14:46 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18369

Differential Revision: D14652039

Pulled By: wanchaol

fbshipit-source-id: 1177b1f60d96672c3e2c9d527b56ee06ca7c0af1

benchmarks/README.md [new file with mode: 0644]
benchmarks/fastrnns/README.md [new file with mode: 0644]
benchmarks/fastrnns/__init__.py [new file with mode: 0644]
benchmarks/fastrnns/bench.py [new file with mode: 0644]
benchmarks/fastrnns/cells.py [new file with mode: 0644]
benchmarks/fastrnns/custom_lstms.py [new file with mode: 0644]
benchmarks/fastrnns/factory.py [new file with mode: 0644]
benchmarks/fastrnns/profile.py [new file with mode: 0644]
benchmarks/fastrnns/runner.py [new file with mode: 0644]
benchmarks/fastrnns/scratch.py [new file with mode: 0644]
benchmarks/fastrnns/test.py [new file with mode: 0644]

diff --git a/benchmarks/README.md b/benchmarks/README.md
new file mode 100644 (file)
index 0000000..1477f9e
--- /dev/null
@@ -0,0 +1,29 @@
+# PyTorch Benchmarks
+
+NOTE: This folder is currently work in progress.
+
+This folder contains scripts that produce reproducible timings of various PyTorch features.
+
+It also provides mechanisms to compare PyTorch with other frameworks.
+
+## Setup environment
+Make sure you're on a machine with CUDA, torchvision, and pytorch installed. Install in the following order:
+```
+# Install torchvision. It comes with the pytorch stable release binary
+conda install pytorch torchvision -c pytorch
+
+# Install the latest pytorch master from source.
+# It should supercede the installation from the release binary.
+cd $PYTORCH_HOME
+python setup.py build develop
+
+# Check the pytorch installation version
+python -c "import torch; print(torch.__version__)"
+```
+
+## Benchmark List
+
+Please refer to each subfolder to discover each benchmark suite
+
+* [Fast RNNs benchmarks](fastrnns/README.md)
+
diff --git a/benchmarks/fastrnns/README.md b/benchmarks/fastrnns/README.md
new file mode 100644 (file)
index 0000000..87f93fa
--- /dev/null
@@ -0,0 +1,42 @@
+# Fast RNN benchmarks
+
+Benchmarks for TorchScript models 
+
+For most stable results, do the following:
+- Set CPU Governor to performance mode (as opposed to energy save)
+- Turn off turbo for all CPUs (assuming Intel CPUs)
+- Shield cpus via `cset shield` when running benchmarks.
+
+Some of these scripts accept command line args but most of them do not because
+I was lazy. They will probably be added sometime in the future, but the default
+sizes are pretty reasonable.
+
+## Test fastrnns (fwd + bwd) correctness
+
+Test the fastrnns benchmarking scripts with the following:
+`python -m fastrnns.test`
+or run the test independently:
+`python -m fastrnns.test --rnns jit`
+
+## Run benchmarks
+
+`python -m fastrnns.bench`
+
+should give a good comparision, or you can specify the type of model to run
+
+`python -m fastrnns.bench --rnns cudnn aten jit --group rnns` 
+
+## Run model profiling, calls nvprof
+
+`python -m fastrnns.profile`
+
+should generate nvprof file for all models somewhere.
+you can also specify the models to generate nvprof files separately:
+
+`python -m fastrnns.profile --rnns aten jit` 
+
+### Caveats
+
+Use Linux for the most accurate timing. A lot of these tests only run
+on CUDA.
+
diff --git a/benchmarks/fastrnns/__init__.py b/benchmarks/fastrnns/__init__.py
new file mode 100644 (file)
index 0000000..f32d4a0
--- /dev/null
@@ -0,0 +1,10 @@
+from .cells import *
+from .factory import *
+from .test import *
+
+# (output, next_state) = cell(input, state)
+seqLength = 100
+numLayers = 2
+inputSize = 512
+hiddenSize = 512
+miniBatch = 64
diff --git a/benchmarks/fastrnns/bench.py b/benchmarks/fastrnns/bench.py
new file mode 100644 (file)
index 0000000..71cad4a
--- /dev/null
@@ -0,0 +1,201 @@
+from __future__ import print_function
+import argparse
+from collections import namedtuple
+import torch
+import gc
+import sys
+import json
+import copy
+
+from .runner import get_nn_runners
+
+
+BenchResult = namedtuple('BenchResult', [
+    'name', 'avg_fwd', 'std_fwd', 'avg_bwd', 'std_bwd',
+])
+
+
+def fit_str(string, colwidth=16):
+    if len(string) < colwidth:
+        return (colwidth - len(string)) * ' ' + string
+    else:
+        return string[:colwidth]
+
+
+def to_str(item):
+    if isinstance(item, float):
+        return '%.4g' % item
+    return str(item)
+
+
+def print_header(colwidth=16, sep=' '):
+    items = []
+    for item in BenchResult._fields:
+        items.append(fit_str(item))
+    return sep.join(items)
+
+
+def pretty_print(benchresult, colwidth=16, sep=' '):
+    items = []
+    for thing in benchresult:
+        items.append(fit_str(to_str(thing)))
+    return sep.join(items)
+
+
+def trainbench(name, rnn_creator, nloops=100, warmup=10,
+               seqLength=100, numLayers=1, inputSize=512, hiddenSize=512,
+               miniBatch=64, device='cuda', seed=None):
+    def train_batch(modeldef):
+        # CUDA events for timing
+        fwd_start_event = torch.cuda.Event(enable_timing=True)
+        fwd_end_event = torch.cuda.Event(enable_timing=True)
+        bwd_start_event = torch.cuda.Event(enable_timing=True)
+        bwd_end_event = torch.cuda.Event(enable_timing=True)
+
+        gc.collect()
+
+        fwd_start_event.record()
+        forward_output = modeldef.forward(*modeldef.inputs)
+        fwd_end_event.record()
+
+        # XXX: Use if need to print something
+        # print(modeldef.forward.graph_for(*modeldef.inputs))
+
+        if modeldef.backward_setup is not None:
+            backward_input = modeldef.backward_setup(forward_output)
+        else:
+            backward_input = forward_output
+
+        gc.collect()
+
+        bwd_start_event.record()
+        if modeldef.backward is not None:
+            modeldef.backward(*backward_input)
+        bwd_end_event.record()
+
+        if modeldef.backward is not None:
+            for param in modeldef.params:
+                assert param.grad is not None
+                param.grad.data.zero_()
+
+        torch.cuda.synchronize()
+
+        fwd_time = fwd_start_event.elapsed_time(fwd_end_event)
+        bwd_time = bwd_start_event.elapsed_time(bwd_end_event)
+        return fwd_time, bwd_time
+
+    assert device == 'cuda'
+    creator_args = dict(seqLength=seqLength, numLayers=numLayers,
+                        inputSize=inputSize, hiddenSize=hiddenSize,
+                        miniBatch=miniBatch, device=device, seed=seed)
+    modeldef = rnn_creator(**creator_args)
+
+    [train_batch(modeldef) for _ in range(warmup)]
+
+    results = [train_batch(modeldef) for _ in range(nloops)]
+    fwd_times, bwd_times = zip(*results)
+
+    fwd_times = torch.tensor(fwd_times)
+    bwd_times = torch.tensor(bwd_times)
+
+    return BenchResult(name=name,
+                       avg_fwd=fwd_times.mean().item(),
+                       std_fwd=fwd_times.std().item(),
+                       avg_bwd=bwd_times.mean().item(),
+                       std_bwd=bwd_times.std().item())
+
+
+def print_stderr(*args, **kwargs):
+    kwargs['file'] = sys.stderr
+    return print(*args, **kwargs)
+
+
+def bench(rnn_runners, group_name, print_json=False, sep=' ', **params):
+    print_stderr(print_header(sep=sep))
+    results = {}
+    for name, creator, context in rnn_runners:
+        with context():
+            try:
+                result = trainbench(name, creator, **params)
+                print_stderr(pretty_print(result, sep=sep))
+                results[name] = result
+            except Exception as e:
+                if not print_json:
+                    raise
+
+    return {
+        group_name: {k: v.avg_fwd for k, v in results.items()},
+        group_name + '-backward': {k: v.avg_bwd for k, v in results.items()},
+    }
+
+
+def bench_group(model_list, bench_name, bench_group, bench_args):
+    print_stderr('Benchmarking {}s...'.format(bench_name))
+    nn_results = bench(get_nn_runners(*model_list), bench_group, **bench_args)
+    print_stderr('')
+    return nn_results
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Profile RNNs')
+
+    # groups help control which test group you want to run
+    # if you only want to run one/two benchmark, run it with
+    # e.g: python -m fastrnns.bench --rnns jit and --group rnns
+    default_groups = ['cnns', 'rnns']
+
+    parser.add_argument('--seqLength', default='100', type=int)
+    parser.add_argument('--numLayers', default='1', type=int)
+    parser.add_argument('--inputSize', default='512', type=int)
+    parser.add_argument('--hiddenSize', default='512', type=int)
+    parser.add_argument('--miniBatch', default='64', type=int)
+    parser.add_argument('--warmup', default='10', type=int)
+    parser.add_argument('--nloops', default='100', type=int)
+    parser.add_argument('--device', default='cuda', type=str)
+    parser.add_argument('--variable_lstms', action='store_true',
+                        help='Also benchmark variable sequence length lstms '
+                        'Note that some of these run really slowly '
+                        'and that the `seqLength` flag will be ignored.')
+    parser.add_argument('--sep', default=' ', type=str)
+    parser.add_argument('--print-json', action='store_true')
+    parser.add_argument('--rnns', nargs='*',
+                        help='What to run. cudnn, aten, jit, etc')
+    parser.add_argument('--cnns', nargs='*',
+                        help='What to run. resnet18, resnet18_jit, resnet50, etc')
+    parser.add_argument('--group', nargs='*', default=default_groups, help='Which group to run. cnns, rnns, etc.')
+
+    args = parser.parse_args()
+    rnns = args.rnns or ['cudnn', 'aten', 'jit', 'jit_premul', 'jit_simple',
+                         'jit_multilayer', 'py']
+    cnns = args.cnns or ['resnet18', 'resnet18_jit', 'resnet50', 'resnet50_jit']
+    # TODO: Maybe add a separate section for the layernorm/dropout lstms
+    # 'cudnn_layernorm', jit_layernorm', 'jit_layernom_decom',
+    # 'jit', 'jit_dropout', 'cudnn_dropout'
+    vlrnns = ['vl_cudnn', 'vl_jit', 'vl_py']
+
+    if args.print_json:
+        print_stderr = lambda *args, **kwargs: None    # noqa
+    print_stderr(args)
+
+    bench_args = copy.deepcopy(vars(args))
+    should_bench_varlen_lstms = args.variable_lstms
+    del bench_args['group']
+    del bench_args['rnns']
+    del bench_args['cnns']
+    del bench_args['variable_lstms']
+
+    results = dict()
+    if should_bench_varlen_lstms:
+        if args.nloops + args.warmup > 30:
+            print_stderr(
+                'WARNING: some of the variable sequence length lstms are '
+                'very unoptimized and therefore take forever to run.')
+        results.update(bench_group(vlrnns, 'variable-length sequence LSTM', 'vl_lstm', bench_args))
+
+    if 'rnns' in args.group:
+        results.update(bench_group(rnns, 'LSTM', 'lstm', bench_args))
+    if 'cnns' in args.group:
+        results.update(bench_group(cnns, 'ResNet', 'resnet', bench_args))
+
+    if args.print_json:
+        print(json.dumps(results))
diff --git a/benchmarks/fastrnns/cells.py b/benchmarks/fastrnns/cells.py
new file mode 100644 (file)
index 0000000..7c80c49
--- /dev/null
@@ -0,0 +1,101 @@
+import torch
+
+
+def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
+    Wx = x.mm(w_ih.t())
+    Uz = hx.mm(w_hh.t())
+
+    # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf
+    gates = (alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias)
+
+    # Same as LSTMCell after this point
+    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+
+    ingate = ingate.sigmoid()
+    forgetgate = forgetgate.sigmoid()
+    cellgate = cellgate.tanh()
+    outgate = outgate.sigmoid()
+
+    cy = (forgetgate * cx) + (ingate * cellgate)
+    hy = outgate * cy.tanh()
+
+    return hy, cy
+
+
+def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
+    # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
+    hx, cx = hidden
+    gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
+
+    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+
+    ingate = torch.sigmoid(ingate)
+    forgetgate = torch.sigmoid(forgetgate)
+    cellgate = torch.tanh(cellgate)
+    outgate = torch.sigmoid(outgate)
+
+    cy = (forgetgate * cx) + (ingate * cellgate)
+    hy = outgate * torch.tanh(cy)
+
+    return hy, cy
+
+
+def flat_lstm_cell(input, hx, cx, w_ih, w_hh, b_ih, b_hh):
+    # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
+    gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
+
+    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+
+    ingate = torch.sigmoid(ingate)
+    forgetgate = torch.sigmoid(forgetgate)
+    cellgate = torch.tanh(cellgate)
+    outgate = torch.sigmoid(outgate)
+
+    cy = (forgetgate * cx) + (ingate * cellgate)
+    hy = outgate * torch.tanh(cy)
+
+    return hy, cy
+
+
+def premul_lstm_cell(igates, hidden, w_hh, b_ih, b_hh):
+    # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
+    hx, cx = hidden
+    gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh
+
+    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+
+    ingate = torch.sigmoid(ingate)
+    forgetgate = torch.sigmoid(forgetgate)
+    cellgate = torch.tanh(cellgate)
+    outgate = torch.sigmoid(outgate)
+
+    cy = (forgetgate * cx) + (ingate * cellgate)
+    hy = outgate * torch.tanh(cy)
+
+    return hy, cy
+
+
+def gru_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
+    gi = torch.mm(input, w_ih.t()) + b_ih
+    gh = torch.mm(hidden, w_hh.t()) + b_hh
+    i_r, i_i, i_n = gi.chunk(3, 1)
+    h_r, h_i, h_n = gh.chunk(3, 1)
+
+    resetgate = torch.sigmoid(i_r + h_r)
+    inputgate = torch.sigmoid(i_i + h_i)
+    newgate = torch.tanh(i_n + resetgate * h_n)
+    hy = newgate + inputgate * (hidden - newgate)
+
+    return hy
+
+
+def rnn_relu_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
+    igates = torch.mm(input, w_ih.t()) + b_ih
+    hgates = torch.mm(hidden, w_hh.t()) + b_hh
+    return torch.relu(igates + hgates)
+
+
+def rnn_tanh_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
+    igates = torch.mm(input, w_ih.t()) + b_ih
+    hgates = torch.mm(hidden, w_hh.t()) + b_hh
+    return torch.tanh(igates + hgates)
diff --git a/benchmarks/fastrnns/custom_lstms.py b/benchmarks/fastrnns/custom_lstms.py
new file mode 100644 (file)
index 0000000..d835b3e
--- /dev/null
@@ -0,0 +1,461 @@
+import torch
+import torch.nn as nn
+from torch.nn import Parameter
+import torch.jit as jit
+import warnings
+from collections import namedtuple
+from typing import List, Tuple
+from torch import Tensor
+import numbers
+
+'''
+Some helper classes for writing custom TorchScript LSTMs.
+
+Goals:
+- Classes are easy to read, use, and extend
+- Performance of custom LSTMs approach fused-kernel-levels of speed.
+
+A few notes about features we could add to clean up the below code:
+- Support enumerate with nn.ModuleList:
+  https://github.com/pytorch/pytorch/issues/14471
+- Support enumerate/zip with lists:
+  https://github.com/pytorch/pytorch/issues/15952
+- Support overriding of class methods:
+  https://github.com/pytorch/pytorch/issues/10733
+- Support passing around user-defined namedtuple types for readability
+- Support slicing w/ range. It enables reversing lists easily.
+  https://github.com/pytorch/pytorch/issues/10774
+- Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose
+  https://github.com/pytorch/pytorch/pull/14922
+'''
+
+
+def script_lstm(input_size, hidden_size, num_layers, bias=True,
+                batch_first=False, dropout=False, bidirectional=False):
+    '''Returns a ScriptModule that mimics a PyTorch native LSTM.'''
+
+    # The following are not implemented.
+    assert bias
+    assert not batch_first
+
+    if bidirectional:
+        stack_type = StackedLSTM2
+        layer_type = BidirLSTMLayer
+        dirs = 2
+    elif dropout:
+        stack_type = StackedLSTMWithDropout
+        layer_type = LSTMLayer
+        dirs = 1
+    else:
+        stack_type = StackedLSTM
+        layer_type = LSTMLayer
+        dirs = 1
+
+    return stack_type(num_layers, layer_type,
+                      first_layer_args=[LSTMCell, input_size, hidden_size],
+                      other_layer_args=[LSTMCell, hidden_size * dirs,
+                                        hidden_size])
+
+
+def script_lnlstm(input_size, hidden_size, num_layers, bias=True,
+                  batch_first=False, dropout=False, bidirectional=False,
+                  decompose_layernorm=False):
+    '''Returns a ScriptModule that mimics a PyTorch native LSTM.'''
+
+    # The following are not implemented.
+    assert bias
+    assert not batch_first
+    assert not dropout
+
+    if bidirectional:
+        stack_type = StackedLSTM2
+        layer_type = BidirLSTMLayer
+        dirs = 2
+    else:
+        stack_type = StackedLSTM
+        layer_type = LSTMLayer
+        dirs = 1
+
+    return stack_type(num_layers, layer_type,
+                      first_layer_args=[LayerNormLSTMCell, input_size, hidden_size,
+                                        decompose_layernorm],
+                      other_layer_args=[LayerNormLSTMCell, hidden_size * dirs,
+                                        hidden_size, decompose_layernorm])
+
+
+LSTMState = namedtuple('LSTMState', ['hx', 'cx'])
+
+
+def reverse(lst):
+    # type: (List[Tensor]) -> List[Tensor]
+    return lst[::-1]
+
+
+class LSTMCell(jit.ScriptModule):
+    def __init__(self, input_size, hidden_size):
+        super(LSTMCell, self).__init__()
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
+        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
+        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
+        self.bias_hh = Parameter(torch.randn(4 * hidden_size))
+
+    @jit.script_method
+    def forward(self, input, state):
+        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        hx, cx = state
+        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
+                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
+        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+
+        ingate = torch.sigmoid(ingate)
+        forgetgate = torch.sigmoid(forgetgate)
+        cellgate = torch.tanh(cellgate)
+        outgate = torch.sigmoid(outgate)
+
+        cy = (forgetgate * cx) + (ingate * cellgate)
+        hy = outgate * torch.tanh(cy)
+
+        return hy, (hy, cy)
+
+
+class LayerNorm(jit.ScriptModule):
+    def __init__(self, normalized_shape):
+        super(LayerNorm, self).__init__()
+        if isinstance(normalized_shape, numbers.Integral):
+            normalized_shape = (normalized_shape,)
+        normalized_shape = torch.Size(normalized_shape)
+
+        # XXX: This is true for our LSTM / NLP use case and helps simplify code
+        assert len(normalized_shape) == 1
+
+        self.weight = Parameter(torch.ones(normalized_shape))
+        self.bias = Parameter(torch.zeros(normalized_shape))
+        self.normalized_shape = normalized_shape
+
+    @jit.script_method
+    def compute_layernorm_stats(self, input):
+        mu = input.mean(-1, keepdim=True)
+        sigma = input.std(-1, keepdim=True, unbiased=False)
+        return mu, sigma
+
+    @jit.script_method
+    def forward(self, input):
+        mu, sigma = self.compute_layernorm_stats(input)
+        return (input - mu) / sigma * self.weight + self.bias
+
+
+class LayerNormLSTMCell(jit.ScriptModule):
+    def __init__(self, input_size, hidden_size, decompose_layernorm=False):
+        super(LayerNormLSTMCell, self).__init__()
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
+        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
+        # The layernorms provide learnable biases
+
+        if decompose_layernorm:
+            ln = LayerNorm
+        else:
+            ln = nn.LayerNorm
+
+        self.layernorm_i = ln(4 * hidden_size)
+        self.layernorm_h = ln(4 * hidden_size)
+        self.layernorm_c = ln(hidden_size)
+
+    @jit.script_method
+    def forward(self, input, state):
+        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        hx, cx = state
+        igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
+        hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
+        gates = igates + hgates
+        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+
+        ingate = torch.sigmoid(ingate)
+        forgetgate = torch.sigmoid(forgetgate)
+        cellgate = torch.tanh(cellgate)
+        outgate = torch.sigmoid(outgate)
+
+        cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
+        hy = outgate * torch.tanh(cy)
+
+        return hy, (hy, cy)
+
+
+class LSTMLayer(jit.ScriptModule):
+    def __init__(self, cell, *cell_args):
+        super(LSTMLayer, self).__init__()
+        self.cell = cell(*cell_args)
+
+    @jit.script_method
+    def forward(self, input, state):
+        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        inputs = input.unbind(0)
+        outputs = torch.jit.annotate(List[Tensor], [])
+        for i in range(len(inputs)):
+            out, state = self.cell(inputs[i], state)
+            outputs += [out]
+        return torch.stack(outputs), state
+
+
+class ReverseLSTMLayer(jit.ScriptModule):
+    def __init__(self, cell, *cell_args):
+        super(ReverseLSTMLayer, self).__init__()
+        self.cell = cell(*cell_args)
+
+    @jit.script_method
+    def forward(self, input, state):
+        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        inputs = reverse(input.unbind(0))
+        outputs = jit.annotate(List[Tensor], [])
+        for i in range(len(inputs)):
+            out, state = self.cell(inputs[i], state)
+            outputs += [out]
+        return torch.stack(reverse(outputs)), state
+
+
+class BidirLSTMLayer(jit.ScriptModule):
+    __constants__ = ['directions']
+
+    def __init__(self, cell, *cell_args):
+        super(BidirLSTMLayer, self).__init__()
+        self.directions = nn.ModuleList([
+            LSTMLayer(cell, *cell_args),
+            ReverseLSTMLayer(cell, *cell_args),
+        ])
+
+    @jit.script_method
+    def forward(self, input, states):
+        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
+        # List[LSTMState]: [forward LSTMState, backward LSTMState]
+        outputs = jit.annotate(List[Tensor], [])
+        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
+        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
+        i = 0
+        for direction in self.directions:
+            state = states[i]
+            out, out_state = direction(input, state)
+            outputs += [out]
+            output_states += [out_state]
+            i += 1
+        return torch.cat(outputs, -1), output_states
+
+
+def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
+    layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
+                                           for _ in range(num_layers - 1)]
+    return nn.ModuleList(layers)
+
+
+class StackedLSTM(jit.ScriptModule):
+    __constants__ = ['layers']  # Necessary for iterating through self.layers
+
+    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
+        super(StackedLSTM, self).__init__()
+        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
+                                        other_layer_args)
+
+    @jit.script_method
+    def forward(self, input, states):
+        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
+        # List[LSTMState]: One state per layer
+        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
+        output = input
+        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
+        i = 0
+        for rnn_layer in self.layers:
+            state = states[i]
+            output, out_state = rnn_layer(output, state)
+            output_states += [out_state]
+            i += 1
+        return output, output_states
+
+
+# Differs from StackedLSTM in that its forward method takes
+# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM
+# except we don't support overriding script methods.
+# https://github.com/pytorch/pytorch/issues/10733
+class StackedLSTM2(jit.ScriptModule):
+    __constants__ = ['layers']  # Necessary for iterating through self.layers
+
+    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
+        super(StackedLSTM2, self).__init__()
+        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
+                                        other_layer_args)
+
+    @jit.script_method
+    def forward(self, input, states):
+        # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]
+        # List[List[LSTMState]]: The outer list is for layers,
+        #                        inner list is for directions.
+        output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
+        output = input
+        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
+        i = 0
+        for rnn_layer in self.layers:
+            state = states[i]
+            output, out_state = rnn_layer(output, state)
+            output_states += [out_state]
+            i += 1
+        return output, output_states
+
+
+class StackedLSTMWithDropout(jit.ScriptModule):
+    # Necessary for iterating through self.layers and dropout support
+    __constants__ = ['layers', 'num_layers']
+
+    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
+        super(StackedLSTMWithDropout, self).__init__()
+        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
+                                        other_layer_args)
+        # Introduces a Dropout layer on the outputs of each LSTM layer except
+        # the last layer, with dropout probability = 0.4.
+        self.num_layers = num_layers
+
+        if (num_layers == 1):
+            warnings.warn("dropout lstm adds dropout layers after all but last "
+                          "recurrent layer, it expects num_layers greater than "
+                          "1, but got num_layers = 1")
+
+        self.dropout_layer = nn.Dropout(0.4)
+
+    @jit.script_method
+    def forward(self, input, states):
+        # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
+        # List[LSTMState]: One state per layer
+        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
+        output = input
+        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
+        i = 0
+        for rnn_layer in self.layers:
+            state = states[i]
+            output, out_state = rnn_layer(output, state)
+            # Apply the dropout layer except the last layer
+            if i < self.num_layers - 1:
+                output = self.dropout_layer(output)
+            output_states += [out_state]
+            i += 1
+        return output, output_states
+
+
+def flatten_states(states):
+    states = list(zip(*states))
+    assert len(states) == 2
+    return [torch.stack(state) for state in states]
+
+
+def double_flatten_states(states):
+    # XXX: Can probably write this in a nicer way
+    states = flatten_states([flatten_states(inner) for inner in states])
+    return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states]
+
+
+def test_script_rnn_layer(seq_len, batch, input_size, hidden_size):
+    inp = torch.randn(seq_len, batch, input_size)
+    state = LSTMState(torch.randn(batch, hidden_size),
+                      torch.randn(batch, hidden_size))
+    rnn = LSTMLayer(LSTMCell, input_size, hidden_size)
+    out, out_state = rnn(inp, state)
+
+    # Control: pytorch native LSTM
+    lstm = nn.LSTM(input_size, hidden_size, 1)
+    lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0))
+    for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()):
+        assert lstm_param.shape == custom_param.shape
+        with torch.no_grad():
+            lstm_param.copy_(custom_param)
+    lstm_out, lstm_out_state = lstm(inp, lstm_state)
+
+    assert (out - lstm_out).abs().max() < 1e-5
+    assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5
+    assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5
+
+
+def test_script_stacked_rnn(seq_len, batch, input_size, hidden_size,
+                            num_layers):
+    inp = torch.randn(seq_len, batch, input_size)
+    states = [LSTMState(torch.randn(batch, hidden_size),
+                        torch.randn(batch, hidden_size))
+              for _ in range(num_layers)]
+    rnn = script_lstm(input_size, hidden_size, num_layers)
+    out, out_state = rnn(inp, states)
+    custom_state = flatten_states(out_state)
+
+    # Control: pytorch native LSTM
+    lstm = nn.LSTM(input_size, hidden_size, num_layers)
+    lstm_state = flatten_states(states)
+    for layer in range(num_layers):
+        custom_params = list(rnn.parameters())[4 * layer: 4 * (layer + 1)]
+        for lstm_param, custom_param in zip(lstm.all_weights[layer],
+                                            custom_params):
+            assert lstm_param.shape == custom_param.shape
+            with torch.no_grad():
+                lstm_param.copy_(custom_param)
+    lstm_out, lstm_out_state = lstm(inp, lstm_state)
+
+    assert (out - lstm_out).abs().max() < 1e-5
+    assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5
+    assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5
+
+
+def test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size,
+                                  num_layers):
+    inp = torch.randn(seq_len, batch, input_size)
+    states = [[LSTMState(torch.randn(batch, hidden_size),
+                         torch.randn(batch, hidden_size))
+               for _ in range(2)]
+              for _ in range(num_layers)]
+    rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True)
+    out, out_state = rnn(inp, states)
+    custom_state = double_flatten_states(out_state)
+
+    # Control: pytorch native LSTM
+    lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)
+    lstm_state = double_flatten_states(states)
+    for layer in range(num_layers):
+        for direct in range(2):
+            index = 2 * layer + direct
+            custom_params = list(rnn.parameters())[4 * index: 4 * index + 4]
+            for lstm_param, custom_param in zip(lstm.all_weights[index],
+                                                custom_params):
+                assert lstm_param.shape == custom_param.shape
+                with torch.no_grad():
+                    lstm_param.copy_(custom_param)
+    lstm_out, lstm_out_state = lstm(inp, lstm_state)
+
+    assert (out - lstm_out).abs().max() < 1e-5
+    assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5
+    assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5
+
+
+def test_script_stacked_lstm_dropout(seq_len, batch, input_size, hidden_size,
+                                     num_layers):
+    inp = torch.randn(seq_len, batch, input_size)
+    states = [LSTMState(torch.randn(batch, hidden_size),
+                        torch.randn(batch, hidden_size))
+              for _ in range(num_layers)]
+    rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True)
+
+    # just a smoke test
+    out, out_state = rnn(inp, states)
+
+
+def test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size,
+                               num_layers):
+    inp = torch.randn(seq_len, batch, input_size)
+    states = [LSTMState(torch.randn(batch, hidden_size),
+                        torch.randn(batch, hidden_size))
+              for _ in range(num_layers)]
+    rnn = script_lnlstm(input_size, hidden_size, num_layers)
+
+    # just a smoke test
+    out, out_state = rnn(inp, states)
+
+
+test_script_rnn_layer(5, 2, 3, 7)
+test_script_stacked_rnn(5, 2, 3, 7, 4)
+test_script_stacked_bidir_rnn(5, 2, 3, 7, 4)
+test_script_stacked_lstm_dropout(5, 2, 3, 7, 4)
+test_script_stacked_lnlstm(5, 2, 3, 7, 4)
diff --git a/benchmarks/fastrnns/factory.py b/benchmarks/fastrnns/factory.py
new file mode 100644 (file)
index 0000000..90f49bc
--- /dev/null
@@ -0,0 +1,432 @@
+import torch
+
+from collections import namedtuple
+
+from .cells import lstm_cell, premul_lstm_cell, flat_lstm_cell
+
+
+# list[list[T]] -> list[T]
+def flatten_list(lst):
+    result = []
+    for inner in lst:
+        result.extend(inner)
+    return result
+
+
+'''
+Define a creator as a function:
+(options) -> (inputs, params, forward, backward_setup, backward)
+inputs: the inputs to the returned 'forward'. One can call
+    forward(*inputs) directly.
+params: List[Tensor] all requires_grad=True parameters.
+forward: function / graph executor / module
+    One can call rnn(rnn_inputs) using the outputs of the creator.
+backward_setup: backward_inputs = backward_setup(*outputs)
+    Then, we pass backward_inputs to backward. If None, then it is assumed to
+    be the identity function.
+backward: Given `output = backward_setup(*forward(*inputs))`, performs
+    backpropagation. If None, then nothing happens.
+
+fastrnns.bench times the forward and backward invocations.
+'''
+
+
+ModelDef = namedtuple('ModelDef', [
+    'inputs', 'params', 'forward', 'backward_setup', 'backward'])
+
+
+def lstm_backward_setup(lstm_outputs, seed=None):
+    hx, _ = lstm_outputs
+    return simple_backward_setup(hx, seed)
+
+
+def simple_backward_setup(output, seed=None):
+    assert isinstance(output, torch.Tensor)
+    if seed:
+        torch.manual_seed(seed)
+    grad_output = torch.randn_like(output)
+    return output, grad_output
+
+
+def simple_backward(output, grad_output):
+    return output.backward(grad_output)
+
+
+def pytorch_lstm_creator(**kwargs):
+    input, hidden, _, module = lstm_inputs(return_module=True, **kwargs)
+    return ModelDef(
+        inputs=[input, hidden],
+        params=flatten_list(module.all_weights),
+        forward=module,
+        backward_setup=lstm_backward_setup,
+        backward=simple_backward)
+
+
+def lstm_creator(script=True, **kwargs):
+    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
+    inputs = [input, hidden] + params[0]
+    return ModelDef(
+        inputs=inputs,
+        params=flatten_list(params),
+        forward=lstm_factory(lstm_cell, script),
+        backward_setup=lstm_backward_setup,
+        backward=simple_backward)
+
+
+def lnlstm_creator(script=True, decompose_layernorm=False, **kwargs):
+    assert script is True
+    from .custom_lstms import script_lnlstm
+    input_size = kwargs['inputSize']
+    hidden_size = kwargs['hiddenSize']
+    seq_len = kwargs['seqLength']
+    batch_size = kwargs['miniBatch']
+    ge = script_lnlstm(input_size, hidden_size, 1,
+                       decompose_layernorm=decompose_layernorm).cuda()
+
+    input = torch.randn(seq_len, batch_size, input_size, device='cuda')
+    states = [(torch.randn(batch_size, hidden_size, device='cuda'),
+               torch.randn(batch_size, hidden_size, device='cuda'))]
+
+    return ModelDef(
+        inputs=[input, states],
+        params=ge.parameters(),
+        forward=ge,
+        backward_setup=lstm_backward_setup,
+        backward=simple_backward)
+
+
+def dropoutlstm_creator(script=True, **kwargs):
+    assert script is True
+    from .custom_lstms import script_lstm, LSTMState
+    input_size = kwargs['inputSize']
+    hidden_size = kwargs['hiddenSize']
+    seq_len = kwargs['seqLength']
+    batch_size = kwargs['miniBatch']
+    num_layers = kwargs['numLayers']
+    ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda()
+
+    input = torch.randn(seq_len, batch_size, input_size, device='cuda')
+    states = [LSTMState(torch.randn(batch_size, hidden_size, device='cuda'),
+                        torch.randn(batch_size, hidden_size, device='cuda'))
+              for _ in range(num_layers)]
+    return ModelDef(
+        inputs=[input, states],
+        params=ge.parameters(),
+        forward=ge,
+        backward_setup=lstm_backward_setup,
+        backward=simple_backward)
+
+
+def lstm_premul_creator(script=True, **kwargs):
+    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
+    inputs = [input, hidden] + params[0]
+    return ModelDef(
+        inputs=inputs,
+        params=flatten_list(params),
+        forward=lstm_factory_premul(premul_lstm_cell, script),
+        backward_setup=lstm_backward_setup,
+        backward=simple_backward)
+
+
+def lstm_simple_creator(script=True, **kwargs):
+    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
+    inputs = [input] + [h[0] for h in hidden] + params[0]
+    return ModelDef(
+        inputs=inputs,
+        params=flatten_list(params),
+        forward=lstm_factory_simple(flat_lstm_cell, script),
+        backward_setup=lstm_backward_setup,
+        backward=simple_backward)
+
+
+def lstm_multilayer_creator(script=True, **kwargs):
+    input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
+    inputs = [input, hidden, flatten_list(params)]
+    return ModelDef(
+        inputs=inputs,
+        params=flatten_list(params),
+        forward=lstm_factory_multilayer(lstm_cell, script),
+        backward_setup=lstm_backward_setup,
+        backward=simple_backward)
+
+
+def imagenet_cnn_creator(arch, jit=True):
+    def creator(device='cuda', **kwargs):
+        model = arch().to(device)
+        x = torch.randn(32, 3, 224, 224, device=device)
+        if jit:
+            model = torch.jit.trace(model, x)
+        return ModelDef(
+            inputs=(x,),
+            params=list(model.parameters()),
+            forward=model,
+            backward_setup=simple_backward_setup,
+            backward=simple_backward)
+
+    return creator
+
+
+def varlen_lstm_inputs(minlen=30, maxlen=100,
+                       numLayers=1, inputSize=512, hiddenSize=512,
+                       miniBatch=64, return_module=False, device='cuda',
+                       seed=None, **kwargs):
+    if seed is not None:
+        torch.manual_seed(seed)
+    lengths = torch.randint(
+        low=minlen, high=maxlen, size=[miniBatch],
+        dtype=torch.long, device=device)
+    x = [torch.randn(length, inputSize, device=device)
+         for length in lengths]
+    hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
+    cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
+    lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device)
+
+    if return_module:
+        return x, lengths, (hx, cx), lstm.all_weights, lstm
+    else:
+        # NB: lstm.all_weights format:
+        # wih, whh, bih, bhh = lstm.all_weights[layer]
+        return x, lengths, (hx, cx), lstm.all_weights, None
+
+
+def varlen_lstm_backward_setup(forward_output, seed=None):
+    if seed:
+        torch.manual_seed(seed)
+    rnn_utils = torch.nn.utils.rnn
+    sequences = forward_output[0]
+    padded = rnn_utils.pad_sequence(sequences)
+    grad = torch.randn_like(padded)
+    return padded, grad
+
+
+def varlen_pytorch_lstm_creator(**kwargs):
+    rnn_utils = torch.nn.utils.rnn
+    sequences, _, hidden, _, module = varlen_lstm_inputs(
+        return_module=True, **kwargs)
+
+    def forward(sequences, hidden):
+        packed = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
+        out, new_hidden = module(packed, hidden)
+        padded, lengths = rnn_utils.pad_packed_sequence(out)
+        # XXX: It's more efficient to store the output in its padded form,
+        # but that might not be conducive to loss computation.
+        # Un-padding the output also makes the backward pass 2x slower...
+        # return [padded[:lengths[i], i, :] for i in range(lengths.size(0))]
+        return padded, new_hidden
+
+    return ModelDef(
+        inputs=[sequences, hidden],
+        params=flatten_list(module.all_weights),
+        forward=forward,
+        backward_setup=lstm_backward_setup,
+        backward=simple_backward)
+
+
+def varlen_lstm_factory(cell, script):
+    def dynamic_rnn(sequences, hiddens, wih, whh, bih, bhh):
+        # type: (List[Tensor], Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]    # noqa
+        hx, cx = hiddens
+        hxs = hx.unbind(1)
+        cxs = cx.unbind(1)
+        # List of: (output, hx, cx)
+        outputs = []
+        hx_outs = []
+        cx_outs = []
+
+        for batch in range(len(sequences)):
+            output = []
+            hy, cy = hxs[batch], cxs[batch]
+            inputs = sequences[batch].unbind(0)
+
+            for seq_idx in range(len(inputs)):
+                hy, cy = cell(
+                    inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh)
+                output += [hy]
+            outputs += [torch.stack(output)]
+            hx_outs += [hy.unsqueeze(0)]
+            cx_outs += [cy.unsqueeze(0)]
+
+        return outputs, (hx_outs, cx_outs)
+
+    if script:
+        cell = torch.jit.script(cell)
+        dynamic_rnn = torch.jit.script(dynamic_rnn)
+
+    return dynamic_rnn
+
+
+def varlen_lstm_creator(script=False, **kwargs):
+    sequences, _, hidden, params, _ = varlen_lstm_inputs(
+        return_module=False, **kwargs)
+    inputs = [sequences, hidden] + params[0]
+    return ModelDef(
+        inputs=inputs,
+        params=flatten_list(params),
+        forward=varlen_lstm_factory(lstm_cell, script),
+        backward_setup=varlen_lstm_backward_setup,
+        backward=simple_backward)
+
+
+# cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark
+# the lowerbound directly. Instead, we only benchmark the foward pass by mimicing the
+# computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve
+# as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself
+# is invariant), the lowerbound of backward pass is hard to get since we lose the
+# intermediate results, we can still optimize the layernorm implementation to make
+# a faster foward lowerbound though.
+def layernorm_pytorch_lstm_creator(**kwargs):
+    input, hidden, _, module = lstm_inputs(return_module=True, **kwargs)
+    batch_size = kwargs['miniBatch']
+    hidden_size = kwargs['hiddenSize']
+    ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda()
+    ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda()
+    ln_c = torch.nn.LayerNorm(hidden_size).cuda()
+    ln_input1 = torch.randn(batch_size, 4 * hidden_size, device='cuda')
+
+    def forward(input, hidden):
+        out, new_hidden = module(input, hidden)
+        # plus (seq_len * three laynorm cell computation) to mimic the lower bound of
+        # Layernorm cudnn LSTM in the forward pass
+        seq_len = len(input.unbind(0))
+        hy, cy = new_hidden
+        for i in range(seq_len):
+            ln_i_output = ln_i(ln_input1)
+            ln_h_output = ln_h(ln_input1)
+            cy = ln_c(cy)
+
+        return out, (hy, cy)
+
+    return ModelDef(
+        inputs=[input, hidden],
+        params=flatten_list(module.all_weights),
+        forward=forward,
+        backward_setup=lstm_backward_setup,
+        backward=None)
+
+
+# input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer])
+# output: packed_weights with format
+# packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize)
+# packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize)
+# packed_weights[2] is bih with size (layer, 4*hiddenSize)
+# packed_weights[3] is bhh with size (layer, 4*hiddenSize)
+def stack_weights(weights):
+    def unzip_columns(mat):
+        assert isinstance(mat, list)
+        assert isinstance(mat[0], list)
+        layers = len(mat)
+        columns = len(mat[0])
+        return [[mat[layer][col] for layer in range(layers)]
+                for col in range(columns)]
+
+    # XXX: script fns have problems indexing multidim lists, so we try to
+    # avoid them by stacking tensors
+    all_weights = weights
+    packed_weights = [torch.stack(param)
+                      for param in unzip_columns(all_weights)]
+    return packed_weights
+
+
+# returns: x, (hx, cx), all_weights, lstm module with all_weights as params
+def lstm_inputs(seqLength=100, numLayers=1, inputSize=512, hiddenSize=512,
+                miniBatch=64, dropout=0.0, return_module=False, device='cuda', seed=None):
+    if seed is not None:
+        torch.manual_seed(seed)
+    x = torch.randn(seqLength, miniBatch, inputSize, device=device)
+    hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
+    cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
+    lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout)
+    if 'cuda' in device:
+        lstm = lstm.cuda()
+
+    if return_module:
+        return x, (hx, cx), lstm.all_weights, lstm
+    else:
+        # NB: lstm.all_weights format:
+        # wih, whh, bih, bhh = lstm.all_weights[layer]
+        return x, (hx, cx), lstm.all_weights, None
+
+
+def lstm_factory(cell, script):
+    def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
+        # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        hx, cx = hidden
+        outputs = []
+        inputs = input.unbind(0)
+        hy, cy = hx[0], cx[0]
+        for seq_idx in range(len(inputs)):
+            hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh)
+            outputs += [hy]
+        return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
+
+    if script:
+        cell = torch.jit.script(cell)
+        dynamic_rnn = torch.jit.script(dynamic_rnn)
+
+    return dynamic_rnn
+
+
+# premul: we're going to premultiply the inputs & weights
+def lstm_factory_premul(premul_cell, script):
+    def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
+        # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        hx, cx = hidden
+        outputs = []
+        inputs = torch.matmul(input, wih.t()).unbind(0)
+        hy, cy = hx[0], cx[0]
+        for seq_idx in range(len(inputs)):
+            hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh)
+            outputs += [hy]
+        return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
+
+    if script:
+        premul_cell = torch.jit.script(premul_cell)
+        dynamic_rnn = torch.jit.script(dynamic_rnn)
+
+    return dynamic_rnn
+
+
+# simple: flat inputs (no tuples), no list to accumulate outputs
+#         useful mostly for benchmarking older JIT versions
+def lstm_factory_simple(cell, script):
+    def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh):
+        hy = hx  # for scoping
+        cy = cx  # for scoping
+        inputs = input.unbind(0)
+        for seq_idx in range(len(inputs)):
+            hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh)
+        return hy, cy
+
+    if script:
+        cell = torch.jit.script(cell)
+        dynamic_rnn = torch.jit.script(dynamic_rnn)
+
+    return dynamic_rnn
+
+
+def lstm_factory_multilayer(cell, script):
+    def dynamic_rnn(input, hidden, params):
+        # type: (Tensor, Tuple[Tensor, Tensor], List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+        params_stride = 4  # NB: this assumes that biases are there
+        hx, cx = hidden
+        hy, cy = hidden  # for scoping...
+        inputs, outputs = input.unbind(0), []
+        for layer in range(hx.size(0)):
+            hy = hx[layer]
+            cy = cx[layer]
+            base_idx = layer * params_stride
+            wih = params[base_idx]
+            whh = params[base_idx + 1]
+            bih = params[base_idx + 2]
+            bhh = params[base_idx + 3]
+            for seq_idx in range(len(inputs)):
+                hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh)
+                outputs += [hy]
+            inputs, outputs = outputs, []
+        return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0))
+
+    if script:
+        cell = torch.jit.script(cell)
+        dynamic_rnn = torch.jit.script(dynamic_rnn)
+
+    return dynamic_rnn
diff --git a/benchmarks/fastrnns/profile.py b/benchmarks/fastrnns/profile.py
new file mode 100644 (file)
index 0000000..3b0ec2b
--- /dev/null
@@ -0,0 +1,138 @@
+import argparse
+import os
+import subprocess
+import sys
+import time
+import torch
+import datetime
+
+from .runner import get_rnn_runners
+
+PY3 = sys.version_info >= (3, 0)
+
+
+def run_rnn(name, rnn_creator, nloops=5,
+            seqLength=100, numLayers=1, inputSize=512, hiddenSize=512,
+            miniBatch=64, device='cuda', seed=None):
+    def run_iter(modeldef):
+        # Forward
+        forward_output = modeldef.forward(*modeldef.inputs)
+
+        # "loss computation" and backward
+        if modeldef.backward_setup is not None:
+            backward_input = modeldef.backward_setup(forward_output)
+        else:
+            backward_input = forward_output
+        if modeldef.backward is not None:
+            modeldef.backward(*backward_input)
+
+        # "Update" parameters
+        if modeldef.backward is not None:
+            for param in modeldef.params:
+                param.grad.data.zero_()
+        torch.cuda.synchronize()
+
+    assert device == 'cuda'
+    creator_args = dict(seqLength=seqLength, numLayers=numLayers,
+                        inputSize=inputSize, hiddenSize=hiddenSize,
+                        miniBatch=miniBatch, device=device, seed=seed)
+    modeldef = rnn_creator(**creator_args)
+
+    [run_iter(modeldef) for _ in range(nloops)]
+
+
+def profile(rnns, sleep_between_seconds=1, nloops=5,
+            internal_run=True,  # Unused, get rid of this TODO
+            seqLength=100, numLayers=1, inputSize=512, hiddenSize=512,
+            miniBatch=64, device='cuda', seed=None):
+    params = dict(seqLength=seqLength, numLayers=numLayers,
+                  inputSize=inputSize, hiddenSize=hiddenSize,
+                  miniBatch=miniBatch, device=device, seed=seed)
+    for name, creator, context in get_rnn_runners(*rnns):
+        with context():
+            run_rnn(name, creator, nloops, **params)
+            time.sleep(sleep_between_seconds)
+
+
+def system(command):
+    """Returns (return-code, stdout, stderr)"""
+    print('[system] {}'.format(command))
+    p = subprocess.Popen(command, stdout=subprocess.PIPE,
+                         stderr=subprocess.PIPE, shell=True)
+    output, err = p.communicate()
+    rc = p.returncode
+    if PY3:
+        output = output.decode("ascii")
+        err = err.decode("ascii")
+    return rc, output, err
+
+
+def describe_sizes(**sizes):
+    # seqLength, numLayers, inputSize, hiddenSize, miniBatch
+    return 's{}-l{}-i{}-h{}-b{}'.format(
+        sizes['seqLength'],
+        sizes['numLayers'],
+        sizes['inputSize'],
+        sizes['hiddenSize'],
+        sizes['miniBatch'],
+    )
+
+
+OUTPUT_DIR = '~/profout/'
+
+
+def nvprof_output_filename(rnns, **params):
+    rnn_tag = '-'.join(rnns)
+    size_tag = describe_sizes(**params)
+    date_tag = datetime.datetime.now().strftime("%m%d%y-%H%M")
+    return '{}prof_{}_{}_{}.nvvp'.format(OUTPUT_DIR, rnn_tag,
+                                         size_tag, date_tag)
+
+
+def nvprof(cmd, outpath):
+    return system('nvprof -o {} {}'.format(outpath, cmd))
+
+
+def full_profile(rnns, **args):
+    args['internal_run'] = True
+    profile_args = []
+    for k, v in args.items():
+        profile_args.append('--{}={}'.format(k, v))
+    profile_args.append('--rnns {}'.format(' '.join(rnns)))
+
+    outpath = nvprof_output_filename(rnns, **args)
+
+    cmd = '{} -m fastrnns.profile {}'.format(
+        sys.executable, ' '.join(profile_args))
+    rc, stdout, stderr = nvprof(cmd, outpath)
+    if rc != 0:
+        raise RuntimeError('stderr: {}\nstdout: {}'.format(stderr, stdout))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Profile RNNs')
+
+    parser.add_argument('--seqLength', default='100', type=int)
+    parser.add_argument('--numLayers', default='1', type=int)
+    parser.add_argument('--inputSize', default='512', type=int)
+    parser.add_argument('--hiddenSize', default='512', type=int)
+    parser.add_argument('--miniBatch', default='64', type=int)
+    parser.add_argument('--sleep_between_seconds', default='1', type=int)
+    parser.add_argument('--nloops', default='5', type=int)
+
+    parser.add_argument('--rnns', nargs='*',
+                        help='What to run. cudnn, aten, jit, etc')
+
+    # if internal_run, we actually run the rnns.
+    # if not internal_run, we shell out to nvprof with internal_run=T
+    parser.add_argument('--internal_run', default=False, type=bool,
+                        help='Don\'t use this')
+    args = parser.parse_args()
+    if args.rnns is None:
+        args.rnns = ['cudnn', 'aten', 'jit']
+    print(args)
+
+    if args.internal_run:
+        profile(**vars(args))
+    else:
+        full_profile(**vars(args))
diff --git a/benchmarks/fastrnns/runner.py b/benchmarks/fastrnns/runner.py
new file mode 100644 (file)
index 0000000..51c850e
--- /dev/null
@@ -0,0 +1,67 @@
+from collections import namedtuple
+from functools import partial
+import torch
+import torchvision.models as cnn
+
+from .factory import *
+
+
+class DisableCuDNN():
+    def __enter__(self):
+        self.saved = torch.backends.cudnn.enabled
+        torch.backends.cudnn.enabled = False
+
+    def __exit__(self, *args, **kwargs):
+        torch.backends.cudnn.enabled = self.saved
+
+
+class DummyContext():
+    def __enter__(self):
+        pass
+
+    def __exit__(self, *args, **kwargs):
+        pass
+
+
+class AssertNoJIT():
+    def __enter__(self):
+        import os
+        enabled = os.environ.get('PYTORCH_JIT', 1)
+        assert not enabled
+
+    def __exit__(self, *args, **kwargs):
+        pass
+
+
+RNNRunner = namedtuple('RNNRunner', [
+    'name', 'creator', 'context',
+])
+
+
+def get_nn_runners(*names):
+    return [nn_runners[name] for name in names]
+
+
+nn_runners = {
+    'cudnn': RNNRunner('cudnn', pytorch_lstm_creator, DummyContext),
+    'cudnn_dropout': RNNRunner('cudnn_dropout', partial(pytorch_lstm_creator, dropout=0.4), DummyContext),
+    'cudnn_layernorm': RNNRunner('cudnn_layernorm', layernorm_pytorch_lstm_creator, DummyContext),
+    'vl_cudnn': RNNRunner('vl_cudnn', varlen_pytorch_lstm_creator, DummyContext),
+    'vl_jit': RNNRunner('vl_jit', partial(varlen_lstm_creator, script=True), DummyContext),
+    'vl_py': RNNRunner('vl_py', varlen_lstm_creator, DummyContext),
+    'aten': RNNRunner('aten', pytorch_lstm_creator, DisableCuDNN),
+    'jit': RNNRunner('jit', lstm_creator, DummyContext),
+    'jit_premul': RNNRunner('jit_premul', lstm_premul_creator, DummyContext),
+    'jit_simple': RNNRunner('jit_simple', lstm_simple_creator, DummyContext),
+    'jit_multilayer': RNNRunner('jit_multilayer', lstm_multilayer_creator, DummyContext),
+    'jit_layernorm': RNNRunner('jit_layernorm', lnlstm_creator, DummyContext),
+    'jit_layernorm_decom': RNNRunner('jit_layernorm_decom',
+                                     partial(lnlstm_creator, decompose_layernorm=True),
+                                     DummyContext),
+    'jit_dropout': RNNRunner('jit_dropout', dropoutlstm_creator, DummyContext),
+    'py': RNNRunner('py', partial(lstm_creator, script=False), DummyContext),
+    'resnet18': RNNRunner('resnet18', imagenet_cnn_creator(cnn.resnet18, jit=False), DummyContext),
+    'resnet18_jit': RNNRunner('resnet18_jit', imagenet_cnn_creator(cnn.resnet18), DummyContext),
+    'resnet50': RNNRunner('resnet50', imagenet_cnn_creator(cnn.resnet50, jit=False), DummyContext),
+    'resnet50_jit': RNNRunner('resnet50_jit', imagenet_cnn_creator(cnn.resnet50), DummyContext),
+}
diff --git a/benchmarks/fastrnns/scratch.py b/benchmarks/fastrnns/scratch.py
new file mode 100644 (file)
index 0000000..c51d716
--- /dev/null
@@ -0,0 +1,51 @@
+import torch
+
+
+@torch.jit.script
+def fn(x, scale, shift):
+    return scale * x / shift
+
+
+@torch.jit.script
+def recurrent(x, scale, shift):
+    y = x
+    for i in range(100):
+        y = fn(y, scale, shift)
+    return y
+
+
+x = torch.randn(2, 2, device='cuda')
+scale = torch.randn(2, 2, device='cuda', requires_grad=True)
+shift = torch.randn(2, 2, device='cuda', requires_grad=True)
+inputs = [x, scale, shift]
+
+
+out = recurrent(x, scale, shift)
+recurrent.graph_for(x, scale, shift)
+
+
+import torch
+
+
+@torch.jit.script
+def recurrent_scaleshift(x, scale, shift):
+    y = x
+    for i in range(64):
+        y = scale * y + shift
+    return y
+
+
+x = torch.randn(2, 2, device='cuda')
+scale = torch.randn(2, 2, device='cuda', requires_grad=True)
+shift = torch.randn(2, 2, device='cuda', requires_grad=True)
+inputs = [x, scale, shift]
+out = recurrent_scaleshift(x, scale, shift)
+recurrent_scaleshift.graph_for(x, scale, shift)
+
+
+import torch
+x = torch.tensor([])
+x.requires_grad = True
+x.mean().backward()  # no error triggered
+x = x.cuda()
+x.mean().backward()
diff --git a/benchmarks/fastrnns/test.py b/benchmarks/fastrnns/test.py
new file mode 100644 (file)
index 0000000..6c8e4e1
--- /dev/null
@@ -0,0 +1,159 @@
+import argparse
+import torch
+import torch.nn as nn
+
+from .cells import lstm_cell
+from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator
+from .runner import get_nn_runners
+
+
+def barf():
+    import pdb
+    pdb.set_trace()
+
+
+def assertEqual(tensor, expected, threshold=0.001):
+    if isinstance(tensor, list) or isinstance(tensor, tuple):
+        for t, e in zip(tensor, expected):
+            assertEqual(t, e)
+    else:
+        if (tensor - expected).abs().max() > threshold:
+            barf()
+
+
+def filter_requires_grad(tensors):
+    return [t for t in tensors if t.requires_grad]
+
+
+def test_rnns(experim_creator, control_creator, check_grad=True, verbose=False,
+              seqLength=100, numLayers=1, inputSize=512, hiddenSize=512,
+              miniBatch=64, device='cuda', seed=17):
+    creator_args = dict(seqLength=seqLength, numLayers=numLayers,
+                        inputSize=inputSize, hiddenSize=hiddenSize,
+                        miniBatch=miniBatch, device=device, seed=seed)
+
+    print("Setting up...")
+    control = control_creator(**creator_args)
+    experim = experim_creator(**creator_args)
+
+    # Precondition
+    assertEqual(experim.inputs, control.inputs)
+    assertEqual(experim.params, control.params)
+
+    print("Checking outputs...")
+    control_outputs = control.forward(*control.inputs)
+    experim_outputs = experim.forward(*experim.inputs)
+    assertEqual(experim_outputs, control_outputs)
+
+    print("Checking grads...")
+    assert control.backward_setup is not None
+    assert experim.backward_setup is not None
+    assert control.backward is not None
+    assert experim.backward is not None
+    control_backward_inputs = control.backward_setup(control_outputs, seed)
+    experim_backward_inputs = experim.backward_setup(experim_outputs, seed)
+
+    control.backward(*control_backward_inputs)
+    experim.backward(*experim_backward_inputs)
+
+    control_grads = [p.grad for p in control.params]
+    experim_grads = [p.grad for p in experim.params]
+    assertEqual(experim_grads, control_grads)
+
+    if verbose:
+        print(experim.forward.graph_for(*experim.inputs))
+    print('')
+
+
+def test_vl_py(**test_args):
+    # XXX: This compares vl_py with vl_lstm.
+    # It's done this way because those two don't give the same outputs so
+    # the result isn't an apples-to-apples comparison right now.
+    control_creator = varlen_pytorch_lstm_creator
+    name, experim_creator, context = get_nn_runners('vl_py')[0]
+    with context():
+        print('testing {}...'.format(name))
+        creator_keys = [
+            'seqLength', 'numLayers', 'inputSize',
+            'hiddenSize', 'miniBatch', 'device', 'seed'
+        ]
+        creator_args = {key: test_args[key] for key in creator_keys}
+
+        print("Setting up...")
+        control = control_creator(**creator_args)
+        experim = experim_creator(**creator_args)
+
+        # Precondition
+        assertEqual(experim.inputs, control.inputs[:2])
+        assertEqual(experim.params, control.params)
+
+        print("Checking outputs...")
+        control_out, control_hiddens = control.forward(*control.inputs)
+        control_hx, control_cx = control_hiddens
+        experim_out, experim_hiddens = experim.forward(*experim.inputs)
+        experim_hx, experim_cx = experim_hiddens
+
+        experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2)
+        assertEqual(experim_padded, control_out)
+        assertEqual(torch.cat(experim_hx, dim=1), control_hx)
+        assertEqual(torch.cat(experim_cx, dim=1), control_cx)
+
+        print("Checking grads...")
+        assert control.backward_setup is not None
+        assert experim.backward_setup is not None
+        assert control.backward is not None
+        assert experim.backward is not None
+        control_backward_inputs = control.backward_setup(
+            (control_out, control_hiddens), test_args['seed'])
+        experim_backward_inputs = experim.backward_setup(
+            (experim_out, experim_hiddens), test_args['seed'])
+
+        control.backward(*control_backward_inputs)
+        experim.backward(*experim_backward_inputs)
+
+        control_grads = [p.grad for p in control.params]
+        experim_grads = [p.grad for p in experim.params]
+        assertEqual(experim_grads, control_grads)
+
+        if test_args['verbose']:
+            print(experim.forward.graph_for(*experim.inputs))
+        print('')
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Test lstm correctness')
+
+    parser.add_argument('--seqLength', default='100', type=int)
+    parser.add_argument('--numLayers', default='1', type=int)
+    parser.add_argument('--inputSize', default='512', type=int)
+    parser.add_argument('--hiddenSize', default='512', type=int)
+    parser.add_argument('--miniBatch', default='64', type=int)
+    parser.add_argument('--device', default='cuda', type=str)
+    parser.add_argument('--check_grad', default='True', type=bool)
+    parser.add_argument('--variable_lstms', action='store_true')
+    parser.add_argument('--seed', default='17', type=int)
+    parser.add_argument('--verbose', action='store_true')
+    parser.add_argument('--rnns', nargs='*',
+                        help='What to run. jit_premul, jit, etc')
+    args = parser.parse_args()
+    if args.rnns is None:
+        args.rnns = ['jit_premul', 'jit']
+    print(args)
+
+    if 'cuda' in args.device:
+        assert torch.cuda.is_available()
+
+    rnn_runners = get_nn_runners(*args.rnns)
+
+    should_test_varlen_lstms = args.variable_lstms
+    test_args = vars(args)
+    del test_args['rnns']
+    del test_args['variable_lstms']
+
+    if should_test_varlen_lstms:
+        test_vl_py(**test_args)
+
+    for name, creator, context in rnn_runners:
+        with context():
+            print('testing {}...'.format(name))
+            test_rnns(creator, pytorch_lstm_creator, **test_args)