From 6684ef3f23f1aeb5f5034a982d2669011f3a5179 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 27 Mar 2019 14:39:33 -0700 Subject: [PATCH] Move fast rnn benchmark to pytorch/pytorch Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18369 Differential Revision: D14652039 Pulled By: wanchaol fbshipit-source-id: 1177b1f60d96672c3e2c9d527b56ee06ca7c0af1 --- benchmarks/README.md | 29 +++ benchmarks/fastrnns/README.md | 42 ++++ benchmarks/fastrnns/__init__.py | 10 + benchmarks/fastrnns/bench.py | 201 ++++++++++++++++ benchmarks/fastrnns/cells.py | 101 ++++++++ benchmarks/fastrnns/custom_lstms.py | 461 ++++++++++++++++++++++++++++++++++++ benchmarks/fastrnns/factory.py | 432 +++++++++++++++++++++++++++++++++ benchmarks/fastrnns/profile.py | 138 +++++++++++ benchmarks/fastrnns/runner.py | 67 ++++++ benchmarks/fastrnns/scratch.py | 51 ++++ benchmarks/fastrnns/test.py | 159 +++++++++++++ 11 files changed, 1691 insertions(+) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/fastrnns/README.md create mode 100644 benchmarks/fastrnns/__init__.py create mode 100644 benchmarks/fastrnns/bench.py create mode 100644 benchmarks/fastrnns/cells.py create mode 100644 benchmarks/fastrnns/custom_lstms.py create mode 100644 benchmarks/fastrnns/factory.py create mode 100644 benchmarks/fastrnns/profile.py create mode 100644 benchmarks/fastrnns/runner.py create mode 100644 benchmarks/fastrnns/scratch.py create mode 100644 benchmarks/fastrnns/test.py diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..1477f9e --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,29 @@ +# PyTorch Benchmarks + +NOTE: This folder is currently work in progress. + +This folder contains scripts that produce reproducible timings of various PyTorch features. + +It also provides mechanisms to compare PyTorch with other frameworks. + +## Setup environment +Make sure you're on a machine with CUDA, torchvision, and pytorch installed. Install in the following order: +``` +# Install torchvision. It comes with the pytorch stable release binary +conda install pytorch torchvision -c pytorch + +# Install the latest pytorch master from source. +# It should supercede the installation from the release binary. +cd $PYTORCH_HOME +python setup.py build develop + +# Check the pytorch installation version +python -c "import torch; print(torch.__version__)" +``` + +## Benchmark List + +Please refer to each subfolder to discover each benchmark suite + +* [Fast RNNs benchmarks](fastrnns/README.md) + diff --git a/benchmarks/fastrnns/README.md b/benchmarks/fastrnns/README.md new file mode 100644 index 0000000..87f93fa --- /dev/null +++ b/benchmarks/fastrnns/README.md @@ -0,0 +1,42 @@ +# Fast RNN benchmarks + +Benchmarks for TorchScript models + +For most stable results, do the following: +- Set CPU Governor to performance mode (as opposed to energy save) +- Turn off turbo for all CPUs (assuming Intel CPUs) +- Shield cpus via `cset shield` when running benchmarks. + +Some of these scripts accept command line args but most of them do not because +I was lazy. They will probably be added sometime in the future, but the default +sizes are pretty reasonable. + +## Test fastrnns (fwd + bwd) correctness + +Test the fastrnns benchmarking scripts with the following: +`python -m fastrnns.test` +or run the test independently: +`python -m fastrnns.test --rnns jit` + +## Run benchmarks + +`python -m fastrnns.bench` + +should give a good comparision, or you can specify the type of model to run + +`python -m fastrnns.bench --rnns cudnn aten jit --group rnns` + +## Run model profiling, calls nvprof + +`python -m fastrnns.profile` + +should generate nvprof file for all models somewhere. +you can also specify the models to generate nvprof files separately: + +`python -m fastrnns.profile --rnns aten jit` + +### Caveats + +Use Linux for the most accurate timing. A lot of these tests only run +on CUDA. + diff --git a/benchmarks/fastrnns/__init__.py b/benchmarks/fastrnns/__init__.py new file mode 100644 index 0000000..f32d4a0 --- /dev/null +++ b/benchmarks/fastrnns/__init__.py @@ -0,0 +1,10 @@ +from .cells import * +from .factory import * +from .test import * + +# (output, next_state) = cell(input, state) +seqLength = 100 +numLayers = 2 +inputSize = 512 +hiddenSize = 512 +miniBatch = 64 diff --git a/benchmarks/fastrnns/bench.py b/benchmarks/fastrnns/bench.py new file mode 100644 index 0000000..71cad4a --- /dev/null +++ b/benchmarks/fastrnns/bench.py @@ -0,0 +1,201 @@ +from __future__ import print_function +import argparse +from collections import namedtuple +import torch +import gc +import sys +import json +import copy + +from .runner import get_nn_runners + + +BenchResult = namedtuple('BenchResult', [ + 'name', 'avg_fwd', 'std_fwd', 'avg_bwd', 'std_bwd', +]) + + +def fit_str(string, colwidth=16): + if len(string) < colwidth: + return (colwidth - len(string)) * ' ' + string + else: + return string[:colwidth] + + +def to_str(item): + if isinstance(item, float): + return '%.4g' % item + return str(item) + + +def print_header(colwidth=16, sep=' '): + items = [] + for item in BenchResult._fields: + items.append(fit_str(item)) + return sep.join(items) + + +def pretty_print(benchresult, colwidth=16, sep=' '): + items = [] + for thing in benchresult: + items.append(fit_str(to_str(thing))) + return sep.join(items) + + +def trainbench(name, rnn_creator, nloops=100, warmup=10, + seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, + miniBatch=64, device='cuda', seed=None): + def train_batch(modeldef): + # CUDA events for timing + fwd_start_event = torch.cuda.Event(enable_timing=True) + fwd_end_event = torch.cuda.Event(enable_timing=True) + bwd_start_event = torch.cuda.Event(enable_timing=True) + bwd_end_event = torch.cuda.Event(enable_timing=True) + + gc.collect() + + fwd_start_event.record() + forward_output = modeldef.forward(*modeldef.inputs) + fwd_end_event.record() + + # XXX: Use if need to print something + # print(modeldef.forward.graph_for(*modeldef.inputs)) + + if modeldef.backward_setup is not None: + backward_input = modeldef.backward_setup(forward_output) + else: + backward_input = forward_output + + gc.collect() + + bwd_start_event.record() + if modeldef.backward is not None: + modeldef.backward(*backward_input) + bwd_end_event.record() + + if modeldef.backward is not None: + for param in modeldef.params: + assert param.grad is not None + param.grad.data.zero_() + + torch.cuda.synchronize() + + fwd_time = fwd_start_event.elapsed_time(fwd_end_event) + bwd_time = bwd_start_event.elapsed_time(bwd_end_event) + return fwd_time, bwd_time + + assert device == 'cuda' + creator_args = dict(seqLength=seqLength, numLayers=numLayers, + inputSize=inputSize, hiddenSize=hiddenSize, + miniBatch=miniBatch, device=device, seed=seed) + modeldef = rnn_creator(**creator_args) + + [train_batch(modeldef) for _ in range(warmup)] + + results = [train_batch(modeldef) for _ in range(nloops)] + fwd_times, bwd_times = zip(*results) + + fwd_times = torch.tensor(fwd_times) + bwd_times = torch.tensor(bwd_times) + + return BenchResult(name=name, + avg_fwd=fwd_times.mean().item(), + std_fwd=fwd_times.std().item(), + avg_bwd=bwd_times.mean().item(), + std_bwd=bwd_times.std().item()) + + +def print_stderr(*args, **kwargs): + kwargs['file'] = sys.stderr + return print(*args, **kwargs) + + +def bench(rnn_runners, group_name, print_json=False, sep=' ', **params): + print_stderr(print_header(sep=sep)) + results = {} + for name, creator, context in rnn_runners: + with context(): + try: + result = trainbench(name, creator, **params) + print_stderr(pretty_print(result, sep=sep)) + results[name] = result + except Exception as e: + if not print_json: + raise + + return { + group_name: {k: v.avg_fwd for k, v in results.items()}, + group_name + '-backward': {k: v.avg_bwd for k, v in results.items()}, + } + + +def bench_group(model_list, bench_name, bench_group, bench_args): + print_stderr('Benchmarking {}s...'.format(bench_name)) + nn_results = bench(get_nn_runners(*model_list), bench_group, **bench_args) + print_stderr('') + return nn_results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Profile RNNs') + + # groups help control which test group you want to run + # if you only want to run one/two benchmark, run it with + # e.g: python -m fastrnns.bench --rnns jit and --group rnns + default_groups = ['cnns', 'rnns'] + + parser.add_argument('--seqLength', default='100', type=int) + parser.add_argument('--numLayers', default='1', type=int) + parser.add_argument('--inputSize', default='512', type=int) + parser.add_argument('--hiddenSize', default='512', type=int) + parser.add_argument('--miniBatch', default='64', type=int) + parser.add_argument('--warmup', default='10', type=int) + parser.add_argument('--nloops', default='100', type=int) + parser.add_argument('--device', default='cuda', type=str) + parser.add_argument('--variable_lstms', action='store_true', + help='Also benchmark variable sequence length lstms ' + 'Note that some of these run really slowly ' + 'and that the `seqLength` flag will be ignored.') + parser.add_argument('--sep', default=' ', type=str) + parser.add_argument('--print-json', action='store_true') + parser.add_argument('--rnns', nargs='*', + help='What to run. cudnn, aten, jit, etc') + parser.add_argument('--cnns', nargs='*', + help='What to run. resnet18, resnet18_jit, resnet50, etc') + parser.add_argument('--group', nargs='*', default=default_groups, help='Which group to run. cnns, rnns, etc.') + + args = parser.parse_args() + rnns = args.rnns or ['cudnn', 'aten', 'jit', 'jit_premul', 'jit_simple', + 'jit_multilayer', 'py'] + cnns = args.cnns or ['resnet18', 'resnet18_jit', 'resnet50', 'resnet50_jit'] + # TODO: Maybe add a separate section for the layernorm/dropout lstms + # 'cudnn_layernorm', jit_layernorm', 'jit_layernom_decom', + # 'jit', 'jit_dropout', 'cudnn_dropout' + vlrnns = ['vl_cudnn', 'vl_jit', 'vl_py'] + + if args.print_json: + print_stderr = lambda *args, **kwargs: None # noqa + print_stderr(args) + + bench_args = copy.deepcopy(vars(args)) + should_bench_varlen_lstms = args.variable_lstms + del bench_args['group'] + del bench_args['rnns'] + del bench_args['cnns'] + del bench_args['variable_lstms'] + + results = dict() + if should_bench_varlen_lstms: + if args.nloops + args.warmup > 30: + print_stderr( + 'WARNING: some of the variable sequence length lstms are ' + 'very unoptimized and therefore take forever to run.') + results.update(bench_group(vlrnns, 'variable-length sequence LSTM', 'vl_lstm', bench_args)) + + if 'rnns' in args.group: + results.update(bench_group(rnns, 'LSTM', 'lstm', bench_args)) + if 'cnns' in args.group: + results.update(bench_group(cnns, 'ResNet', 'resnet', bench_args)) + + if args.print_json: + print(json.dumps(results)) diff --git a/benchmarks/fastrnns/cells.py b/benchmarks/fastrnns/cells.py new file mode 100644 index 0000000..7c80c49 --- /dev/null +++ b/benchmarks/fastrnns/cells.py @@ -0,0 +1,101 @@ +import torch + + +def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): + Wx = x.mm(w_ih.t()) + Uz = hx.mm(w_hh.t()) + + # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf + gates = (alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias) + + # Same as LSTMCell after this point + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = ingate.sigmoid() + forgetgate = forgetgate.sigmoid() + cellgate = cellgate.tanh() + outgate = outgate.sigmoid() + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * cy.tanh() + + return hy, cy + + +def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): + # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] + hx, cx = hidden + gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh + + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(cy) + + return hy, cy + + +def flat_lstm_cell(input, hx, cx, w_ih, w_hh, b_ih, b_hh): + # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] + gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh + + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(cy) + + return hy, cy + + +def premul_lstm_cell(igates, hidden, w_hh, b_ih, b_hh): + # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] + hx, cx = hidden + gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh + + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(cy) + + return hy, cy + + +def gru_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): + gi = torch.mm(input, w_ih.t()) + b_ih + gh = torch.mm(hidden, w_hh.t()) + b_hh + i_r, i_i, i_n = gi.chunk(3, 1) + h_r, h_i, h_n = gh.chunk(3, 1) + + resetgate = torch.sigmoid(i_r + h_r) + inputgate = torch.sigmoid(i_i + h_i) + newgate = torch.tanh(i_n + resetgate * h_n) + hy = newgate + inputgate * (hidden - newgate) + + return hy + + +def rnn_relu_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): + igates = torch.mm(input, w_ih.t()) + b_ih + hgates = torch.mm(hidden, w_hh.t()) + b_hh + return torch.relu(igates + hgates) + + +def rnn_tanh_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): + igates = torch.mm(input, w_ih.t()) + b_ih + hgates = torch.mm(hidden, w_hh.t()) + b_hh + return torch.tanh(igates + hgates) diff --git a/benchmarks/fastrnns/custom_lstms.py b/benchmarks/fastrnns/custom_lstms.py new file mode 100644 index 0000000..d835b3e --- /dev/null +++ b/benchmarks/fastrnns/custom_lstms.py @@ -0,0 +1,461 @@ +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.jit as jit +import warnings +from collections import namedtuple +from typing import List, Tuple +from torch import Tensor +import numbers + +''' +Some helper classes for writing custom TorchScript LSTMs. + +Goals: +- Classes are easy to read, use, and extend +- Performance of custom LSTMs approach fused-kernel-levels of speed. + +A few notes about features we could add to clean up the below code: +- Support enumerate with nn.ModuleList: + https://github.com/pytorch/pytorch/issues/14471 +- Support enumerate/zip with lists: + https://github.com/pytorch/pytorch/issues/15952 +- Support overriding of class methods: + https://github.com/pytorch/pytorch/issues/10733 +- Support passing around user-defined namedtuple types for readability +- Support slicing w/ range. It enables reversing lists easily. + https://github.com/pytorch/pytorch/issues/10774 +- Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose + https://github.com/pytorch/pytorch/pull/14922 +''' + + +def script_lstm(input_size, hidden_size, num_layers, bias=True, + batch_first=False, dropout=False, bidirectional=False): + '''Returns a ScriptModule that mimics a PyTorch native LSTM.''' + + # The following are not implemented. + assert bias + assert not batch_first + + if bidirectional: + stack_type = StackedLSTM2 + layer_type = BidirLSTMLayer + dirs = 2 + elif dropout: + stack_type = StackedLSTMWithDropout + layer_type = LSTMLayer + dirs = 1 + else: + stack_type = StackedLSTM + layer_type = LSTMLayer + dirs = 1 + + return stack_type(num_layers, layer_type, + first_layer_args=[LSTMCell, input_size, hidden_size], + other_layer_args=[LSTMCell, hidden_size * dirs, + hidden_size]) + + +def script_lnlstm(input_size, hidden_size, num_layers, bias=True, + batch_first=False, dropout=False, bidirectional=False, + decompose_layernorm=False): + '''Returns a ScriptModule that mimics a PyTorch native LSTM.''' + + # The following are not implemented. + assert bias + assert not batch_first + assert not dropout + + if bidirectional: + stack_type = StackedLSTM2 + layer_type = BidirLSTMLayer + dirs = 2 + else: + stack_type = StackedLSTM + layer_type = LSTMLayer + dirs = 1 + + return stack_type(num_layers, layer_type, + first_layer_args=[LayerNormLSTMCell, input_size, hidden_size, + decompose_layernorm], + other_layer_args=[LayerNormLSTMCell, hidden_size * dirs, + hidden_size, decompose_layernorm]) + + +LSTMState = namedtuple('LSTMState', ['hx', 'cx']) + + +def reverse(lst): + # type: (List[Tensor]) -> List[Tensor] + return lst[::-1] + + +class LSTMCell(jit.ScriptModule): + def __init__(self, input_size, hidden_size): + super(LSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) + self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) + self.bias_ih = Parameter(torch.randn(4 * hidden_size)) + self.bias_hh = Parameter(torch.randn(4 * hidden_size)) + + @jit.script_method + def forward(self, input, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + hx, cx = state + gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih + + torch.mm(hx, self.weight_hh.t()) + self.bias_hh) + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(cy) + + return hy, (hy, cy) + + +class LayerNorm(jit.ScriptModule): + def __init__(self, normalized_shape): + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + # XXX: This is true for our LSTM / NLP use case and helps simplify code + assert len(normalized_shape) == 1 + + self.weight = Parameter(torch.ones(normalized_shape)) + self.bias = Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + @jit.script_method + def compute_layernorm_stats(self, input): + mu = input.mean(-1, keepdim=True) + sigma = input.std(-1, keepdim=True, unbiased=False) + return mu, sigma + + @jit.script_method + def forward(self, input): + mu, sigma = self.compute_layernorm_stats(input) + return (input - mu) / sigma * self.weight + self.bias + + +class LayerNormLSTMCell(jit.ScriptModule): + def __init__(self, input_size, hidden_size, decompose_layernorm=False): + super(LayerNormLSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) + self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) + # The layernorms provide learnable biases + + if decompose_layernorm: + ln = LayerNorm + else: + ln = nn.LayerNorm + + self.layernorm_i = ln(4 * hidden_size) + self.layernorm_h = ln(4 * hidden_size) + self.layernorm_c = ln(hidden_size) + + @jit.script_method + def forward(self, input, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + hx, cx = state + igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) + hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) + gates = igates + hgates + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate)) + hy = outgate * torch.tanh(cy) + + return hy, (hy, cy) + + +class LSTMLayer(jit.ScriptModule): + def __init__(self, cell, *cell_args): + super(LSTMLayer, self).__init__() + self.cell = cell(*cell_args) + + @jit.script_method + def forward(self, input, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + inputs = input.unbind(0) + outputs = torch.jit.annotate(List[Tensor], []) + for i in range(len(inputs)): + out, state = self.cell(inputs[i], state) + outputs += [out] + return torch.stack(outputs), state + + +class ReverseLSTMLayer(jit.ScriptModule): + def __init__(self, cell, *cell_args): + super(ReverseLSTMLayer, self).__init__() + self.cell = cell(*cell_args) + + @jit.script_method + def forward(self, input, state): + # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + inputs = reverse(input.unbind(0)) + outputs = jit.annotate(List[Tensor], []) + for i in range(len(inputs)): + out, state = self.cell(inputs[i], state) + outputs += [out] + return torch.stack(reverse(outputs)), state + + +class BidirLSTMLayer(jit.ScriptModule): + __constants__ = ['directions'] + + def __init__(self, cell, *cell_args): + super(BidirLSTMLayer, self).__init__() + self.directions = nn.ModuleList([ + LSTMLayer(cell, *cell_args), + ReverseLSTMLayer(cell, *cell_args), + ]) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + # List[LSTMState]: [forward LSTMState, backward LSTMState] + outputs = jit.annotate(List[Tensor], []) + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 + i = 0 + for direction in self.directions: + state = states[i] + out, out_state = direction(input, state) + outputs += [out] + output_states += [out_state] + i += 1 + return torch.cat(outputs, -1), output_states + + +def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args): + layers = [layer(*first_layer_args)] + [layer(*other_layer_args) + for _ in range(num_layers - 1)] + return nn.ModuleList(layers) + + +class StackedLSTM(jit.ScriptModule): + __constants__ = ['layers'] # Necessary for iterating through self.layers + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super(StackedLSTM, self).__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, + other_layer_args) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + # List[LSTMState]: One state per layer + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + output = input + # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 + i = 0 + for rnn_layer in self.layers: + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + i += 1 + return output, output_states + + +# Differs from StackedLSTM in that its forward method takes +# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM +# except we don't support overriding script methods. +# https://github.com/pytorch/pytorch/issues/10733 +class StackedLSTM2(jit.ScriptModule): + __constants__ = ['layers'] # Necessary for iterating through self.layers + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super(StackedLSTM2, self).__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, + other_layer_args) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]] + # List[List[LSTMState]]: The outer list is for layers, + # inner list is for directions. + output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) + output = input + # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 + i = 0 + for rnn_layer in self.layers: + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + i += 1 + return output, output_states + + +class StackedLSTMWithDropout(jit.ScriptModule): + # Necessary for iterating through self.layers and dropout support + __constants__ = ['layers', 'num_layers'] + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super(StackedLSTMWithDropout, self).__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, + other_layer_args) + # Introduces a Dropout layer on the outputs of each LSTM layer except + # the last layer, with dropout probability = 0.4. + self.num_layers = num_layers + + if (num_layers == 1): + warnings.warn("dropout lstm adds dropout layers after all but last " + "recurrent layer, it expects num_layers greater than " + "1, but got num_layers = 1") + + self.dropout_layer = nn.Dropout(0.4) + + @jit.script_method + def forward(self, input, states): + # type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]] + # List[LSTMState]: One state per layer + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + output = input + # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 + i = 0 + for rnn_layer in self.layers: + state = states[i] + output, out_state = rnn_layer(output, state) + # Apply the dropout layer except the last layer + if i < self.num_layers - 1: + output = self.dropout_layer(output) + output_states += [out_state] + i += 1 + return output, output_states + + +def flatten_states(states): + states = list(zip(*states)) + assert len(states) == 2 + return [torch.stack(state) for state in states] + + +def double_flatten_states(states): + # XXX: Can probably write this in a nicer way + states = flatten_states([flatten_states(inner) for inner in states]) + return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states] + + +def test_script_rnn_layer(seq_len, batch, input_size, hidden_size): + inp = torch.randn(seq_len, batch, input_size) + state = LSTMState(torch.randn(batch, hidden_size), + torch.randn(batch, hidden_size)) + rnn = LSTMLayer(LSTMCell, input_size, hidden_size) + out, out_state = rnn(inp, state) + + # Control: pytorch native LSTM + lstm = nn.LSTM(input_size, hidden_size, 1) + lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0)) + for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()): + assert lstm_param.shape == custom_param.shape + with torch.no_grad(): + lstm_param.copy_(custom_param) + lstm_out, lstm_out_state = lstm(inp, lstm_state) + + assert (out - lstm_out).abs().max() < 1e-5 + assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5 + assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5 + + +def test_script_stacked_rnn(seq_len, batch, input_size, hidden_size, + num_layers): + inp = torch.randn(seq_len, batch, input_size) + states = [LSTMState(torch.randn(batch, hidden_size), + torch.randn(batch, hidden_size)) + for _ in range(num_layers)] + rnn = script_lstm(input_size, hidden_size, num_layers) + out, out_state = rnn(inp, states) + custom_state = flatten_states(out_state) + + # Control: pytorch native LSTM + lstm = nn.LSTM(input_size, hidden_size, num_layers) + lstm_state = flatten_states(states) + for layer in range(num_layers): + custom_params = list(rnn.parameters())[4 * layer: 4 * (layer + 1)] + for lstm_param, custom_param in zip(lstm.all_weights[layer], + custom_params): + assert lstm_param.shape == custom_param.shape + with torch.no_grad(): + lstm_param.copy_(custom_param) + lstm_out, lstm_out_state = lstm(inp, lstm_state) + + assert (out - lstm_out).abs().max() < 1e-5 + assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 + assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 + + +def test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size, + num_layers): + inp = torch.randn(seq_len, batch, input_size) + states = [[LSTMState(torch.randn(batch, hidden_size), + torch.randn(batch, hidden_size)) + for _ in range(2)] + for _ in range(num_layers)] + rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True) + out, out_state = rnn(inp, states) + custom_state = double_flatten_states(out_state) + + # Control: pytorch native LSTM + lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True) + lstm_state = double_flatten_states(states) + for layer in range(num_layers): + for direct in range(2): + index = 2 * layer + direct + custom_params = list(rnn.parameters())[4 * index: 4 * index + 4] + for lstm_param, custom_param in zip(lstm.all_weights[index], + custom_params): + assert lstm_param.shape == custom_param.shape + with torch.no_grad(): + lstm_param.copy_(custom_param) + lstm_out, lstm_out_state = lstm(inp, lstm_state) + + assert (out - lstm_out).abs().max() < 1e-5 + assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 + assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 + + +def test_script_stacked_lstm_dropout(seq_len, batch, input_size, hidden_size, + num_layers): + inp = torch.randn(seq_len, batch, input_size) + states = [LSTMState(torch.randn(batch, hidden_size), + torch.randn(batch, hidden_size)) + for _ in range(num_layers)] + rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True) + + # just a smoke test + out, out_state = rnn(inp, states) + + +def test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size, + num_layers): + inp = torch.randn(seq_len, batch, input_size) + states = [LSTMState(torch.randn(batch, hidden_size), + torch.randn(batch, hidden_size)) + for _ in range(num_layers)] + rnn = script_lnlstm(input_size, hidden_size, num_layers) + + # just a smoke test + out, out_state = rnn(inp, states) + + +test_script_rnn_layer(5, 2, 3, 7) +test_script_stacked_rnn(5, 2, 3, 7, 4) +test_script_stacked_bidir_rnn(5, 2, 3, 7, 4) +test_script_stacked_lstm_dropout(5, 2, 3, 7, 4) +test_script_stacked_lnlstm(5, 2, 3, 7, 4) diff --git a/benchmarks/fastrnns/factory.py b/benchmarks/fastrnns/factory.py new file mode 100644 index 0000000..90f49bc --- /dev/null +++ b/benchmarks/fastrnns/factory.py @@ -0,0 +1,432 @@ +import torch + +from collections import namedtuple + +from .cells import lstm_cell, premul_lstm_cell, flat_lstm_cell + + +# list[list[T]] -> list[T] +def flatten_list(lst): + result = [] + for inner in lst: + result.extend(inner) + return result + + +''' +Define a creator as a function: +(options) -> (inputs, params, forward, backward_setup, backward) +inputs: the inputs to the returned 'forward'. One can call + forward(*inputs) directly. +params: List[Tensor] all requires_grad=True parameters. +forward: function / graph executor / module + One can call rnn(rnn_inputs) using the outputs of the creator. +backward_setup: backward_inputs = backward_setup(*outputs) + Then, we pass backward_inputs to backward. If None, then it is assumed to + be the identity function. +backward: Given `output = backward_setup(*forward(*inputs))`, performs + backpropagation. If None, then nothing happens. + +fastrnns.bench times the forward and backward invocations. +''' + + +ModelDef = namedtuple('ModelDef', [ + 'inputs', 'params', 'forward', 'backward_setup', 'backward']) + + +def lstm_backward_setup(lstm_outputs, seed=None): + hx, _ = lstm_outputs + return simple_backward_setup(hx, seed) + + +def simple_backward_setup(output, seed=None): + assert isinstance(output, torch.Tensor) + if seed: + torch.manual_seed(seed) + grad_output = torch.randn_like(output) + return output, grad_output + + +def simple_backward(output, grad_output): + return output.backward(grad_output) + + +def pytorch_lstm_creator(**kwargs): + input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) + return ModelDef( + inputs=[input, hidden], + params=flatten_list(module.all_weights), + forward=module, + backward_setup=lstm_backward_setup, + backward=simple_backward) + + +def lstm_creator(script=True, **kwargs): + input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) + inputs = [input, hidden] + params[0] + return ModelDef( + inputs=inputs, + params=flatten_list(params), + forward=lstm_factory(lstm_cell, script), + backward_setup=lstm_backward_setup, + backward=simple_backward) + + +def lnlstm_creator(script=True, decompose_layernorm=False, **kwargs): + assert script is True + from .custom_lstms import script_lnlstm + input_size = kwargs['inputSize'] + hidden_size = kwargs['hiddenSize'] + seq_len = kwargs['seqLength'] + batch_size = kwargs['miniBatch'] + ge = script_lnlstm(input_size, hidden_size, 1, + decompose_layernorm=decompose_layernorm).cuda() + + input = torch.randn(seq_len, batch_size, input_size, device='cuda') + states = [(torch.randn(batch_size, hidden_size, device='cuda'), + torch.randn(batch_size, hidden_size, device='cuda'))] + + return ModelDef( + inputs=[input, states], + params=ge.parameters(), + forward=ge, + backward_setup=lstm_backward_setup, + backward=simple_backward) + + +def dropoutlstm_creator(script=True, **kwargs): + assert script is True + from .custom_lstms import script_lstm, LSTMState + input_size = kwargs['inputSize'] + hidden_size = kwargs['hiddenSize'] + seq_len = kwargs['seqLength'] + batch_size = kwargs['miniBatch'] + num_layers = kwargs['numLayers'] + ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda() + + input = torch.randn(seq_len, batch_size, input_size, device='cuda') + states = [LSTMState(torch.randn(batch_size, hidden_size, device='cuda'), + torch.randn(batch_size, hidden_size, device='cuda')) + for _ in range(num_layers)] + return ModelDef( + inputs=[input, states], + params=ge.parameters(), + forward=ge, + backward_setup=lstm_backward_setup, + backward=simple_backward) + + +def lstm_premul_creator(script=True, **kwargs): + input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) + inputs = [input, hidden] + params[0] + return ModelDef( + inputs=inputs, + params=flatten_list(params), + forward=lstm_factory_premul(premul_lstm_cell, script), + backward_setup=lstm_backward_setup, + backward=simple_backward) + + +def lstm_simple_creator(script=True, **kwargs): + input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) + inputs = [input] + [h[0] for h in hidden] + params[0] + return ModelDef( + inputs=inputs, + params=flatten_list(params), + forward=lstm_factory_simple(flat_lstm_cell, script), + backward_setup=lstm_backward_setup, + backward=simple_backward) + + +def lstm_multilayer_creator(script=True, **kwargs): + input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) + inputs = [input, hidden, flatten_list(params)] + return ModelDef( + inputs=inputs, + params=flatten_list(params), + forward=lstm_factory_multilayer(lstm_cell, script), + backward_setup=lstm_backward_setup, + backward=simple_backward) + + +def imagenet_cnn_creator(arch, jit=True): + def creator(device='cuda', **kwargs): + model = arch().to(device) + x = torch.randn(32, 3, 224, 224, device=device) + if jit: + model = torch.jit.trace(model, x) + return ModelDef( + inputs=(x,), + params=list(model.parameters()), + forward=model, + backward_setup=simple_backward_setup, + backward=simple_backward) + + return creator + + +def varlen_lstm_inputs(minlen=30, maxlen=100, + numLayers=1, inputSize=512, hiddenSize=512, + miniBatch=64, return_module=False, device='cuda', + seed=None, **kwargs): + if seed is not None: + torch.manual_seed(seed) + lengths = torch.randint( + low=minlen, high=maxlen, size=[miniBatch], + dtype=torch.long, device=device) + x = [torch.randn(length, inputSize, device=device) + for length in lengths] + hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) + cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) + lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device) + + if return_module: + return x, lengths, (hx, cx), lstm.all_weights, lstm + else: + # NB: lstm.all_weights format: + # wih, whh, bih, bhh = lstm.all_weights[layer] + return x, lengths, (hx, cx), lstm.all_weights, None + + +def varlen_lstm_backward_setup(forward_output, seed=None): + if seed: + torch.manual_seed(seed) + rnn_utils = torch.nn.utils.rnn + sequences = forward_output[0] + padded = rnn_utils.pad_sequence(sequences) + grad = torch.randn_like(padded) + return padded, grad + + +def varlen_pytorch_lstm_creator(**kwargs): + rnn_utils = torch.nn.utils.rnn + sequences, _, hidden, _, module = varlen_lstm_inputs( + return_module=True, **kwargs) + + def forward(sequences, hidden): + packed = rnn_utils.pack_sequence(sequences, enforce_sorted=False) + out, new_hidden = module(packed, hidden) + padded, lengths = rnn_utils.pad_packed_sequence(out) + # XXX: It's more efficient to store the output in its padded form, + # but that might not be conducive to loss computation. + # Un-padding the output also makes the backward pass 2x slower... + # return [padded[:lengths[i], i, :] for i in range(lengths.size(0))] + return padded, new_hidden + + return ModelDef( + inputs=[sequences, hidden], + params=flatten_list(module.all_weights), + forward=forward, + backward_setup=lstm_backward_setup, + backward=simple_backward) + + +def varlen_lstm_factory(cell, script): + def dynamic_rnn(sequences, hiddens, wih, whh, bih, bhh): + # type: (List[Tensor], Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]] # noqa + hx, cx = hiddens + hxs = hx.unbind(1) + cxs = cx.unbind(1) + # List of: (output, hx, cx) + outputs = [] + hx_outs = [] + cx_outs = [] + + for batch in range(len(sequences)): + output = [] + hy, cy = hxs[batch], cxs[batch] + inputs = sequences[batch].unbind(0) + + for seq_idx in range(len(inputs)): + hy, cy = cell( + inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh) + output += [hy] + outputs += [torch.stack(output)] + hx_outs += [hy.unsqueeze(0)] + cx_outs += [cy.unsqueeze(0)] + + return outputs, (hx_outs, cx_outs) + + if script: + cell = torch.jit.script(cell) + dynamic_rnn = torch.jit.script(dynamic_rnn) + + return dynamic_rnn + + +def varlen_lstm_creator(script=False, **kwargs): + sequences, _, hidden, params, _ = varlen_lstm_inputs( + return_module=False, **kwargs) + inputs = [sequences, hidden] + params[0] + return ModelDef( + inputs=inputs, + params=flatten_list(params), + forward=varlen_lstm_factory(lstm_cell, script), + backward_setup=varlen_lstm_backward_setup, + backward=simple_backward) + + +# cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark +# the lowerbound directly. Instead, we only benchmark the foward pass by mimicing the +# computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve +# as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself +# is invariant), the lowerbound of backward pass is hard to get since we lose the +# intermediate results, we can still optimize the layernorm implementation to make +# a faster foward lowerbound though. +def layernorm_pytorch_lstm_creator(**kwargs): + input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) + batch_size = kwargs['miniBatch'] + hidden_size = kwargs['hiddenSize'] + ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda() + ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda() + ln_c = torch.nn.LayerNorm(hidden_size).cuda() + ln_input1 = torch.randn(batch_size, 4 * hidden_size, device='cuda') + + def forward(input, hidden): + out, new_hidden = module(input, hidden) + # plus (seq_len * three laynorm cell computation) to mimic the lower bound of + # Layernorm cudnn LSTM in the forward pass + seq_len = len(input.unbind(0)) + hy, cy = new_hidden + for i in range(seq_len): + ln_i_output = ln_i(ln_input1) + ln_h_output = ln_h(ln_input1) + cy = ln_c(cy) + + return out, (hy, cy) + + return ModelDef( + inputs=[input, hidden], + params=flatten_list(module.all_weights), + forward=forward, + backward_setup=lstm_backward_setup, + backward=None) + + +# input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer]) +# output: packed_weights with format +# packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize) +# packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize) +# packed_weights[2] is bih with size (layer, 4*hiddenSize) +# packed_weights[3] is bhh with size (layer, 4*hiddenSize) +def stack_weights(weights): + def unzip_columns(mat): + assert isinstance(mat, list) + assert isinstance(mat[0], list) + layers = len(mat) + columns = len(mat[0]) + return [[mat[layer][col] for layer in range(layers)] + for col in range(columns)] + + # XXX: script fns have problems indexing multidim lists, so we try to + # avoid them by stacking tensors + all_weights = weights + packed_weights = [torch.stack(param) + for param in unzip_columns(all_weights)] + return packed_weights + + +# returns: x, (hx, cx), all_weights, lstm module with all_weights as params +def lstm_inputs(seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, + miniBatch=64, dropout=0.0, return_module=False, device='cuda', seed=None): + if seed is not None: + torch.manual_seed(seed) + x = torch.randn(seqLength, miniBatch, inputSize, device=device) + hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) + cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) + lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout) + if 'cuda' in device: + lstm = lstm.cuda() + + if return_module: + return x, (hx, cx), lstm.all_weights, lstm + else: + # NB: lstm.all_weights format: + # wih, whh, bih, bhh = lstm.all_weights[layer] + return x, (hx, cx), lstm.all_weights, None + + +def lstm_factory(cell, script): + def dynamic_rnn(input, hidden, wih, whh, bih, bhh): + # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + hx, cx = hidden + outputs = [] + inputs = input.unbind(0) + hy, cy = hx[0], cx[0] + for seq_idx in range(len(inputs)): + hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) + outputs += [hy] + return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) + + if script: + cell = torch.jit.script(cell) + dynamic_rnn = torch.jit.script(dynamic_rnn) + + return dynamic_rnn + + +# premul: we're going to premultiply the inputs & weights +def lstm_factory_premul(premul_cell, script): + def dynamic_rnn(input, hidden, wih, whh, bih, bhh): + # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + hx, cx = hidden + outputs = [] + inputs = torch.matmul(input, wih.t()).unbind(0) + hy, cy = hx[0], cx[0] + for seq_idx in range(len(inputs)): + hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh) + outputs += [hy] + return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) + + if script: + premul_cell = torch.jit.script(premul_cell) + dynamic_rnn = torch.jit.script(dynamic_rnn) + + return dynamic_rnn + + +# simple: flat inputs (no tuples), no list to accumulate outputs +# useful mostly for benchmarking older JIT versions +def lstm_factory_simple(cell, script): + def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh): + hy = hx # for scoping + cy = cx # for scoping + inputs = input.unbind(0) + for seq_idx in range(len(inputs)): + hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh) + return hy, cy + + if script: + cell = torch.jit.script(cell) + dynamic_rnn = torch.jit.script(dynamic_rnn) + + return dynamic_rnn + + +def lstm_factory_multilayer(cell, script): + def dynamic_rnn(input, hidden, params): + # type: (Tensor, Tuple[Tensor, Tensor], List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] + params_stride = 4 # NB: this assumes that biases are there + hx, cx = hidden + hy, cy = hidden # for scoping... + inputs, outputs = input.unbind(0), [] + for layer in range(hx.size(0)): + hy = hx[layer] + cy = cx[layer] + base_idx = layer * params_stride + wih = params[base_idx] + whh = params[base_idx + 1] + bih = params[base_idx + 2] + bhh = params[base_idx + 3] + for seq_idx in range(len(inputs)): + hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) + outputs += [hy] + inputs, outputs = outputs, [] + return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0)) + + if script: + cell = torch.jit.script(cell) + dynamic_rnn = torch.jit.script(dynamic_rnn) + + return dynamic_rnn diff --git a/benchmarks/fastrnns/profile.py b/benchmarks/fastrnns/profile.py new file mode 100644 index 0000000..3b0ec2b --- /dev/null +++ b/benchmarks/fastrnns/profile.py @@ -0,0 +1,138 @@ +import argparse +import os +import subprocess +import sys +import time +import torch +import datetime + +from .runner import get_rnn_runners + +PY3 = sys.version_info >= (3, 0) + + +def run_rnn(name, rnn_creator, nloops=5, + seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, + miniBatch=64, device='cuda', seed=None): + def run_iter(modeldef): + # Forward + forward_output = modeldef.forward(*modeldef.inputs) + + # "loss computation" and backward + if modeldef.backward_setup is not None: + backward_input = modeldef.backward_setup(forward_output) + else: + backward_input = forward_output + if modeldef.backward is not None: + modeldef.backward(*backward_input) + + # "Update" parameters + if modeldef.backward is not None: + for param in modeldef.params: + param.grad.data.zero_() + torch.cuda.synchronize() + + assert device == 'cuda' + creator_args = dict(seqLength=seqLength, numLayers=numLayers, + inputSize=inputSize, hiddenSize=hiddenSize, + miniBatch=miniBatch, device=device, seed=seed) + modeldef = rnn_creator(**creator_args) + + [run_iter(modeldef) for _ in range(nloops)] + + +def profile(rnns, sleep_between_seconds=1, nloops=5, + internal_run=True, # Unused, get rid of this TODO + seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, + miniBatch=64, device='cuda', seed=None): + params = dict(seqLength=seqLength, numLayers=numLayers, + inputSize=inputSize, hiddenSize=hiddenSize, + miniBatch=miniBatch, device=device, seed=seed) + for name, creator, context in get_rnn_runners(*rnns): + with context(): + run_rnn(name, creator, nloops, **params) + time.sleep(sleep_between_seconds) + + +def system(command): + """Returns (return-code, stdout, stderr)""" + print('[system] {}'.format(command)) + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=True) + output, err = p.communicate() + rc = p.returncode + if PY3: + output = output.decode("ascii") + err = err.decode("ascii") + return rc, output, err + + +def describe_sizes(**sizes): + # seqLength, numLayers, inputSize, hiddenSize, miniBatch + return 's{}-l{}-i{}-h{}-b{}'.format( + sizes['seqLength'], + sizes['numLayers'], + sizes['inputSize'], + sizes['hiddenSize'], + sizes['miniBatch'], + ) + + +OUTPUT_DIR = '~/profout/' + + +def nvprof_output_filename(rnns, **params): + rnn_tag = '-'.join(rnns) + size_tag = describe_sizes(**params) + date_tag = datetime.datetime.now().strftime("%m%d%y-%H%M") + return '{}prof_{}_{}_{}.nvvp'.format(OUTPUT_DIR, rnn_tag, + size_tag, date_tag) + + +def nvprof(cmd, outpath): + return system('nvprof -o {} {}'.format(outpath, cmd)) + + +def full_profile(rnns, **args): + args['internal_run'] = True + profile_args = [] + for k, v in args.items(): + profile_args.append('--{}={}'.format(k, v)) + profile_args.append('--rnns {}'.format(' '.join(rnns))) + + outpath = nvprof_output_filename(rnns, **args) + + cmd = '{} -m fastrnns.profile {}'.format( + sys.executable, ' '.join(profile_args)) + rc, stdout, stderr = nvprof(cmd, outpath) + if rc != 0: + raise RuntimeError('stderr: {}\nstdout: {}'.format(stderr, stdout)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Profile RNNs') + + parser.add_argument('--seqLength', default='100', type=int) + parser.add_argument('--numLayers', default='1', type=int) + parser.add_argument('--inputSize', default='512', type=int) + parser.add_argument('--hiddenSize', default='512', type=int) + parser.add_argument('--miniBatch', default='64', type=int) + parser.add_argument('--sleep_between_seconds', default='1', type=int) + parser.add_argument('--nloops', default='5', type=int) + + parser.add_argument('--rnns', nargs='*', + help='What to run. cudnn, aten, jit, etc') + + # if internal_run, we actually run the rnns. + # if not internal_run, we shell out to nvprof with internal_run=T + parser.add_argument('--internal_run', default=False, type=bool, + help='Don\'t use this') + args = parser.parse_args() + if args.rnns is None: + args.rnns = ['cudnn', 'aten', 'jit'] + print(args) + + if args.internal_run: + profile(**vars(args)) + else: + full_profile(**vars(args)) diff --git a/benchmarks/fastrnns/runner.py b/benchmarks/fastrnns/runner.py new file mode 100644 index 0000000..51c850e --- /dev/null +++ b/benchmarks/fastrnns/runner.py @@ -0,0 +1,67 @@ +from collections import namedtuple +from functools import partial +import torch +import torchvision.models as cnn + +from .factory import * + + +class DisableCuDNN(): + def __enter__(self): + self.saved = torch.backends.cudnn.enabled + torch.backends.cudnn.enabled = False + + def __exit__(self, *args, **kwargs): + torch.backends.cudnn.enabled = self.saved + + +class DummyContext(): + def __enter__(self): + pass + + def __exit__(self, *args, **kwargs): + pass + + +class AssertNoJIT(): + def __enter__(self): + import os + enabled = os.environ.get('PYTORCH_JIT', 1) + assert not enabled + + def __exit__(self, *args, **kwargs): + pass + + +RNNRunner = namedtuple('RNNRunner', [ + 'name', 'creator', 'context', +]) + + +def get_nn_runners(*names): + return [nn_runners[name] for name in names] + + +nn_runners = { + 'cudnn': RNNRunner('cudnn', pytorch_lstm_creator, DummyContext), + 'cudnn_dropout': RNNRunner('cudnn_dropout', partial(pytorch_lstm_creator, dropout=0.4), DummyContext), + 'cudnn_layernorm': RNNRunner('cudnn_layernorm', layernorm_pytorch_lstm_creator, DummyContext), + 'vl_cudnn': RNNRunner('vl_cudnn', varlen_pytorch_lstm_creator, DummyContext), + 'vl_jit': RNNRunner('vl_jit', partial(varlen_lstm_creator, script=True), DummyContext), + 'vl_py': RNNRunner('vl_py', varlen_lstm_creator, DummyContext), + 'aten': RNNRunner('aten', pytorch_lstm_creator, DisableCuDNN), + 'jit': RNNRunner('jit', lstm_creator, DummyContext), + 'jit_premul': RNNRunner('jit_premul', lstm_premul_creator, DummyContext), + 'jit_simple': RNNRunner('jit_simple', lstm_simple_creator, DummyContext), + 'jit_multilayer': RNNRunner('jit_multilayer', lstm_multilayer_creator, DummyContext), + 'jit_layernorm': RNNRunner('jit_layernorm', lnlstm_creator, DummyContext), + 'jit_layernorm_decom': RNNRunner('jit_layernorm_decom', + partial(lnlstm_creator, decompose_layernorm=True), + DummyContext), + 'jit_dropout': RNNRunner('jit_dropout', dropoutlstm_creator, DummyContext), + 'py': RNNRunner('py', partial(lstm_creator, script=False), DummyContext), + 'resnet18': RNNRunner('resnet18', imagenet_cnn_creator(cnn.resnet18, jit=False), DummyContext), + 'resnet18_jit': RNNRunner('resnet18_jit', imagenet_cnn_creator(cnn.resnet18), DummyContext), + 'resnet50': RNNRunner('resnet50', imagenet_cnn_creator(cnn.resnet50, jit=False), DummyContext), + 'resnet50_jit': RNNRunner('resnet50_jit', imagenet_cnn_creator(cnn.resnet50), DummyContext), +} diff --git a/benchmarks/fastrnns/scratch.py b/benchmarks/fastrnns/scratch.py new file mode 100644 index 0000000..c51d716 --- /dev/null +++ b/benchmarks/fastrnns/scratch.py @@ -0,0 +1,51 @@ +import torch + + +@torch.jit.script +def fn(x, scale, shift): + return scale * x / shift + + +@torch.jit.script +def recurrent(x, scale, shift): + y = x + for i in range(100): + y = fn(y, scale, shift) + return y + + +x = torch.randn(2, 2, device='cuda') +scale = torch.randn(2, 2, device='cuda', requires_grad=True) +shift = torch.randn(2, 2, device='cuda', requires_grad=True) +inputs = [x, scale, shift] + + +out = recurrent(x, scale, shift) +recurrent.graph_for(x, scale, shift) + + +import torch + + +@torch.jit.script +def recurrent_scaleshift(x, scale, shift): + y = x + for i in range(64): + y = scale * y + shift + return y + + +x = torch.randn(2, 2, device='cuda') +scale = torch.randn(2, 2, device='cuda', requires_grad=True) +shift = torch.randn(2, 2, device='cuda', requires_grad=True) +inputs = [x, scale, shift] +out = recurrent_scaleshift(x, scale, shift) +recurrent_scaleshift.graph_for(x, scale, shift) + + +import torch +x = torch.tensor([]) +x.requires_grad = True +x.mean().backward() # no error triggered +x = x.cuda() +x.mean().backward() diff --git a/benchmarks/fastrnns/test.py b/benchmarks/fastrnns/test.py new file mode 100644 index 0000000..6c8e4e1 --- /dev/null +++ b/benchmarks/fastrnns/test.py @@ -0,0 +1,159 @@ +import argparse +import torch +import torch.nn as nn + +from .cells import lstm_cell +from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator +from .runner import get_nn_runners + + +def barf(): + import pdb + pdb.set_trace() + + +def assertEqual(tensor, expected, threshold=0.001): + if isinstance(tensor, list) or isinstance(tensor, tuple): + for t, e in zip(tensor, expected): + assertEqual(t, e) + else: + if (tensor - expected).abs().max() > threshold: + barf() + + +def filter_requires_grad(tensors): + return [t for t in tensors if t.requires_grad] + + +def test_rnns(experim_creator, control_creator, check_grad=True, verbose=False, + seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, + miniBatch=64, device='cuda', seed=17): + creator_args = dict(seqLength=seqLength, numLayers=numLayers, + inputSize=inputSize, hiddenSize=hiddenSize, + miniBatch=miniBatch, device=device, seed=seed) + + print("Setting up...") + control = control_creator(**creator_args) + experim = experim_creator(**creator_args) + + # Precondition + assertEqual(experim.inputs, control.inputs) + assertEqual(experim.params, control.params) + + print("Checking outputs...") + control_outputs = control.forward(*control.inputs) + experim_outputs = experim.forward(*experim.inputs) + assertEqual(experim_outputs, control_outputs) + + print("Checking grads...") + assert control.backward_setup is not None + assert experim.backward_setup is not None + assert control.backward is not None + assert experim.backward is not None + control_backward_inputs = control.backward_setup(control_outputs, seed) + experim_backward_inputs = experim.backward_setup(experim_outputs, seed) + + control.backward(*control_backward_inputs) + experim.backward(*experim_backward_inputs) + + control_grads = [p.grad for p in control.params] + experim_grads = [p.grad for p in experim.params] + assertEqual(experim_grads, control_grads) + + if verbose: + print(experim.forward.graph_for(*experim.inputs)) + print('') + + +def test_vl_py(**test_args): + # XXX: This compares vl_py with vl_lstm. + # It's done this way because those two don't give the same outputs so + # the result isn't an apples-to-apples comparison right now. + control_creator = varlen_pytorch_lstm_creator + name, experim_creator, context = get_nn_runners('vl_py')[0] + with context(): + print('testing {}...'.format(name)) + creator_keys = [ + 'seqLength', 'numLayers', 'inputSize', + 'hiddenSize', 'miniBatch', 'device', 'seed' + ] + creator_args = {key: test_args[key] for key in creator_keys} + + print("Setting up...") + control = control_creator(**creator_args) + experim = experim_creator(**creator_args) + + # Precondition + assertEqual(experim.inputs, control.inputs[:2]) + assertEqual(experim.params, control.params) + + print("Checking outputs...") + control_out, control_hiddens = control.forward(*control.inputs) + control_hx, control_cx = control_hiddens + experim_out, experim_hiddens = experim.forward(*experim.inputs) + experim_hx, experim_cx = experim_hiddens + + experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2) + assertEqual(experim_padded, control_out) + assertEqual(torch.cat(experim_hx, dim=1), control_hx) + assertEqual(torch.cat(experim_cx, dim=1), control_cx) + + print("Checking grads...") + assert control.backward_setup is not None + assert experim.backward_setup is not None + assert control.backward is not None + assert experim.backward is not None + control_backward_inputs = control.backward_setup( + (control_out, control_hiddens), test_args['seed']) + experim_backward_inputs = experim.backward_setup( + (experim_out, experim_hiddens), test_args['seed']) + + control.backward(*control_backward_inputs) + experim.backward(*experim_backward_inputs) + + control_grads = [p.grad for p in control.params] + experim_grads = [p.grad for p in experim.params] + assertEqual(experim_grads, control_grads) + + if test_args['verbose']: + print(experim.forward.graph_for(*experim.inputs)) + print('') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Test lstm correctness') + + parser.add_argument('--seqLength', default='100', type=int) + parser.add_argument('--numLayers', default='1', type=int) + parser.add_argument('--inputSize', default='512', type=int) + parser.add_argument('--hiddenSize', default='512', type=int) + parser.add_argument('--miniBatch', default='64', type=int) + parser.add_argument('--device', default='cuda', type=str) + parser.add_argument('--check_grad', default='True', type=bool) + parser.add_argument('--variable_lstms', action='store_true') + parser.add_argument('--seed', default='17', type=int) + parser.add_argument('--verbose', action='store_true') + parser.add_argument('--rnns', nargs='*', + help='What to run. jit_premul, jit, etc') + args = parser.parse_args() + if args.rnns is None: + args.rnns = ['jit_premul', 'jit'] + print(args) + + if 'cuda' in args.device: + assert torch.cuda.is_available() + + rnn_runners = get_nn_runners(*args.rnns) + + should_test_varlen_lstms = args.variable_lstms + test_args = vars(args) + del test_args['rnns'] + del test_args['variable_lstms'] + + if should_test_varlen_lstms: + test_vl_py(**test_args) + + for name, creator, context in rnn_runners: + with context(): + print('testing {}...'.format(name)) + test_rnns(creator, pytorch_lstm_creator, **test_args) -- 2.7.4