From ff4b6d1a49695d8bf3b3ef01630a56d701d71ec3 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 28 Mar 2019 23:07:45 -0700 Subject: [PATCH] Delete batch tensor (#18575) Summary: Deleting batch tensor since we are no longer maintaining the project and keeping it functional is blocking other improvements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18575 Differential Revision: D14671126 Pulled By: eellison fbshipit-source-id: b42d5b699c4d12171ed95e6d3a977532167f0d2c --- test/test_docs_coverage.py | 3 +- test/test_jit.py | 543 ----------------------------- torch/CMakeLists.txt | 2 - torch/csrc/jit/batched/BatchTensor.cpp | 98 ------ torch/csrc/jit/batched/BatchTensor.h | 57 --- torch/csrc/jit/init.cpp | 4 - torch/csrc/jit/passes/to_batch.cpp | 610 --------------------------------- torch/csrc/jit/passes/to_batch.h | 47 --- torch/csrc/jit/script/init.cpp | 1 - torch/jit/__init__.py | 33 -- torch/jit/batchop.py | 528 ---------------------------- 11 files changed, 1 insertion(+), 1925 deletions(-) delete mode 100644 torch/csrc/jit/batched/BatchTensor.cpp delete mode 100644 torch/csrc/jit/batched/BatchTensor.h delete mode 100644 torch/csrc/jit/passes/to_batch.cpp delete mode 100644 torch/csrc/jit/passes/to_batch.h delete mode 100644 torch/jit/batchop.py diff --git a/test/test_docs_coverage.py b/test/test_docs_coverage.py index 811221a..3b565c3 100644 --- a/test/test_docs_coverage.py +++ b/test/test_docs_coverage.py @@ -36,8 +36,7 @@ class TestDocCoverage(unittest.TestCase): whitelist = { # below are some jit functions 'wait', 'fork', 'parse_type_comment', 'import_ir_module', - 'to_batch_graph', 'import_ir_module_from_buffer', - 'register_batch_operator', 'merge_type_from_type_comment', + 'import_ir_module_from_buffer', 'merge_type_from_type_comment', # below are symbols mistakely binded to torch.*, but should # go to torch.nn.functional.* instead diff --git a/test/test_jit.py b/test/test_jit.py index b1b5c98..02fe570 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -48,7 +48,6 @@ from copy import deepcopy import random from typing import List, Dict, Optional, Tuple from torch.jit.frontend import NotSupportedError -from torch.jit import BatchTensor from torch import Tensor from torch.jit.annotations import BroadcastingList2, BroadcastingList3 @@ -2465,548 +2464,6 @@ class TestJit(JitTestCase): self.assertTrue(tested_blocks) -class TestBatched(TestCase): - # generate random examples and create an batchtensor with them - def rand_batch(self, *dims): - dims = [dim for dim in dims if dim != ()] - xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]), - requires_grad=True) for i in range(dims[0])] - xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte()) - return xs, xb - - def test_create_batchtensor(self): - # create from tensorlist - xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5)) - self.assertEqual(xs, batch.examples()) - # create from data, mask, dims - batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims()) - self.assertEqual(xs, batch2.examples()) - # expand a tensor to a batchtensor given batch_size - xs = torch.rand(3, 4, 5) - batch3 = BatchTensor(xs, 2) - xs = xs.unsqueeze(0) - self.assertEqual([xs, xs], batch3.examples()) - - def test_batch_elementwise_unary(self): - @torch.jit.batch(batch_size=4) - def tanh(a): - return torch.tanh(a) - - xs, batch = self.rand_batch(4, (True, 3), (False, 2)) - res_batch = tanh(batch) - res = [torch.tanh(xs[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_elementwise_binary(self): - @torch.jit.batch(batch_size=4) - def add(a, b): - return a + b - - xs, batch = self.rand_batch(4, (True, 3), (False, 2)) - xs2, batch2 = xs, batch - res_batch = add(batch, batch2) - res = [torch.add(xs[j], xs2[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - # test broadcast - xs, batch = self.rand_batch(4, (False, 3), (False, 2)) - b = torch.rand(3, 2) - res_batch = add(batch, b) - res = [torch.add(xs[j], b) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_mm(self): - @torch.jit.batch(batch_size=4) - def mm(a, b): - return torch.mm(a, b) - - xs, batch = self.rand_batch(4, (True, 3), (False, 2)) - xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3)) - res_batch = mm(batch, batch2) - res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - # test broadcast - b = torch.rand(2, 4) - res_batch = mm(batch, b) - res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_matmul(self): - @torch.jit.batch(batch_size=4) - def matmul(a, b): - return torch.matmul(a, b) - - def matmul_test(xs, batch, xs2, batch2): - ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)] - ybs = matmul(batch, batch2) - self.assertEqual(ys, ybs.examples()) - - # 1 dimension * 1 dimension - xs, batch = self.rand_batch(4, (False, 2)) - xs2, batch2 = self.rand_batch(4, (False, 2)) - matmul_test(xs, batch, xs2, batch2) - # 1 dimension * 2 dimension - xs, batch = self.rand_batch(4, (False, 2)) - xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3)) - matmul_test(xs, batch, xs2, batch2) - # 2 dimension * 1 dimensions - xs, batch = self.rand_batch(4, (True, 3), (False, 2)) - xs2, batch2 = self.rand_batch(4, (False, 2)) - matmul_test(xs, batch, xs2, batch2) - # 2 dimension * 2 dimension - xs, batch = self.rand_batch(4, (True, 3), (False, 2)) - xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3)) - matmul_test(xs, batch, xs2, batch2) - - def test_batch_select(self): - @torch.jit.batch(batch_size=4) - def select(x): - return torch.select(x, 1, 0) - - xs, batch = self.rand_batch(4, (True, 3), (True, 2)) - res_batch = select(batch) - res = [torch.select(xs[j], 1, 0) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - xs, batch = self.rand_batch(4, (False, 3), (True, 2)) - res_batch = select(batch) - res = [torch.select(xs[j], 1, 0) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_index_select(self): - @torch.jit.batch(batch_size=4) - def index_select(x, ind): - return x.index_select(1, ind) - - xs, batch = self.rand_batch(4, (False, 5), (True, 2)) - ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)] - ind_batch = BatchTensor(ind, torch.tensor([]).byte()) - res_batch = index_select(batch, ind_batch) - res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_where(self): - @torch.jit.batch(batch_size=4) - def where(c, a, b): - return torch.where(c, a, b) - - xs, batch = self.rand_batch(4, (False, 3), (False, 2)) - xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2)) - - dims = [4, (False, 3), (False, 2)] - xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])] - batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]])) - - res_batch = where(batch_cond, batch, batch2) - res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_argmax(self): - @torch.jit.batch(batch_size=4) - def argmax(a): - return torch.argmax(a, 1) - - xs, batch = self.rand_batch(4, (True, 5), (True, 6)) - res_batch = argmax(batch) - res = [torch.argmax(xs[j], 1) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - @torch.jit.batch(batch_size=4) - def argmax(a): - return torch.argmax(a, 1, False) - - res_batch = argmax(batch) - res = [torch.argmax(xs[j], 1, False) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_topk(self): - @torch.jit.batch(batch_size=4) - def topk(a): - return torch.topk(a, 3, 1) - - xs, batch = self.rand_batch(4, (False, 5), (True, 6)) - - # along static dim - res_batch = topk(batch) - res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)] - res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)] - self.assertEqual(res, res_batch[0].examples()) - self.assertEqual(res_idx, res_batch[1].examples()) - - @torch.jit.batch(batch_size=4) - def topk(a): - return torch.topk(a, 1, 2) - - # along dynamic dim - res_batch = topk(batch) - res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)] - res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)] - self.assertEqual(res, res_batch[0].examples()) - self.assertEqual(res_idx, res_batch[1].examples()) - - def test_batch_softmax(self): - @torch.jit.batch(batch_size=4) - def softmax(a): - return torch.softmax(a, 1) - - xs, batch = self.rand_batch(4, (False, 5), (True, 6)) - - # along static dim - res_batch = softmax(batch) - res = [torch.softmax(xs[j], 1) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - @torch.jit.batch(batch_size=4) - def softmax(a): - return torch.softmax(a, 2) - - # along dynamic dim - res_batch = softmax(batch) - res = [torch.softmax(xs[j], 2) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_view(self): - @torch.jit.batch(batch_size=4) - def view(a): - return a.view([4, -1, 3]) - - xs, batch = self.rand_batch(4, (True, 5), (False, 3)) - res_batch = view(batch) - res = [xs[j].view([1, -1, 3]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_cat(self): - @torch.jit.batch(batch_size=4) - def cat2(a, b): - return torch.cat([a, b], 2) - - xs, batch = self.rand_batch(4, (True, 5), (False, 3)) - xs2, batch2 = xs, batch - res_batch = cat2(batch, batch2) - res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_batch_sum(self): - @torch.jit.batch(batch_size=4) - def batch_sum(a): - return a.sum() - - xs, batch = self.rand_batch(4, (True, 5), (False, 3)) - res_batch = batch_sum(batch) - res = [xs[j].sum().unsqueeze(0) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - def test_if_else(self): - def single_if(a, b): - if bool(a > b): - a = a + b - else: - a = a - b - return a - - batch_if = torch.jit.batch(batch_size=4)(single_if) - - a, batch_a = self.rand_batch(4, ()) - b, batch_b = self.rand_batch(4, ()) - res_batch = batch_if(batch_a, batch_b) - res = [single_if(a[j], b[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - script_if = torch.jit.script(single_if) - torch.to_batch_graph(script_if.graph) - - def test_if_else_with_scalar(self): - def single_if(a, b): - if bool(a > 0.1): - a = a + b - else: - a = a - b - return a - - batch_if = torch.jit.batch(batch_size=4)(single_if) - - a, batch_a = self.rand_batch(4, ()) - b, batch_b = self.rand_batch(4, ()) - res_batch = batch_if(batch_a, batch_b) - res = [single_if(a[j], b[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - script_if = torch.jit.script(single_if) - torch.to_batch_graph(script_if.graph) - - def test_if_noelse(self): - def single_if(a, b): - if bool(a > b): - a = a + b - return a - - batch_if = torch.jit.batch(batch_size=4)(single_if) - - a, batch_a = self.rand_batch(4, ()) - b, batch_b = self.rand_batch(4, ()) - res_batch = batch_if(batch_a, batch_b) - res = [single_if(a[j], b[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - script_if = torch.jit.script(single_if) - torch.to_batch_graph(script_if.graph) - - def test_if_noelse_with_scalar(self): - def single_if(a, b): - if bool(a > 0.1): - a = a + b - return a - - batch_if = torch.jit.batch(batch_size=4)(single_if) - - a, batch_a = self.rand_batch(4, ()) - b, batch_b = self.rand_batch(4, ()) - res_batch = batch_if(batch_a, batch_b) - res = [single_if(a[j], b[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - script_if = torch.jit.script(single_if) - torch.to_batch_graph(script_if.graph) - - def test_while(self): - def single_while(a, b): - while bool(a > b): - a = a - b - return a - - batch_while = torch.jit.batch(batch_size=4)(single_while) - - a, batch_a = self.rand_batch(4, ()) - b = [torch.abs(torch.rand(1)) for i in range(4)] - batch_b = BatchTensor(b, torch.tensor([]).byte()) - res_batch = batch_while(batch_a, batch_b) - res = [single_while(a[j], b[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - script_while = torch.jit.script(single_while) - torch.to_batch_graph(script_while.graph) - - def test_for(self): - def single_for(x, y): - for _ in range(10): - x = x + y - return x - - batch_for = torch.jit.batch(batch_size=4)(single_for) - - a, batch_a = self.rand_batch(4, ()) - b, batch_b = self.rand_batch(4, ()) - res_batch = batch_for(batch_a, batch_b) - res = [single_for(a[j], b[j]) for j in range(4)] - self.assertEqual(res, res_batch.examples()) - - script_for = torch.jit.script(single_for) - torch.to_batch_graph(script_for.graph) - - def test_lstm(self): - def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c): - for i in range(x_all.size(1)): - x = x_all.select(1, i) - i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i - f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f - o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o - # activations - i_t = torch.sigmoid(i_t) - f_t = torch.sigmoid(f_t) - o_t = torch.sigmoid(o_t) - # cell computations - c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c - c_t = torch.tanh(c_t) - c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t) - h_t = torch.mul(o_t, torch.tanh(c_t)) - h = h_t - c = c_t - return h - - LSTM_batch = torch.jit.batch(batch_size=4)(LSTM) - - batch_size, input_size, hidden_size = 4, 3, 2 - xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size)) - hx, h_batch = self.rand_batch(batch_size, (False, hidden_size)) - cx, c_batch = self.rand_batch(batch_size, (False, hidden_size)) - - # input to hidden weights - w_xi = torch.rand(input_size, hidden_size) - w_xf = torch.rand(input_size, hidden_size) - w_xo = torch.rand(input_size, hidden_size) - w_xc = torch.rand(input_size, hidden_size) - # hidden to hidden weights - w_hi = torch.rand(hidden_size, hidden_size) - w_hf = torch.rand(hidden_size, hidden_size) - w_ho = torch.rand(hidden_size, hidden_size) - w_hc = torch.rand(hidden_size, hidden_size) - # bias terms - b_i = torch.rand(hidden_size) - b_f = torch.rand(hidden_size) - b_o = torch.rand(hidden_size) - b_c = torch.rand(hidden_size) - - ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc, - w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)] - ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc, - w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) - self.assertEqual(ys, ybs.examples()) - - @slowTest - def test_greedy_search(self): - def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, - b_i, b_f, b_o, b_c, w_hs, b_s, iter_num): - iter_count = torch.zeros_like(iter_num) - while bool(iter_count < iter_num): - iter_count = iter_count + 1 - # LSTM Cell - i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i - f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f - o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o - # activations - i_t = torch.sigmoid(i_t) - f_t = torch.sigmoid(f_t) - o_t = torch.sigmoid(o_t) - # cell computations - c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c - c_t = torch.tanh(c_t) - c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t) - h_t = torch.mul(o_t, torch.tanh(c_t)) - h = h_t - c = c_t - # calculate feature with max probability - s_t = torch.matmul(h_t, w_hs) + b_s - p_t = torch.softmax(s_t, 1) - i_t = torch.argmax(p_t, 1) - x = embed.index_select(1, i_t).squeeze(1) - return h - - greedy_batch = torch.jit.batch(batch_size=4)(greedy) - - batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7 - xs, batch = self.rand_batch(batch_size, (False, input_size)) - hx, h_batch = self.rand_batch(batch_size, (False, hidden_size)) - cx, c_batch = self.rand_batch(batch_size, (False, hidden_size)) - embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size)) - iter_num = [torch.randint(2, 5, (1,)) for i in range(batch_size)] - iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte()) - - # input to hidden weights - w_xi = torch.rand(input_size, hidden_size) - w_xf = torch.rand(input_size, hidden_size) - w_xo = torch.rand(input_size, hidden_size) - w_xc = torch.rand(input_size, hidden_size) - # hidden to hidden weights - w_hi = torch.rand(hidden_size, hidden_size) - w_hf = torch.rand(hidden_size, hidden_size) - w_ho = torch.rand(hidden_size, hidden_size) - w_hc = torch.rand(hidden_size, hidden_size) - # bias terms - b_i = torch.rand(hidden_size) - b_f = torch.rand(hidden_size) - b_o = torch.rand(hidden_size) - b_c = torch.rand(hidden_size) - # hidden to vocab weights, bias - w_hs = torch.rand(hidden_size, vocab_size) - b_s = torch.rand(vocab_size) - - ys = [greedy(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, - w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j]) for j in range(batch_size)] - ybs = greedy_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc, - w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch) - self.assertEqual(ys, ybs.examples()) - - @slowTest - def test_beam_search(self): - def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, - b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx): - k = 5 - vocab_size = embed.size(1) - iter_count = torch.zeros_like(iter_num) - max_len = idx.size(2) - while bool(iter_count < iter_num): - iter_count = iter_count + 1 - # LSTM Cell - i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i - f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f - o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o - # activations - i_t = torch.sigmoid(i_t) - f_t = torch.sigmoid(f_t) - o_t = torch.sigmoid(o_t) - # cell computations - c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c - c_t = torch.tanh(c_t) - c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t) - h_t = torch.mul(o_t, torch.tanh(c_t)) - h = h_t - c = c_t - # calculate features with max probability - s_t = torch.matmul(h_t, w_hs) + b_s - s_t = s_t.view([1, s_t.size(1) * s_t.size(2)]) - p_t = torch.softmax(s_t, 1) - prob_t, idx_t = torch.topk(p_t, k, 1) - if(int(idx_t.dim()) > 1): - idx_t_tmp = idx_t.squeeze(0) - else: - idx_t_tmp = idx_t - new_y = torch.fmod(idx_t_tmp, vocab_size) - pre_y = idx_t_tmp / vocab_size - x = embed.index_select(1, new_y) - h = h_t.index_select(1, pre_y) - c = c_t.index_select(1, pre_y) - iter = int(iter_count[0]) - idx = torch.cat([idx.narrow(2, 0, iter).index_select(1, pre_y), - torch.fmod(idx_t, vocab_size).unsqueeze(-1), - idx.narrow(2, iter, max_len - iter)], 2) - idx = idx.narrow(2, 0, max_len) - return idx - - beam_batch = torch.jit.batch(batch_size=4)(beam) - - k = 5 - batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7 - max_len = 5 - xs, batch = self.rand_batch(batch_size, (False, 1), (False, input_size)) - hx, h_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size)) - cx, c_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size)) - embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size)) - iter_num = [torch.randint(2, max_len + 1, (1,)) for i in range(batch_size)] - iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte()) - - # input to hidden weights - w_xi = torch.rand(input_size, hidden_size) - w_xf = torch.rand(input_size, hidden_size) - w_xo = torch.rand(input_size, hidden_size) - w_xc = torch.rand(input_size, hidden_size) - # hidden to hidden weights - w_hi = torch.rand(hidden_size, hidden_size) - w_hf = torch.rand(hidden_size, hidden_size) - w_ho = torch.rand(hidden_size, hidden_size) - w_hc = torch.rand(hidden_size, hidden_size) - # bias terms - b_i = torch.rand(1, hidden_size) - b_f = torch.rand(1, hidden_size) - b_o = torch.rand(1, hidden_size) - b_c = torch.rand(1, hidden_size) - # hidden to vocab weights, bias - w_hs = torch.rand(hidden_size, vocab_size) - b_s = torch.rand(1, vocab_size) - - idx_batch = torch.jit.BatchTensor(torch.zeros([batch_size, k, max_len], dtype=torch.long), - torch.zeros([batch_size, 1, max_len]).byte(), - torch.tensor([0, 1]).byte()) - idx = [torch.zeros([1, k, max_len], dtype=torch.long) for _ in range(batch_size)] - - ys = [beam(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, - b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j], idx[j]).narrow(2, 0, int(iter_num[j])) - for j in range(batch_size)] - ybs = beam_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc, - w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch) - self.assertEqual(ys, ybs.examples()) - - def execWrapper(code, glob, loc): if PY2: exec(code) in glob, loc diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 5db5260..4e6fe9a 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -520,13 +520,11 @@ if (BUILD_PYTHON) ${TORCH_SRC_DIR}/csrc/autograd/python_variable.cpp ${TORCH_SRC_DIR}/csrc/autograd/python_variable_indexing.cpp ${TORCH_SRC_DIR}/csrc/byte_order.cpp - ${TORCH_SRC_DIR}/csrc/jit/batched/BatchTensor.cpp ${TORCH_SRC_DIR}/csrc/jit/init.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_loop.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/peephole.cpp - ${TORCH_SRC_DIR}/csrc/jit/passes/to_batch.cpp ${TORCH_SRC_DIR}/csrc/jit/python_arg_flatten.cpp ${TORCH_SRC_DIR}/csrc/jit/python_interpreter.cpp ${TORCH_SRC_DIR}/csrc/jit/python_ir.cpp diff --git a/torch/csrc/jit/batched/BatchTensor.cpp b/torch/csrc/jit/batched/BatchTensor.cpp deleted file mode 100644 index 7d709a6..0000000 --- a/torch/csrc/jit/batched/BatchTensor.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include - -namespace torch { -namespace jit { - -BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims) { - if (data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1) { - throw std::runtime_error( - "malformed MaskedBatch with data.dim(): " + std::to_string(data.dim()) + - ", mask.dim(): " + std::to_string(mask.dim()) + - ", dims.size(0): " + std::to_string(dims.size(0))); - } - this->data = std::move(data); - this->mask = std::move(mask); - this->dims = std::move(dims); -} - -BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size) { - dims = at::empty(data.dim(), data.options().dtype(at::kByte)); - dims.fill_(0); - std::vector sizes(data.dim() + 1, -1); - sizes[0] = batch_size; - this->data = data.unsqueeze(0).expand(sizes); - std::vector mask_sizes(data.dim() + 1, 1); - mask_sizes[0] = batch_size; - mask = at::empty(mask_sizes, data.options().dtype(at::kByte)); - mask.fill_(1); -} - -BatchTensor::BatchTensor( - const std::vector& datalist, - at::Tensor dims) { - auto bs = datalist.size(); - std::vector sizes(dims.size(0) + 1, 0), - mask_sizes(dims.size(0) + 1, 0); - sizes[0] = bs; - mask_sizes[0] = bs; - for (int64_t i = 1; i < dims.size(0) + 1; i++) { - for (const auto& x : datalist) { - sizes[i] = std::max(sizes[i], x.size(i)); - } - mask_sizes[i] = *dims[i - 1].data() ? sizes[i] : 1; - } - data = at::empty(sizes, datalist[0].options()); - data.fill_(0); - mask = at::empty(mask_sizes, datalist[0].options().dtype(at::kByte)); - mask.fill_(0); - for (std::size_t i = 0; i < datalist.size(); i++) { - auto data_item = data.narrow(0, i, 1); - auto mask_item = mask.narrow(0, i, 1); - for (int64_t j = 0; j < dims.size(0); j++) { - if (*dims[j].data()) { - data_item = data_item.narrow(j + 1, 0, datalist[i].size(j + 1)); - mask_item = mask_item.narrow(j + 1, 0, datalist[i].size(j + 1)); - } - } - data_item += datalist[i]; - mask_item.fill_(1); - } - this->dims = std::move(dims); -} - -std::vector BatchTensor::examples() { - std::vector result; - // calculate number of valid entries in dth dimension of data - auto mask_sum = [](at::Tensor data, int d) -> int64_t { - data = data.sum(d, /*keepdim=*/true); - while (data.dim() >= 1) - data = data[0]; - return *data.data(); - }; - for (int64_t i = 0; i < data.size(0); i++) { - auto data_tmp = data.narrow(0, i, 1); - for (int64_t d = 0; d < dims.size(0); d++) { - if (*dims[d].data()) { - data_tmp = data_tmp.narrow(d + 1, 0, mask_sum(mask[i], d)); - } - } - result.push_back(data_tmp); - } - return result; -} - -void initBatchTensorBindings(PyObject* module) { - auto m = py::handle(module).cast(); - auto jit = m.def_submodule("_jit"); - py::class_(jit, "BatchTensor") - .def(py::init()) - .def(py::init()) - .def(py::init, at::Tensor>()) - .def("examples", &BatchTensor::examples) - .def("get_data", &BatchTensor::get_data) - .def("get_mask", &BatchTensor::get_mask) - .def("get_dims", &BatchTensor::get_dims); -} - -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/batched/BatchTensor.h b/torch/csrc/jit/batched/BatchTensor.h deleted file mode 100644 index b1b49e8..0000000 --- a/torch/csrc/jit/batched/BatchTensor.h +++ /dev/null @@ -1,57 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include - -namespace torch { -namespace jit { -struct BatchTensor { - public: - BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims); - // expand a tensor to a batchtensor given batch_size - BatchTensor(const at::Tensor& data, int64_t batch_size); - BatchTensor(const std::vector& datalist, at::Tensor dims); - const char* toString() const { - return "BatchTensor"; - } - at::IntArrayRef sizes() const { - return data.sizes(); - } - int64_t dim() const { - return data.dim(); - } - std::vector examples(); - at::Tensor get_data() { - return data; - } - at::Tensor get_mask() { - return mask; - } - at::Tensor get_dims() { - return dims; - } - - public: - // data is a Tensor whose size is the batch size in the batch dimension, - // the size of all examples in static dimensions, - // and at least as large as the largest example in the batch in dynamic - // dimensions. - at::Tensor data; - // mask is a Tensor whose size is the batch size in the batch dimension, - // one in static dimensions, - // and at least as large as the largest example in the batch in dynamic - // dimensions. Each entry in the mask corresponds to one or more entries in - // the data array (singleton, i.e., static, dimensions are broadcasted), with - // a one in the mask denoting that the corresponding data entries represent - // valid, meaningful data and a zero denoting that they do not. - at::Tensor mask; - // dims is a 1-dimensional tensor with a bool for each non-batch dimension, - // representing whether that dimension is static (False) or dynamic (True). - at::Tensor dims; -}; - -void initBatchTensorBindings(PyObject* module); -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index f46dd9f..bd66058 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -31,7 +30,6 @@ #include #include #include -#include #include #include #include @@ -483,8 +481,6 @@ void initJITBindings(PyObject* module) { tracer::initPythonTracerBindings(module); script::initTreeViewBindings(module); script::initJitScriptBindings(module); - initBatchTensorBindings(module); - initRegisterBatchOpsBindings(module); } } // namespace jit diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp deleted file mode 100644 index 445e02b..0000000 --- a/torch/csrc/jit/passes/to_batch.cpp +++ /dev/null @@ -1,610 +0,0 @@ -#include -#include -#include - -namespace torch { -namespace jit { - -std::unordered_map>> - ToBatch::batch_operator_table; - -std::shared_ptr ToBatch::getBatchOperator( - const std::string& name, - int64_t num_inputs) { - if (batch_operator_table.find(name) == batch_operator_table.end()) { - throw std::runtime_error( - "function " + name + " is not supported in batched tensor yet"); - } - auto ops = batch_operator_table.at(name); - if (num_inputs == -1) // default function - return ops[0]; - for (auto op : ops) { - if (size_t(num_inputs) == op->inputs().size()) - return op; - } - throw std::runtime_error( - "function " + name + " with " + std::to_string(num_inputs) + - " inputs is not supported in batched tensor yet"); -} - -std::vector inlineUnpackedCallTo( - Graph& g, - Graph& callee, - ArrayRef inputs) { - return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true); -} - -// replace aten operator node with BatchTensor operator graph -void ToBatch::visitAten(Node* n, Block* block, Block* res_block) { - auto res_graph = res_block->owningGraph(); - auto func_name = std::string(n->kind().toUnqualString()); - std::vector new_inputs; - for (Value* input : n->inputs()) { - if (rn_env.find(input) == rn_env.end()) { // non-tensor input - auto new_input = batch_map.at(input); - new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end()); - } else { // batched tensor input - new_inputs.push_back(rn_env.at(input)); - } - } - - // transform scalar to tensor before pass to batch operator script - for (auto& input : new_inputs) { - if (input->type() == IntType::get() || input->type() == FloatType::get() || - input->type() == BoolType::get()) { - auto to_tensor_node = res_graph->createNumToTensor(input); - res_graph->insertNode(to_tensor_node); - input = to_tensor_node->output(); - } - } - - auto batch_graph = getBatchOperator(func_name, new_inputs.size()); - auto outputs = - inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs); - - // Assume all outputs from inlined operator implementation are in the triple - // form batched tensor or just a single non-tensor. - if (outputs.size() == 1) { - // if previous output is scalar, transform new output back to scalar from - // dynamic - TypePtr orig_type = n->outputs()[0]->type(); - if (!orig_type->isSubtypeOf(outputs[0]->type())) { - Symbol op; - if (orig_type == IntType::get()) { - op = prim::Int; - } else if (orig_type == FloatType::get()) { - op = prim::Float; - } else if (orig_type == BoolType::get()) { - op = prim::Bool; - } else { - throw std::runtime_error( - "NYI: scalar types other than int, float, and bool are not supported yet"); - } - rn_env[n->outputs()[0]] = res_graph->insert(op, {outputs[0]}); - } else { - rn_env[n->outputs()[0]] = outputs[0]; - } - } else { - for (size_t i = 0; i < n->outputs().size(); i++) { - auto output = n->outputs()[i]; - batch_map[output] = std::vector( - outputs.begin() + i * EXP_BTENSOR_SIZE, - outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE); - } - } -} - -// clone prim::Constant to new graph -// batching transformation is applied to the output of prim::NumToTensor. -// If there is a prim::NumToTensor following prim::Constant, it will be finally -// transformed to BatchTensor. -void ToBatch::visitConstant(Node* n, Block* block, Block* res_block) { - auto res_graph = res_block->owningGraph(); - auto* r_node = res_graph->createClone(n, rn_fn); - res_block->appendNode(r_node); - rn_env[n->output()] = r_node->output(); -} - -// change return tensor to expanded batched tensor, eg: {data, mask, dims} -void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block) { - auto res_graph = res_block->owningGraph(); - auto* r_node = res_graph->createClone(n, rn_fn); - res_block->appendNode(r_node); - auto outputs = inlineUnpackedCallTo( - *res_block->owningGraph(), - *getBatchOperator("batch_from_scalar_tensor"), - r_node->outputs()); - batch_map[n->output()] = outputs; -} - -// clone prim::TensorToNum to new graph -void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block) { - auto res_graph = res_block->owningGraph(); - if (rn_env.find(n->input()) == rn_env.end()) { - rn_env[n->input()] = batch_map.at(n->input())[0]; - } - auto* r_node = res_graph->createClone(n, rn_fn); - res_block->appendNode(r_node); - rn_env[n->output()] = r_node->output(); - batch_map[n->output()] = batch_map.at(n->input()); -} - -// clone prim::ListConstruct to new graph -void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block) { - auto res_graph = res_block->owningGraph(); - if (n->inputs()[0]->type() == - TensorType::get()) { // TensorList: expand directly - std::vector inputs; - for (Value* input : n->inputs()) { - auto res = batch_map.at(input); - inputs.insert(inputs.end(), res.begin(), res.end()); - } - batch_map[n->output()] = inputs; - } else { // ScalarList: transform to tensor, then transform back - for (Value* input : n->inputs()) { - if (rn_env.find(input) == rn_env.end()) { - rn_env[input] = batch_map.at(input)[0]; - } - } - auto* r_node = res_graph->createClone(n, rn_fn); - res_block->appendNode(r_node); - // transform int[] to tensor - auto to_tensor_node = - res_graph->create(Symbol::fromQualString("aten::_list_to_tensor")); - to_tensor_node->addInput(r_node->output()); - res_block->appendNode(to_tensor_node); - rn_env[n->output()] = to_tensor_node->output(); - } -} - -// clang-format off -// prim::If transformation: -// elif is not supported -// -// transformation example: -// @torch.jit.batch(batch_size=4) -// def batch_if(a, b): -// if a > b: -// a += b -// else: -// a -= b -// return a -// -// original graph: -// graph(%a.1 : Dynamic -// %b : Dynamic) { -// %2 : Dynamic = aten::gt(%a.1, %b) -// %a : Dynamic = prim::If(%2) -// block0() { -// %a.2 : Dynamic = aten::add[alpha={1}](%a.1, %b) -// -> (%a.2) -// } -// block1() { -// %a.3 : Dynamic = aten::sub[alpha={1}](%a.1, %b) -// -> (%a.3) -// } -// return (%a); -// } -// -// transformed graph: -// graph(%a.1_data : Dynamic -// %a.1_mask : Dynamic -// %a.1_dims : Dynamic -// %b_data : Dynamic -// %b_mask : Dynamic -// %b_dims : Dynamic) { -// %6 : Dynamic = aten::gt(%a.1_data, %b_data) // calculate condition -// %7 : Dynamic = aten::mul(%a.1_mask, %b_mask) -// %8 : Dynamic = aten::__or__(%a.1_dims, %b_dims) -// %9 : int = prim::TensorToNum(%6) -// %10 : Long() = prim::Constant[value={1}]() // if_block -// %alpha.1 : float = prim::TensorToNum(%10) -// %data.1 : Dynamic = aten::add(%a.1_data, %b_data, %alpha.1) -// %mask.1 : Dynamic = aten::mul(%a.1_mask, %b_mask) -// %dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims) -// %15 : Long() = prim::Constant[value={1}]() // else_block -// %alpha : float = prim::TensorToNum(%15) -// %data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha) -// %mask : Dynamic = aten::mul(%a.1_mask, %b_mask) -// %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims) -// %20 : Dynamic = aten::type_as(%7, %6) // combine two outputs (batch_where) -// %cond_mask.1 : Dynamic = aten::mul(%6, %20) -// %22 : int = aten::dim(%cond_mask.1) -// %23 : int = prim::Constant[value=1]() -// %24 : int = aten::eq(%22, %23) -// %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%24) -// block0() { -// %28 : int = aten::dim(%data.1) -// %29 : int = prim::Constant[value=1]() -// %30 : int = aten::sub(%28, %29) -// %31 : int = prim::Constant[value=1]() -// %data.3 : Dynamic = prim::Loop(%30, %31, %cond_mask.1) -// block0(%_ : int, %34 : Dynamic) { -// %35 : int = prim::Constant[value=1]() -// %36 : int = aten::neg(%35) -// %data.2 : Dynamic = aten::unsqueeze(%34, %36) -// %38 : int = prim::Constant[value=1]() -// -> (%38, %data.2) -// } -// %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) -// %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1) -// -> (%cond_data.1, %cond_mask.2, %data.3) -// } -// block1() { -// -> (%cond_mask.1, %cond_mask.1, %cond_mask.1) -// } -// %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4) -// %res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask) -// %res_dims : Dynamic = aten::__or__(%dims.1, %dims) -// return (%res_data, %res_mask, %res_dims); -// } -// clang-format on -void ToBatch::visitIf(Node* n, Block* block, Block* res_block) { - toBatch(n->blocks()[0], res_block); - toBatch(n->blocks()[1], res_block); - - // combine results from two if paths - for (size_t i = 0; i < n->outputs().size(); i++) { - std::vector inputs; - if (batch_map.find(n->input()) == batch_map.end()) { // cond is scalar - inputs.push_back(rn_env.at(n->input())); - } else { // cond is tensor - auto cond = batch_map.at(n->input()); - inputs.insert(inputs.end(), cond.begin(), cond.end()); - } - auto if_output = batch_map.at(n->blocks()[0]->outputs()[i]); - inputs.insert(inputs.end(), if_output.begin(), if_output.end()); - auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]); - inputs.insert(inputs.end(), else_output.begin(), else_output.end()); - auto outputs = inlineUnpackedCallTo( - *res_block->owningGraph(), - *getBatchOperator("where", inputs.size()), - inputs); - batch_map[n->outputs()[i]] = outputs; - } -} - -// clang-format off -// prim::Loop transformation: -// -// transformation example: -// @torch.jit.batch(batch_size=4) -// def batch_while(a, b): -// while a > b: -// a -= b -// return a -// -// original graph: -// graph(%a.1 : Dynamic -// %b : Dynamic) { -// %2 : int = prim::Constant[value={2147483647}]() -// %3 : Dynamic = aten::gt(%a.1, %b) -// %a : Dynamic = prim::Loop(%2, %3, %a.1) -// block0(%4 : Dynamic, %5 : Dynamic) { -// %a.2 : Dynamic = aten::sub[alpha={1}](%5, %b) -// %9 : Dynamic = aten::gt(%a.2, %b) -// -> (%9, %a.2) -// } -// return (%a); -// } -// -// transformed graph: -// graph(%a.1_data : Dynamic -// %a.1_mask : Dynamic -// %a.1_dims : Dynamic -// %b_data : Dynamic -// %b_mask : Dynamic -// %b_dims : Dynamic) { -// %6 : int = prim::Constant[value=2147483647]() -// %7 : Dynamic = aten::gt(%a.1_data, %b_data) -// %8 : Dynamic = aten::mul(%a.1_mask, %b_mask) -// %9 : Dynamic = aten::__or__(%a.1_dims, %b_dims) -// %10 : int = prim::TensorToNum(%7) -// %11 : Dynamic = aten::mul(%7, %8) -// %12 : Dynamic = aten::sum(%11) -// %13 : Dynamic = aten::gt[other={0}](%12) // cond_any -// %14 : int = prim::TensorToNum(%13) -// %62 : Dynamic, %63 : Dynamic, %64 : Dynamic, %a : Dynamic, %60 : Dynamic, %61 : Dynamic = prim::Loop(%6, %14, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims) -// block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) { -// %23 : Long() = prim::Constant[value={1}]() -// %alpha : float = prim::TensorToNum(%23) -// %data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha) -// %mask : Dynamic = aten::mul(%6_mask, %b_mask) -// %dims : Dynamic = aten::__or__(%6_dims, %b_dims) -// %28 : Dynamic = aten::gt(%data.1, %b_data) -// %29 : Dynamic = aten::mul(%mask, %b_mask) -// %30 : Dynamic = aten::__or__(%dims, %b_dims) -// %31 : int = prim::TensorToNum(%28) -// %32 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2) // update outputs (batch_where) -// %cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %32) -// %34 : int = aten::dim(%cond_mask.1) -// %35 : int = prim::Constant[value=1]() -// %36 : int = aten::eq(%34, %35) -// %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%36) -// block0() { -// %40 : int = aten::dim(%data.1) -// %41 : int = prim::Constant[value=1]() -// %42 : int = aten::sub(%40, %41) -// %43 : int = prim::Constant[value=1]() -// %data.3 : Dynamic = prim::Loop(%42, %43, %cond_mask.1) -// block0(%_ : int, %46 : Dynamic) { -// %47 : int = prim::Constant[value=1]() -// %48 : int = aten::neg(%47) -// %data.2 : Dynamic = aten::unsqueeze(%46, %48) -// %50 : int = prim::Constant[value=1]() -// -> (%50, %data.2) -// } -// %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) -// %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask) -// -> (%cond_data.1, %cond_mask.2, %data.3) -// } -// block1() { -// -> (%cond_mask.1, %cond_mask.1, %cond_mask.1) -// } -// %res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data) -// %res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask) -// %res_dims : Dynamic = aten::__or__(%dims, %6_dims) -// %56 : Dynamic = aten::mul(%28, %29) -// %57 : Dynamic = aten::sum(%56) -// %58 : Dynamic = aten::gt[other={0}](%57) -// %59 : int = prim::TensorToNum(%58) -// -> (%59, %28, %29, %30, %res_data, %res_mask, %res_dims) -// } -// return (%a, %60, %61); -// } -// clang-format on -void ToBatch::visitLoop(Node* n, Block* block, Block* res_block) { - auto res_graph = res_block->owningGraph(); - // bool cond_is_tensor indicates whether cond is tensor - // cond_is_tensor = false, eg: for loop, n->inputs()[1] = byte() - // cond_is_tensor = true, eg: in some while loop, cond is a batched tensor, - // we need to add expanded cond to the inputs of - // loop node and block, and compute cond_any as - // cond for while loop - bool cond_is_tensor = (batch_map.find(n->inputs()[1]) != batch_map.end()); - - // create prim::Loop node for res_block - - // type of cond in loop should be int type - if (rn_env.at(n->inputs()[0])->type() != IntType::get()) { - rn_env[n->inputs()[0]] = - res_graph->insert(prim::Int, {rn_env.at(n->inputs()[0])}); - } - if (cond_is_tensor) { - auto cond = batch_map.at(n->inputs()[1]); - auto cond_any = inlineUnpackedCallTo( - *res_block->owningGraph(), *getBatchOperator("any"), cond); - rn_env[n->inputs()[1]] = res_graph->insert(prim::Bool, {cond_any[0]}); - } - for (size_t i = 2; i < n->inputs().size(); i++) { - auto input = n->inputs()[i]; - rn_env[input] = batch_map.at(input)[0]; - } - auto* r_node = res_graph->createClone(n, rn_fn, /*copy_blocks=*/false); - - // change inputs of prim::Loop - if (cond_is_tensor) { - for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) { - auto cond = batch_map.at(n->inputs()[1]); - r_node->insertInput(i + 2, cond[i]); - } - } - for (size_t i = 2; i < n->inputs().size(); i++) { - for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) { - r_node->insertInput( - (i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 + - j, - batch_map.at(n->inputs()[i])[j]); - } - } - res_block->appendNode(r_node); - - // create block for Loop node in res_block - // if cond is tensor: first 4 inputs of block: cond_any, cond_data, - // cond_mask, cond_dims - // if cond is not tensor: first 1 input of block: cond - auto loop_block = r_node->addBlock(); - - // add inputs - loop_block->addInput("loop_num"); - loop_block->inputs()[0]->setType(IntType::get()); - rn_env[n->blocks()[0]->inputs()[0]] = loop_block->inputs()[0]; - if (cond_is_tensor) { - for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) { - loop_block->addInput("cond_" + EXP_BTENSOR_NAME[i]); - } - } - for (size_t i = 1; i < n->blocks()[0]->inputs().size(); i++) { - auto input = n->blocks()[0]->inputs()[i]; - auto name = input->uniqueName(); - for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) { - loop_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]); - } - batch_map[input] = - std::vector(loop_block->inputs() - .slice( - (i - 1) * EXP_BTENSOR_SIZE + 1 + - EXP_BTENSOR_SIZE * cond_is_tensor, - EXP_BTENSOR_SIZE) - .vec()); - } - - toBatch(n->blocks()[0], loop_block); - - WithInsertPoint guard(loop_block); - - // use where operator to update variables and add to outputs - for (size_t i = 0; i < n->outputs().size(); i++) { - std::vector inputs, outputs; - if (cond_is_tensor) { - for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) { - inputs.push_back(loop_block->inputs()[j + 1]); - } - auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]); - inputs.insert(inputs.end(), data.begin(), data.end()); - for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) { - inputs.push_back( - loop_block - ->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]); - } - outputs = inlineUnpackedCallTo( - *res_block->owningGraph(), *getBatchOperator("where"), inputs); - } else { - for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) { - inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + 1]); - } - auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]); - inputs.insert(inputs.end(), data.begin(), data.end()); - outputs = inlineUnpackedCallTo( - *res_block->owningGraph(), *getBatchOperator("update"), inputs); - } - batch_map[n->outputs()[i]] = outputs; - for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) { - loop_block->registerOutput(outputs[j]); - } - } - - // update loop conditions - if (cond_is_tensor) { - auto cond = batch_map.at(n->blocks()[0]->outputs()[0]); - auto cond_any = inlineUnpackedCallTo( - *res_block->owningGraph(), *getBatchOperator("any"), cond); - auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]}); - loop_block->insertOutput(0, to_bool_output); - for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) { - loop_block->insertOutput(i + 1, cond[i]); - } - } else { - auto cond = rn_env.at(n->blocks()[0]->outputs()[0]); - loop_block->insertOutput(0, cond); - } - - // change outputs of prim::Loop - auto size = r_node->outputs().size(); - for (size_t i = 0; i < size; i++) { - for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) { - r_node->insertOutput(i * EXP_BTENSOR_SIZE + j); - } - batch_map[n->outputs()[i]] = - r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec(); - } - // add cond to outputs of loop node - if (cond_is_tensor) { - for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) { - r_node->insertOutput(i); - } - } -} - -void ToBatch::toBatch(Block* block, Block* res_block) { - WithInsertPoint guard(res_block); - - // change inputs of block-expand tensor to batchtensor eg: (data, mask, dims) - // eg: a -> a_data, a_mask, a_dims for block in prim::Loop, register inputs - // separately to deal with cond - if (!block->owningNode() || block->owningNode()->kind() != prim::Loop) { - auto size = block->inputs().size(); - for (size_t i = 0; i < size; i++) { - auto input = block->inputs()[i]; - auto name = input->uniqueName(); - for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) { - res_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]); - } - batch_map[input] = - std::vector(res_block->inputs() - .slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE) - .vec()); - } - } - - for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) { - auto n = *it; - if (n->kind().is_aten()) { - visitAten(n, block, res_block); - } else if (n->kind().is_prim()) { - switch (n->kind()) { - case prim::Constant: - visitConstant(n, block, res_block); - break; - case prim::NumToTensor: - visitNumToTensor(n, block, res_block); - break; - case prim::Bool: - case prim::Float: - case prim::Int: - visitTensorToNum(n, block, res_block); - break; - case prim::ListConstruct: - visitListConstruct(n, block, res_block); - break; - case prim::If: - visitIf(n, block, res_block); - break; - case prim::Loop: - visitLoop(n, block, res_block); - break; - default: - throw std::runtime_error( - "NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet"); - } - } else { - throw std::runtime_error( - "NYI: node that is not aten or prim kind is not supported yet"); - } - } - // change outputs of block - expand tensor to batchtensor(data, mask, dims) - // for block in prim::Loop, register outputs separately to deal with cond and - // cond_any - // - // for block in prim::If, register outputs separately by combining - // outputs from two paths and return - if (!block->owningNode() || - (block->owningNode()->kind() != prim::Loop && - block->owningNode()->kind() != prim::If)) { - for (Value* output : block->outputs()) { - auto r_output = batch_map.at(output); - for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) { - res_block->registerOutput(r_output[i]); - } - } - } -} - -std::shared_ptr to_batch_graph(std::shared_ptr graph) { - // lower the tuple before the pass - if (graph->outputs().at(0)->type()->kind() == TupleType::Kind) { - graph = graph->copy(); - auto outs = createTupleUnpack(graph->outputs().at(0)); - graph->eraseOutput(0); - for (auto o : outs) - graph->registerOutput(o); - EliminateDeadCode(graph->block()); - } - std::shared_ptr res_graph = std::make_shared(); - ToBatch to_batch; - to_batch.toBatch(graph->block(), res_graph->block()); - - // methods should only have a single output, so we pack everything into a - // tuple - auto tup = - res_graph->insertNode(res_graph->createTuple(res_graph->outputs())); - while (res_graph->outputs().size() > 0) - res_graph->eraseOutput(res_graph->outputs().size() - 1); - res_graph->registerOutput(tup->output()); - EliminateDeadCode(res_graph->block()); - - return res_graph; -} - -void initRegisterBatchOpsBindings(PyObject* module) { - auto m = py::handle(module).cast(); - m.def("to_batch_graph", to_batch_graph); - m.def( - "register_batch_operator", - [](std::string name, std::shared_ptr graph) { - ToBatch::batch_operator_table[name].push_back(graph); - }); -} - -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/passes/to_batch.h b/torch/csrc/jit/passes/to_batch.h deleted file mode 100644 index 76bf53d..0000000 --- a/torch/csrc/jit/passes/to_batch.h +++ /dev/null @@ -1,47 +0,0 @@ -#pragma once - -#include -#include - -#include - -namespace torch { -namespace jit { - -class ToBatch { - private: - // number of tensors to represent a expanded BatchTensor. {data, mask, dims} - // for now. - const size_t EXP_BTENSOR_SIZE = 3; - const std::vector EXP_BTENSOR_NAME = {"data", "mask", "dims"}; - // mapping from tensor in original graph to {data, mask, dims} in new graph - std::unordered_map> batch_map; - // mapping from input in original graph to new input in new graph - used in - // createClone - std::unordered_map rn_env; - std::function rn_fn = [this](Value* v) { - return rn_env.at(v); - }; - - private: - std::shared_ptr getBatchOperator( - const std::string& name, - int64_t input_num = -1); - void visitAten(Node* n, Block* block, Block* res_block); - void visitConstant(Node* n, Block* block, Block* res_block); - void visitNumToTensor(Node* n, Block* block, Block* res_block); - void visitTensorToNum(Node* n, Block* block, Block* res_block); - void visitListConstruct(Node* n, Block* block, Block* res_block); - void visitIf(Node* n, Block* block, Block* res_block); - void visitLoop(Node* n, Block* block, Block* res_block); - - public: - static std::unordered_map>> - batch_operator_table; - TORCH_API void toBatch(Block* block, Block* res_block); -}; - -TORCH_API std::shared_ptr to_batch_graph(std::shared_ptr graph); -TORCH_API void initRegisterBatchOpsBindings(PyObject* module); -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index d3efa4d..f4d1c89 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 9d30e88..ec895cf 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -57,7 +57,6 @@ _flatten = torch._C._jit_flatten _unflatten = torch._C._jit_unflatten _jit_script_compile = torch._C._jit_script_compile _jit_script_class_compile = torch._C._jit_script_class_compile -BatchTensor = torch._C._jit.BatchTensor Future = torch._C.Future _fork = torch._C.fork @@ -800,38 +799,6 @@ def _is_weak_type(cls): return cls in _jit_internal.weak_types -def batch(batch_size=1, optimize=True, _frames_up=0): - def decorator(fn): - if not _enabled: - return fn - import torch.jit.batchop - mod = script(fn, optimize, _frames_up) - res_graph = torch.to_batch_graph(mod.graph) - res_mod = ScriptModule() - res_mod._create_method_from_graph('forward', res_graph) - - def wrapper(*args): - new_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - arg = BatchTensor(arg, batch_size) - if isinstance(arg, BatchTensor): - new_args.extend([arg.get_data(), arg.get_mask(), arg.get_dims()]) - else: - new_args.append(arg) - res = res_mod(*new_args) - assert len(res) % 3 == 0 - if len(res) % 3 != 0: - raise "non-batched-tensor output is not supported yet" - result = [BatchTensor(*res[i * 3: i * 3 + 3]) for i in range(len(res) // 3)] - if len(result) == 1: - return result[0] - return result - wrapper.__doc__ = fn.__doc__ - return wrapper - return decorator - - # These OrderedDictWrapper classes replace the actual OrderedDicts in # module with versions that get/set properties inside of script::Module. # This allows us to reuse most of nn.Module while still storing the diff --git a/torch/jit/batchop.py b/torch/jit/batchop.py deleted file mode 100644 index b0af553..0000000 --- a/torch/jit/batchop.py +++ /dev/null @@ -1,528 +0,0 @@ -import torch -from torch.jit import BatchTensor - - -# TODO: there are some commented raise statements -# when we support rasie exception in script, we want to check them -@torch.jit.script -def batch_tanh(data, mask, dims): - data = torch.tanh(data) - return data, mask, dims - - -@torch.jit.script -def batch_sigmoid(data, mask, dims): - data = torch.sigmoid(data) - return data, mask, dims - - -@torch.jit.script -def batch_relu(data, mask, dims): - data = torch.relu(data) - return data, mask, dims - - -@torch.jit.script -def batch_neg(data, mask, dims): - data = torch.neg(data) - return data, mask, dims - - -@torch.jit.script -def batch_neg_scalar(data): - return torch.neg(data) - - -@torch.jit.script -def batch_add(data1, mask1, dims1, data2, mask2, dims2, alpha_): - alpha = float(alpha_) - data = torch.add(data1, data2, alpha=alpha) - mask = mask1 * mask2 - dims = dims1.__or__(dims2) - return data, mask, dims - - -@torch.jit.script -def batch_add_scalar(data, mask, dims, other, alpha_): - alpha = float(alpha_) - data = torch.add(data, other.type_as(data), alpha=alpha) - return data, mask, dims - - -@torch.jit.script -def batch_sub(data1, mask1, dims1, data2, mask2, dims2, alpha_): - alpha = float(alpha_) - data = torch.sub(data1, data2, alpha=alpha) - mask = mask1 * mask2 - dims = dims1.__or__(dims2) - return data, mask, dims - - -@torch.jit.script -def batch_sub_scalar(data1, data2): - return data1 - data2 - - -@torch.jit.script -def batch_mul(data1, mask1, dims1, data2, mask2, dims2): - data = torch.mul(data1, data2) - mask = mask1 * mask2 - dims = dims1.__or__(dims2) - return data, mask, dims - - -@torch.jit.script -def batch_mul_scalar(data1, data2): - return data1 * data2 - - -@torch.jit.script -def batch_div(data, mask, dims, other): # div(batchtensor, scalar) - data = torch.div(data, other) - return data, mask, dims - - -@torch.jit.script -def batch_mm(data1, mask1, dims1, data2, mask2, dims2): - data1 = data1 * mask1.type_as(data1) - data2 = data2 * mask2.type_as(data2) - data = torch.bmm(data1, data2) - mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1)) - dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)])) - return data, mask, dims - - -@torch.jit.script -def batch_matmul(data1, mask1, dims1, data2, mask2, dims2): - d1 = data1.dim() - 1 - d2 = data2.dim() - 1 - data1 = data1 * mask1.type_as(data1) - data2 = data2 * mask2.type_as(data2) - if d1 == 1: - data1 = data1.unsqueeze(-2) - if d2 == 1: - data2 = data2.unsqueeze(-1) - data = torch.bmm(data1, data2) - mask = mask1 - dims = dims1 - if d1 == 1 and d2 == 1: - # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask).all(): - # raise ValueError("cannot contract non-matching dimensions") - data = data.squeeze(-1).squeeze(-1) - mask = mask1.narrow(1, 0, 1).squeeze(-1) - dims = dims1[:0] # empty tensor - if d1 == 2 and d2 == 1: - # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask).all(): - # raise ValueError("cannot contract non-matching dimensions") - data = data.squeeze(-1) - mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1).unsqueeze(-1)).squeeze(-1) - dims = dims1[:1] - elif d1 == 1 and d2 == 2: - # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask[:, :, 0]).all(): - # raise ValueError("cannot contract non-matching dimensions") - data = data.squeeze(-2) - mask = torch.bmm(mask1.narrow(1, 0, 1).unsqueeze(-2), mask2.narrow(1, 0, 1)).squeeze(-2) - dims = dims2[1:dims2.size(0)] - elif d1 == 2 and d2 == 2: - # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask[:, :, 0]).all(): - # raise ValueError("cannot contract non-matching dimensions") - mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1)) - dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)])) - # else: - # raise NotImplementedError("matmul not implemented with batches of 3+D tensors") - return data, mask, dims - - -@torch.jit.script -def batch_select(data, mask, dims, dim_, index_): - dim = int(dim_) - index = int(index_) - # if dim == 0: - # raise ValueError("Cannot select 0 dim in BatchTensor") - data = data.select(dim, index) - if bool(dims[dim - 1]): - mask = mask.select(dim, index) - else: - mask = mask.select(dim, 0) - dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)])) - return data, mask, dims - - -@torch.jit.script -def batch_fmod(data, mask, dims, other_): - other = int(other_) - data = torch.fmod(data, other) - return data, mask, dims - - -@torch.jit.script -def batch_zeros_like(data, mask, dims): - res_data = torch.zeros_like(data) - return res_data, mask, dims - - -@torch.jit.script -def batch_index_select(data, mask, dims, dim_, index_data, index_mask, index_dims): - dim = int(dim_) - # if dim == 0: - # raise ValueError("Cannot index_select along 0 dim in BatchTensor") - batch_size = data.size(0) # TODO maybe index_mask will be used at some point - res_data = torch.zeros([0]) - res_mask = torch.zeros([0]) - for i in range(batch_size): - d = data[i].index_select(dim - 1, index_data[i]).unsqueeze(0) - if bool(dims[dim - 1]): - m = mask[i].index_select(dim - 1, index_data[i]).unsqueeze(0) - else: - m = mask[i].unsqueeze(0) - if i == 0: - res_data = d - res_mask = m - else: - res_data = torch.cat((res_data, d), 0) - res_mask = torch.cat((res_mask, m), 0) - return res_data, res_mask, dims - - -@torch.jit.script -def batch_view_as(data, mask, dims, data1, mask1, dims1): - # if data.size(0) != data1.size(0): - # raise ValueError("In view_as, tensor and target tensor should have the same batch_size") - # if not torch.equal(dims, dims1): - # raise ValueError("In batched view_as, dims and target dims should be the same") - data = data.view_as(data1) - mask = mask.view_as(mask1) - dims = dims1 - return data, mask, dims - - -# assume data, data1, data2 have same size -@torch.jit.script -def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2): - data = data * mask.type_as(data) - cond_data = data - cond_mask = data - if data.dim() == 1: - for _ in range(data1.dim() - 1): - data = data.unsqueeze(data.dim()) - cond_data = data.expand_as(data1) - cond_mask = data.expand_as(mask1) - res_data = torch.where(cond_data, data1, data2) - res_mask = torch.where(cond_mask, mask1, mask2) - res_dims = dims1.__or__(dims2) - return res_data, res_mask, res_dims - - -@torch.jit.script -def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2): - cond = torch.zeros([1], dtype=torch.uint8) - res_data = torch.where(cond, data1, data2) - res_mask = torch.where(cond, mask1, mask2) - res_dims = torch.where(cond, dims1, dims2) - return res_data, res_mask, res_dims - - -@torch.jit.script -def batch_update(batch_data, batch_mask, batch_dims, new_data, new_mask, new_dims): - data = torch.where(new_mask, new_data, batch_data) - return data, new_mask, new_dims # TODO: consider whether return new_mask and new_dims - - -@torch.jit.script -def batch_any(data, mask, dims): - return torch.gt(torch.sum(data * mask), 0) - - -@torch.jit.script -def batch_type_as(data, mask, dims, data1, mask1, dims1): - return data.type_as(data1), mask, dims - - -@torch.jit.script -def batch_gt(data, mask, dims, data1, mask1, dims1): - return torch.gt(data, data1), mask * mask1, dims.__or__(dims1) - - -@torch.jit.script -def batch_gt_scalar(data1, data2): - return torch.gt(data1, data2) - - -@torch.jit.script -def batch_gt_one_scalar(data, mask, dims, other_): - other = float(other_) - return torch.gt(data, other), mask, dims - - -@torch.jit.script -def batch_lt(data, mask, dims, data1, mask1, dims1): - return torch.lt(data, data1), mask * mask1, dims.__or__(dims1) - - -@torch.jit.script -def batch_eq(data, mask, dims, data1, mask1, dims1): - return torch.eq(data, data1), mask * mask1, dims.__or__(dims1) - - -@torch.jit.script -def batch_size(data, mask, dims, dim_): - dim = int(dim_) - return data.size(dim) - - -@torch.jit.script -def batch_dim(data, mask, dims): - return data.dim() - - -@torch.jit.script -def batch_squeeze(data, mask, dims, dim_): - if int(dim_) < 0: - dim_ = dim_ + data.dim() - dim = int(dim_) - # if dim == 0: - # raise ValueError("cannot do squeeze along batch_dim") - data = data.squeeze(dim) - mask = mask.squeeze(dim) - dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)])) - return data, mask, dims - - -@torch.jit.script -def batch_unsqueeze(data, mask, dims, dim_): - if int(dim_) < 0: - dim_ = dim_ + data.dim() + 1 - dim = int(dim_) - # if dim == 0: - # raise ValueError("cannot do unsqueeze along batch_dim") - data = data.unsqueeze(dim) - mask = mask.unsqueeze(dim) - dims = torch.cat((dims[:dim], torch.zeros([1], dtype=torch.uint8), dims[dim:dims.size(0)])) - return data, mask, dims - - -@torch.jit.script -def batch_argmax(data, mask, dims, dim_, keepdim_): - dim = int(dim_) - keepdim = bool(keepdim_) - # if dim == 0: - # raise ValueError("cannot do argmax along batch_dim") - batch_size = data.size(0) - res_data = torch.zeros([0]) - for i in range(batch_size): - if bool(dims[dim - 1]): - if dim - 1 != 0: - m = mask[i].transpose(0, dim - 1) - else: - m = mask[i] - valid_num = m.sum(0, keepdim=True) - while(valid_num.dim() >= 1): - valid_num = valid_num[0] - d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num)) - else: - d = data[i].unsqueeze(0) - d = d.argmax(dim, keepdim) - if i == 0: - res_data = d - else: - res_data = torch.cat([res_data, d], 0) - if keepdim: - mask = mask - else: - mask = mask.select(dim, 0) - dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)])) - return res_data, mask, dims - - -@torch.jit.script -def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_): - k = int(k_) - dim = int(dim_) - largest = bool(largest_) - sorted = bool(sorted_) - # if dim == 0: - # raise ValueError("cannot do topk along batch_dim") - batch_size = data.size(0) - res_data = torch.zeros([0]) - res_index = torch.zeros([0]) - for i in range(batch_size): - if bool(dims[dim - 1]): - if dim - 1 != 0: - m = mask[i].transpose(0, dim - 1) - else: - m = mask[i] - valid_num = m.sum(0, keepdim=True) - while(valid_num.dim() >= 1): - valid_num = valid_num[0] - d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num)) - else: - d = data[i].unsqueeze(0) - d, idx = d.topk(k, dim, largest, sorted) - if i == 0: - res_data = d - res_index = idx - else: - res_data = torch.cat([res_data, d], 0) - res_index = torch.cat([res_index, idx], 0) - if bool(dims[dim - 1]): - mask = mask.narrow(dim, 0, k) - return res_data, mask, dims, res_index, mask, dims - - -@torch.jit.script -def batch_softmax(data, mask, dims, dim_): - dim = int(dim_) - # if dim == 0: - # raise ValueError("cannot do softmax along batch_dim") - batch_size = data.size(0) - max_len = data.size(dim) - res_data = torch.zeros([0]) - for i in range(batch_size): - if bool(dims[dim - 1]): - if dim - 1 != 0: - m = mask[i].transpose(0, dim - 1) - else: - m = mask[i] - valid_num = m.sum(0, keepdim=True) - while(valid_num.dim() >= 1): - valid_num = valid_num[0] - valid_num = int(valid_num) - d = data[i].unsqueeze(0).narrow(dim, 0, valid_num).softmax(dim) - if valid_num < max_len: - d = torch.cat([d, data[i].unsqueeze(0).narrow(dim, valid_num, max_len - valid_num)], dim) - else: - d = data[i].unsqueeze(0).softmax(dim) - if i == 0: - res_data = d - else: - res_data = torch.cat([res_data, d], 0) - return res_data, mask, dims - - -# size argument in dynamic dimension has to be -1 -# in static dimension, size has to be specified, -1 is not supported -@torch.jit.script -def batch_view(data, mask, dims, sizes): - batch_size = data.size(0) - # if(sizes[0] != batch_size and sizes[0] != -1 and sizes[0] != 1): - # raise "first dim in view must be 1, -1, or batch size" - # for i in range(dims.size(0)): - # if dims[0] == 1 and sizes[i + 1] != -1: - # raise "size argument in dynamic dimension has to be -1" - sizes = sizes.type_as(torch.ones([1], dtype=torch.int)) - data_sizes_ = torch.cat([torch.ones([1], dtype=torch.int) * batch_size, sizes.narrow(0, 1, sizes.size(0) - 1)], 0) - data_sizes = data_sizes_._tensor_to_list() - res_data = data.view(data_sizes) - mask_sizes_ = data_sizes_.narrow(0, 0, 1) - res_dims = data_sizes_.narrow(0, 0, 1) - for i_ in range(sizes.size(0) - 1): - i = i_ + 1 - if bool(sizes[i] == -1): - cur_size_ = mask.size(i) - cur_dim = 1 - else: - cur_size_ = 1 - cur_dim = 0 - mask_sizes_ = torch.cat([mask_sizes_, torch.ones([1], dtype=torch.int) * cur_size_]) - res_dims = torch.cat([res_dims, torch.ones([1], dtype=torch.int) * cur_dim]) - mask_sizes = mask_sizes_._tensor_to_list() - res_mask = mask.view(mask_sizes) - return res_data, res_mask, res_dims.narrow(0, 1, res_dims.size(0) - 1).type_as(dims) - - -@torch.jit.script -def batch_cat2(data1, mask1, dims1, data2, mask2, dims2, dim_): - dim = int(dim_) - data = torch.cat([data1, data2], dim) - if bool(dims1[dim - 1]): - mask = torch.cat([mask1, mask2], dim) - else: - mask = mask1 - return data, mask, dims1 - - -@torch.jit.script -def batch_cat3(data1, mask1, dims1, data2, mask2, dims2, data3, mask3, dims3, dim_): - dim = int(dim_) - data = torch.cat([data1, data2, data3], dim) - if bool(dims1[dim - 1]): - mask = torch.cat([mask1, mask2, mask3], dim) - else: - mask = mask1 - return data, mask, dims1 - - -@torch.jit.script -def batch_narrow(data, mask, dims, dimension_, start_, length_): - dimension = int(dimension_) - start = int(start_) - length = int(length_) - # if dimension == 0: - # raise ValueError("cannot do narrow along batch_dim") - data = data.narrow(dimension, start, length) - if bool(dims[dimension - 1]): - mask = mask.narrow(dimension, start, length) - else: - mask = mask.narrow(dimension, 0, 1) - return data, mask, dims - - -@torch.jit.script -def batch_sum(data, mask, dims): - data = data * mask.type_as(data) - for _ in range(dims.size(0)): - data = data.sum(1) - mask = torch.ones([data.size(0)], dtype=torch.uint8) - dims = dims[:0] # empty tensor - return data, mask, dims - - -@torch.jit.script -def batch_from_scalar_tensor(data): - data = data.unsqueeze(0) - mask = torch.ones([1], dtype=torch.uint8) - dims = torch.zeros([0], dtype=torch.uint8) - return data, mask, dims - -torch.register_batch_operator("tanh", batch_tanh.graph) -torch.register_batch_operator("sigmoid", batch_sigmoid.graph) -torch.register_batch_operator("relu", batch_relu.graph) -torch.register_batch_operator("neg", batch_neg.graph) -torch.register_batch_operator("neg", batch_neg_scalar.graph) -torch.register_batch_operator("add", batch_add.graph) -torch.register_batch_operator("add", batch_add_scalar.graph) -torch.register_batch_operator("sub", batch_sub.graph) -torch.register_batch_operator("sub", batch_sub_scalar.graph) -torch.register_batch_operator("mul", batch_mul.graph) -torch.register_batch_operator("mul", batch_mul_scalar.graph) -torch.register_batch_operator("div", batch_div.graph) -torch.register_batch_operator("matmul", batch_matmul.graph) -torch.register_batch_operator("mm", batch_mm.graph) -torch.register_batch_operator("fmod", batch_fmod.graph) -torch.register_batch_operator("zeros_like", batch_zeros_like.graph) -torch.register_batch_operator("select", batch_select.graph) -torch.register_batch_operator("index_select", batch_index_select.graph) -torch.register_batch_operator("view_as", batch_view_as.graph) -torch.register_batch_operator("where", batch_where.graph) -torch.register_batch_operator("where", batch_where_scalar.graph) -torch.register_batch_operator("update", batch_update.graph) -torch.register_batch_operator("any", batch_any.graph) -torch.register_batch_operator("type_as", batch_type_as.graph) -torch.register_batch_operator("gt", batch_gt.graph) -torch.register_batch_operator("gt", batch_gt_scalar.graph) -torch.register_batch_operator("gt", batch_gt_one_scalar.graph) -torch.register_batch_operator("lt", batch_lt.graph) -torch.register_batch_operator("eq", batch_eq.graph) -torch.register_batch_operator("size", batch_size.graph) -torch.register_batch_operator("dim", batch_dim.graph) -torch.register_batch_operator("squeeze", batch_squeeze.graph) -torch.register_batch_operator("unsqueeze", batch_unsqueeze.graph) -torch.register_batch_operator("argmax", batch_argmax.graph) -torch.register_batch_operator("topk", batch_topk.graph) -torch.register_batch_operator("softmax", batch_softmax.graph) -torch.register_batch_operator("view", batch_view.graph) -torch.register_batch_operator("cat", batch_cat2.graph) -torch.register_batch_operator("cat", batch_cat3.graph) -torch.register_batch_operator("narrow", batch_narrow.graph) -torch.register_batch_operator("sum", batch_sum.graph) -torch.register_batch_operator("batch_from_scalar_tensor", batch_from_scalar_tensor.graph) -- 2.7.4